2
0

Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded

Previously, it wasn't loaded until NextRow was called the first time.
This commit is contained in:
Jack Christensen
2020-09-05 13:14:11 -05:00
parent 5db484908c
commit 0d4f029683
2 changed files with 84 additions and 4 deletions
+48 -2
View File
@@ -84,6 +84,8 @@ type PgConn struct {
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error
peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources
wbuf []byte // write buffer
resultReader ResultReader
@@ -427,8 +429,12 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
return msg, err
}
// receiveMessage receives a message without setting up context cancellation
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
// peekMessage peeks at the next message without setting up context cancellation.
func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
if pgConn.peekedMsg != nil {
return pgConn.peekedMsg, nil
}
var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
@@ -455,6 +461,23 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
return nil, err
}
pgConn.peekedMsg = msg
return msg, nil
}
// receiveMessage receives a message without setting up context cancellation
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := pgConn.peekMessage()
if err != nil {
// Close on anything other than timeout error - everything else is fatal
if err, ok := err.(net.Error); !(ok && err.Timeout()) {
pgConn.asyncClose()
}
return nil, err
}
pgConn.peekedMsg = nil
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
pgConn.txStatus = msg.TxStatus
@@ -1044,7 +1067,10 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
return
}
result.readUntilRowDescription()
}
// CopyTo executes the copy command sql and copies the results to w.
@@ -1454,6 +1480,26 @@ func (rr *ResultReader) Close() (CommandTag, error) {
return rr.commandTag, rr.err
}
// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any
// error will be stored in the ResultReader.
func (rr *ResultReader) readUntilRowDescription() {
for !rr.commandConcluded {
// Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method.
// This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are
// manually used to construct a query that does not issue a describe statement.
msg, _ := rr.pgConn.peekMessage()
if _, ok := msg.(*pgproto3.DataRow); ok {
return
}
// Consume the message
msg, _ = rr.receiveMessage()
if _, ok := msg.(*pgproto3.RowDescription); ok {
return
}
}
}
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
if rr.multiResultReader == nil {
msg, err = rr.pgConn.receiveMessage()
+36 -2
View File
@@ -481,6 +481,34 @@ func TestConnExecMultipleQueries(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num")
require.True(t, mrr.NextResult())
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name)
_, err = mrr.ResultReader().Close()
require.NoError(t, err)
require.True(t, mrr.NextResult())
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name)
_, err = mrr.ResultReader().Close()
require.NoError(t, err)
require.False(t, mrr.NextResult())
require.NoError(t, mrr.Close())
ensureConnValid(t, pgConn)
}
func TestConnExecMultipleQueriesError(t *testing.T) {
t.Parallel()
@@ -578,7 +606,10 @@ func TestConnExecParams(t *testing.T) {
require.NoError(t, err)
defer closeConn(t, pgConn)
result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
require.Len(t, result.FieldDescriptions(), 1)
assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
rowCount := 0
for result.NextRow() {
rowCount += 1
@@ -734,13 +765,16 @@ func TestConnExecPrepared(t *testing.T) {
require.NoError(t, err)
defer closeConn(t, pgConn)
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil)
require.NoError(t, err)
require.NotNil(t, psd)
assert.Len(t, psd.ParamOIDs, 1)
assert.Len(t, psd.Fields, 1)
result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
require.Len(t, result.FieldDescriptions(), 1)
assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
rowCount := 0
for result.NextRow() {
rowCount += 1