From 0330052b0a5d985f572b05cf38a95495febad6c8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:10:16 -0600 Subject: [PATCH] Use result readers in next/get fashion --- pgconn/pgconn.go | 27 ++++++++++++++++----------- pgconn/pgconn_stress_test.go | 29 ++++++++++++++--------------- pgconn/pgconn_test.go | 26 +++++++++++++------------- 3 files changed, 43 insertions(+), 39 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index d9755f6c..8511d5b9 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -538,10 +538,9 @@ type PgResultReader struct { cleanupContext func() } -// GetResult returns a PgResultReader for the next result. If all results are consumed it returns nil. If an error -// occurs it will be reported on the returned PgResultReader. Returned PgResultReader is only valid until next call of -// GetResult. -func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { +// NextResult reads until a result is ready to be read or no results are pending. Returns true if a result is available. +// Use ResultReader() to acquire a reader for the result. +func (pgConn *PgConn) NextResult(ctx context.Context) bool { cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) for pgConn.pendingReadyForQueryCount > 0 { @@ -549,29 +548,34 @@ func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { if err != nil { cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} - return &pgConn.resultReader + return true } switch msg := msg.(type) { case *pgproto3.RowDescription: pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} - return &pgConn.resultReader + return true case *pgproto3.DataRow: pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} - return &pgConn.resultReader + return true case *pgproto3.CommandComplete: cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} - return &pgConn.resultReader + return true case *pgproto3.ErrorResponse: cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} - return &pgConn.resultReader + return true } } cleanupContext() - return nil + return false +} + +// ResultReader returns the result reader prepared by next result. It is only valid until the result is completed. +func (pgConn *PgConn) ResultReader() *PgResultReader { + return &pgConn.resultReader } // NextRow returns advances the PgResultReader to the next row and returns true if a row is available. @@ -806,7 +810,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { var result *PgResult - for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) { + for pgConn.NextResult(ctx) { + resultReader := pgConn.ResultReader() rows := [][][]byte{} for resultReader.NextRow() { row := make([][]byte, len(resultReader.Values())) diff --git a/pgconn/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go index cc6acab8..9aa94539 100644 --- a/pgconn/pgconn_stress_test.go +++ b/pgconn/pgconn_stress_test.go @@ -84,10 +84,10 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // Query 1 - resultReader := pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader := pgConn.ResultReader() for resultReader.NextRow() { } @@ -97,10 +97,10 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // Query 2 - resultReader = pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader = pgConn.ResultReader() for resultReader.NextRow() { } @@ -110,8 +110,7 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // No more - resultReader = pgConn.GetResult(context.Background()) - if resultReader != nil { + if pgConn.NextResult(context.Background()) { return errors.New("unexpected result reader") } @@ -162,10 +161,10 @@ func stressBatchCanceled(pgConn *pgconn.PgConn) error { } // Query 1 - resultReader := pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader := pgConn.ResultReader() for resultReader.NextRow() { } @@ -176,11 +175,11 @@ func stressBatchCanceled(pgConn *pgconn.PgConn) error { // Query 2 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - resultReader = pgConn.GetResult(ctx) - cancel() - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(ctx) { + return errors.New("missing result") } + cancel() + resultReader = pgConn.ResultReader() for resultReader.NextRow() { } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 35f5b536..8b578d42 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -373,8 +373,8 @@ func TestConnBatchedQueries(t *testing.T) { err = pgConn.Flush(context.Background()) // "select 'SendExec 1'" - resultReader := pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader := pgConn.ResultReader() rows := [][][]byte{} for resultReader.NextRow() { @@ -391,8 +391,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecParams 1" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -409,8 +409,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecPrepared 1" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -427,8 +427,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExec 2" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -445,8 +445,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecParams 2" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -463,8 +463,7 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // Done - resultReader = pgConn.GetResult(context.Background()) - assert.Nil(t, resultReader) + require.False(t, pgConn.NextResult(context.Background())) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -505,7 +504,8 @@ func TestConnCancelQuery(t *testing.T) { err = pgConn.CancelRequest(context.Background()) require.Nil(t, err) - _, err = pgConn.GetResult(context.Background()).Close() + require.True(t, pgConn.NextResult(context.Background())) + _, err = pgConn.ResultReader().Close() if err, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "57014", err.Code) } else {