Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded
Previously, it wasn't loaded until NextRow was called the first time.
This commit is contained in:
@@ -84,6 +84,8 @@ type PgConn struct {
|
|||||||
bufferingReceiveMsg pgproto3.BackendMessage
|
bufferingReceiveMsg pgproto3.BackendMessage
|
||||||
bufferingReceiveErr error
|
bufferingReceiveErr error
|
||||||
|
|
||||||
|
peekedMsg pgproto3.BackendMessage
|
||||||
|
|
||||||
// Reusable / preallocated resources
|
// Reusable / preallocated resources
|
||||||
wbuf []byte // write buffer
|
wbuf []byte // write buffer
|
||||||
resultReader ResultReader
|
resultReader ResultReader
|
||||||
@@ -427,8 +429,12 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
|
|||||||
return msg, err
|
return msg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// receiveMessage receives a message without setting up context cancellation
|
// peekMessage peeks at the next message without setting up context cancellation.
|
||||||
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
||||||
|
if pgConn.peekedMsg != nil {
|
||||||
|
return pgConn.peekedMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
var msg pgproto3.BackendMessage
|
var msg pgproto3.BackendMessage
|
||||||
var err error
|
var err error
|
||||||
if pgConn.bufferingReceive {
|
if pgConn.bufferingReceive {
|
||||||
@@ -455,6 +461,23 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pgConn.peekedMsg = msg
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveMessage receives a message without setting up context cancellation
|
||||||
|
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||||
|
msg, err := pgConn.peekMessage()
|
||||||
|
if err != nil {
|
||||||
|
// Close on anything other than timeout error - everything else is fatal
|
||||||
|
if err, ok := err.(net.Error); !(ok && err.Timeout()) {
|
||||||
|
pgConn.asyncClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pgConn.peekedMsg = nil
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
pgConn.txStatus = msg.TxStatus
|
pgConn.txStatus = msg.TxStatus
|
||||||
@@ -1044,7 +1067,10 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
|
|||||||
pgConn.contextWatcher.Unwatch()
|
pgConn.contextWatcher.Unwatch()
|
||||||
result.closed = true
|
result.closed = true
|
||||||
pgConn.unlock()
|
pgConn.unlock()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result.readUntilRowDescription()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyTo executes the copy command sql and copies the results to w.
|
// CopyTo executes the copy command sql and copies the results to w.
|
||||||
@@ -1454,6 +1480,26 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
|||||||
return rr.commandTag, rr.err
|
return rr.commandTag, rr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any
|
||||||
|
// error will be stored in the ResultReader.
|
||||||
|
func (rr *ResultReader) readUntilRowDescription() {
|
||||||
|
for !rr.commandConcluded {
|
||||||
|
// Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method.
|
||||||
|
// This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are
|
||||||
|
// manually used to construct a query that does not issue a describe statement.
|
||||||
|
msg, _ := rr.pgConn.peekMessage()
|
||||||
|
if _, ok := msg.(*pgproto3.DataRow); ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume the message
|
||||||
|
msg, _ = rr.receiveMessage()
|
||||||
|
if _, ok := msg.(*pgproto3.RowDescription); ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
|
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
|
||||||
if rr.multiResultReader == nil {
|
if rr.multiResultReader == nil {
|
||||||
msg, err = rr.pgConn.receiveMessage()
|
msg, err = rr.pgConn.receiveMessage()
|
||||||
|
|||||||
+36
-2
@@ -481,6 +481,34 @@ func TestConnExecMultipleQueries(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num")
|
||||||
|
|
||||||
|
require.True(t, mrr.NextResult())
|
||||||
|
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
|
||||||
|
assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name)
|
||||||
|
_, err = mrr.ResultReader().Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.True(t, mrr.NextResult())
|
||||||
|
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
|
||||||
|
assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name)
|
||||||
|
_, err = mrr.ResultReader().Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.False(t, mrr.NextResult())
|
||||||
|
|
||||||
|
require.NoError(t, mrr.Close())
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExecMultipleQueriesError(t *testing.T) {
|
func TestConnExecMultipleQueriesError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -578,7 +606,10 @@ func TestConnExecParams(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer closeConn(t, pgConn)
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
||||||
|
require.Len(t, result.FieldDescriptions(), 1)
|
||||||
|
assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
|
||||||
|
|
||||||
rowCount := 0
|
rowCount := 0
|
||||||
for result.NextRow() {
|
for result.NextRow() {
|
||||||
rowCount += 1
|
rowCount += 1
|
||||||
@@ -734,13 +765,16 @@ func TestConnExecPrepared(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer closeConn(t, pgConn)
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, psd)
|
require.NotNil(t, psd)
|
||||||
assert.Len(t, psd.ParamOIDs, 1)
|
assert.Len(t, psd.ParamOIDs, 1)
|
||||||
assert.Len(t, psd.Fields, 1)
|
assert.Len(t, psd.Fields, 1)
|
||||||
|
|
||||||
result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
||||||
|
require.Len(t, result.FieldDescriptions(), 1)
|
||||||
|
assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
|
||||||
|
|
||||||
rowCount := 0
|
rowCount := 0
|
||||||
for result.NextRow() {
|
for result.NextRow() {
|
||||||
rowCount += 1
|
rowCount += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user