Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded
Previously, it wasn't loaded until NextRow was called the first time.
This commit is contained in:
@@ -84,6 +84,8 @@ type PgConn struct {
|
||||
bufferingReceiveMsg pgproto3.BackendMessage
|
||||
bufferingReceiveErr error
|
||||
|
||||
peekedMsg pgproto3.BackendMessage
|
||||
|
||||
// Reusable / preallocated resources
|
||||
wbuf []byte // write buffer
|
||||
resultReader ResultReader
|
||||
@@ -427,8 +429,12 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
|
||||
return msg, err
|
||||
}
|
||||
|
||||
// receiveMessage receives a message without setting up context cancellation
|
||||
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||
// peekMessage peeks at the next message without setting up context cancellation.
|
||||
func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
||||
if pgConn.peekedMsg != nil {
|
||||
return pgConn.peekedMsg, nil
|
||||
}
|
||||
|
||||
var msg pgproto3.BackendMessage
|
||||
var err error
|
||||
if pgConn.bufferingReceive {
|
||||
@@ -455,6 +461,23 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pgConn.peekedMsg = msg
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// receiveMessage receives a message without setting up context cancellation
|
||||
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||
msg, err := pgConn.peekMessage()
|
||||
if err != nil {
|
||||
// Close on anything other than timeout error - everything else is fatal
|
||||
if err, ok := err.(net.Error); !(ok && err.Timeout()) {
|
||||
pgConn.asyncClose()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
pgConn.peekedMsg = nil
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
pgConn.txStatus = msg.TxStatus
|
||||
@@ -1044,7 +1067,10 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
return
|
||||
}
|
||||
|
||||
result.readUntilRowDescription()
|
||||
}
|
||||
|
||||
// CopyTo executes the copy command sql and copies the results to w.
|
||||
@@ -1454,6 +1480,26 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
||||
return rr.commandTag, rr.err
|
||||
}
|
||||
|
||||
// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any
|
||||
// error will be stored in the ResultReader.
|
||||
func (rr *ResultReader) readUntilRowDescription() {
|
||||
for !rr.commandConcluded {
|
||||
// Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method.
|
||||
// This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are
|
||||
// manually used to construct a query that does not issue a describe statement.
|
||||
msg, _ := rr.pgConn.peekMessage()
|
||||
if _, ok := msg.(*pgproto3.DataRow); ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Consume the message
|
||||
msg, _ = rr.receiveMessage()
|
||||
if _, ok := msg.(*pgproto3.RowDescription); ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
|
||||
if rr.multiResultReader == nil {
|
||||
msg, err = rr.pgConn.receiveMessage()
|
||||
|
||||
+36
-2
@@ -481,6 +481,34 @@ func TestConnExecMultipleQueries(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num")
|
||||
|
||||
require.True(t, mrr.NextResult())
|
||||
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
|
||||
assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name)
|
||||
_, err = mrr.ResultReader().Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, mrr.NextResult())
|
||||
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
|
||||
assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name)
|
||||
_, err = mrr.ResultReader().Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.False(t, mrr.NextResult())
|
||||
|
||||
require.NoError(t, mrr.Close())
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecMultipleQueriesError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -578,7 +606,10 @@ func TestConnExecParams(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
||||
result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
||||
require.Len(t, result.FieldDescriptions(), 1)
|
||||
assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
|
||||
|
||||
rowCount := 0
|
||||
for result.NextRow() {
|
||||
rowCount += 1
|
||||
@@ -734,13 +765,16 @@ func TestConnExecPrepared(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
assert.Len(t, psd.ParamOIDs, 1)
|
||||
assert.Len(t, psd.Fields, 1)
|
||||
|
||||
result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
||||
require.Len(t, result.FieldDescriptions(), 1)
|
||||
assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
|
||||
|
||||
rowCount := 0
|
||||
for result.NextRow() {
|
||||
rowCount += 1
|
||||
|
||||
Reference in New Issue
Block a user