diff --git a/pgconn.go b/pgconn.go index e2ab5c13..ff812069 100644 --- a/pgconn.go +++ b/pgconn.go @@ -84,6 +84,8 @@ type PgConn struct { bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error + peekedMsg pgproto3.BackendMessage + // Reusable / preallocated resources wbuf []byte // write buffer resultReader ResultReader @@ -427,8 +429,12 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa return msg, err } -// receiveMessage receives a message without setting up context cancellation -func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { +// peekMessage peeks at the next message without setting up context cancellation. +func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { + if pgConn.peekedMsg != nil { + return pgConn.peekedMsg, nil + } + var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { @@ -455,6 +461,23 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return nil, err } + pgConn.peekedMsg = msg + return msg, nil +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.peekMessage() + if err != nil { + // Close on anything other than timeout error - everything else is fatal + if err, ok := err.(net.Error); !(ok && err.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + pgConn.peekedMsg = nil + switch msg := msg.(type) { case *pgproto3.ReadyForQuery: pgConn.txStatus = msg.TxStatus @@ -1044,7 +1067,10 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() + return } + + result.readUntilRowDescription() } // CopyTo executes the copy command sql and copies the results to w. @@ -1454,6 +1480,26 @@ func (rr *ResultReader) Close() (CommandTag, error) { return rr.commandTag, rr.err } +// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any +// error will be stored in the ResultReader. +func (rr *ResultReader) readUntilRowDescription() { + for !rr.commandConcluded { + // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. + // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are + // manually used to construct a query that does not issue a describe statement. + msg, _ := rr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + return + } + + // Consume the message + msg, _ = rr.receiveMessage() + if _, ok := msg.(*pgproto3.RowDescription); ok { + return + } + } +} + func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { msg, err = rr.pgConn.receiveMessage() diff --git a/pgconn_test.go b/pgconn_test.go index f6750a60..24200e73 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -481,6 +481,34 @@ func TestConnExecMultipleQueries(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num") + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.False(t, mrr.NextResult()) + + require.NoError(t, mrr.Close()) + + ensureConnValid(t, pgConn) +} + func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() @@ -578,7 +606,10 @@ func TestConnExecParams(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) - result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + rowCount := 0 for result.NextRow() { rowCount += 1 @@ -734,13 +765,16 @@ func TestConnExecPrepared(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil) require.NoError(t, err) require.NotNil(t, psd) assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + rowCount := 0 for result.NextRow() { rowCount += 1