From 31cb2b4e72d2b3171e1b661d642e4a314b6d3803 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 17:37:28 -0600 Subject: [PATCH] Big restructure to better handle context cancel --- pgconn/benchmark_test.go | 33 +- pgconn/config.go | 6 +- pgconn/helper_test.go | 4 +- pgconn/pgconn.go | 995 +++++++++++++++++++---------------- pgconn/pgconn_stress_test.go | 116 +--- pgconn/pgconn_test.go | 289 ++++------ 6 files changed, 686 insertions(+), 757 deletions(-) diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go index fc4b6057..ffb1455c 100644 --- a/pgconn/benchmark_test.go +++ b/pgconn/benchmark_test.go @@ -44,7 +44,7 @@ func BenchmarkExec(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() require.Nil(b, err) } } @@ -60,7 +60,7 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { defer cancel() for i := 0; i < b.N; i++ { - _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() require.Nil(b, err) } } @@ -71,12 +71,13 @@ func BenchmarkExecPrepared(b *testing.B) { defer closeConn(b, conn) _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) - require.Nil(b, err) + result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).ReadAll() + require.Nil(b, result.Err) } } @@ -89,32 +90,12 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { defer cancel() _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) - require.Nil(b, err) - } -} - -func BenchmarkSendExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) - defer closeConn(b, conn) - - _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() for i := 0; i < b.N; i++ { - conn.SendExecPrepared("ps1", nil, nil, nil) - err := conn.Flush(context.Background()) - require.Nil(b, err) - - for conn.NextResult(context.Background()) { - _, err := conn.ResultReader().Close() - require.Nil(b, err) - } + result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).ReadAll() + require.Nil(b, result.Err) } } diff --git a/pgconn/config.go b/pgconn/config.go index bd1fec9b..fb0719cd 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -470,9 +470,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result, err := pgConn.Exec(ctx, "show transaction_read_only") - if err != nil { - return err + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).ReadAll() + if result.Err != nil { + return result.Err } if string(result.Rows[0][0]) == "on" { diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go index 1053310b..a50f7cb1 100644 --- a/pgconn/helper_test.go +++ b/pgconn/helper_test.go @@ -20,10 +20,10 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { // Do a simple query to ensure the connection is still usable func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil) + result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).ReadAll() cancel() - require.Nil(t, err) + require.Nil(t, result.Err) assert.Equal(t, 3, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) assert.Equal(t, "2", string(result.Rows[1][0])) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ee8127bf..cfacc7bb 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -17,8 +17,6 @@ import ( "github.com/jackc/pgx/pgproto3" ) -const batchBufferSize = 4096 - var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) // PgError represents an error reported by the PostgreSQL server. See @@ -76,14 +74,9 @@ type PgConn struct { Config *Config - batchBuf []byte - batchCount int32 - - pendingReadyForQueryCount int32 + controller chan interface{} closed bool - - resultReader PgResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -140,6 +133,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config + pgConn.controller = make(chan interface{}, 1) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -268,23 +262,22 @@ func hexMD5(s string) string { func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.Frontend.Receive() if err != nil { + // Close on anything other than timeout error - everything else is fatal + if err, ok := err.(net.Error); !ok && err.Timeout() { + pgConn.hardClose() + } + return nil, err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - // Under normal circumstances pendingReadyForQueryCount will be > 0 when a - // ReadyForQuery is received. However, this is not the case on initial - // connection. - if pgConn.pendingReadyForQueryCount > 0 { - pgConn.pendingReadyForQueryCount -= 1 - } pgConn.TxStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - // TODO - close pgConn + pgConn.hardClose() return nil, errorResponseToPgError(msg) } case *pgproto3.NoticeResponse: @@ -338,6 +331,15 @@ func (pgConn *PgConn) Close(ctx context.Context) error { return pgConn.conn.Close() } +// hardClose closes the underlying connection without sending the exit message. +func (pgConn *PgConn) hardClose() error { + if pgConn.closed { + return nil + } + pgConn.closed = true + return pgConn.conn.Close() +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -363,229 +365,6 @@ func (ct CommandTag) String() string { return string(ct) } -// SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries. -// Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains -// transaction control statements. It is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExec(sql string) { - pgConn.batchBuf = (&pgproto3.Query{String: sql}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 -} - -// SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol. -// -// sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc. -// -// paramValues are the parameter values. It must be encoded in the format given by paramFormats. -// -// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for -// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. -// SendExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). -// -// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if -// len(paramFormats) is not 0, 1, or len(paramValues). -// -// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. -// -// Query is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { - panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) - } - - pgConn.batchBuf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) - pgConn.SendExecPrepared("", paramValues, paramFormats, resultFormats) -} - -// SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. -// -// paramValues are the parameter values. It must be encoded in the format given by paramFormats. -// -// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if -// len(paramFormats) is not 0, 1, or len(paramValues). -// -// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. -// -// Query is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: stmtName}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Execute{}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 -} - -type PgResultReader struct { - pgConn *PgConn - fieldDescriptions []pgproto3.FieldDescription - rowValues [][]byte - commandTag CommandTag - err error - complete bool - preloadedRowValues bool - ctx context.Context - cleanupContext func() -} - -// NextResult reads until a result is ready to be read or no results are pending. Returns true if a result is available. -// Use ResultReader() to acquire a reader for the result. -func (pgConn *PgConn) NextResult(ctx context.Context) bool { - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - - for pgConn.pendingReadyForQueryCount > 0 { - msg, err := pgConn.ReceiveMessage() - if err != nil { - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} - return true - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} - return true - case *pgproto3.DataRow: - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} - return true - case *pgproto3.CommandComplete: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} - return true - case *pgproto3.EmptyQueryResponse: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, complete: true} - return true - case *pgproto3.ErrorResponse: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} - return true - } - } - - cleanupContext() - return false -} - -// ResultReader returns the result reader prepared by next result. It is only valid until the result is completed. -func (pgConn *PgConn) ResultReader() *PgResultReader { - return &pgConn.resultReader -} - -// NextRow returns advances the PgResultReader to the next row and returns true if a row is available. -func (rr *PgResultReader) NextRow() bool { - if rr.complete { - return false - } - - if rr.preloadedRowValues { - rr.preloadedRowValues = false - return true - } - - for { - msg, err := rr.pgConn.ReceiveMessage() - if err != nil { - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) - rr.close() - return false - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - rr.fieldDescriptions = msg.Fields - case *pgproto3.DataRow: - rr.rowValues = msg.Values - return true - case *pgproto3.CommandComplete: - rr.commandTag = CommandTag(msg.CommandTag) - rr.close() - return false - case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) - rr.close() - return false - } - } -} - -// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the PgResultReader is closed. -func (rr *PgResultReader) FieldDescriptions() []pgproto3.FieldDescription { - return rr.fieldDescriptions -} - -// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to -// retain a reference to and mutate. -func (rr *PgResultReader) Values() [][]byte { - return rr.rowValues -} - -// Close consumes any remaining result data and returns the command tag or -// error. -func (rr *PgResultReader) Close() (CommandTag, error) { - if rr.complete { - return rr.commandTag, rr.err - } - defer rr.close() - - for { - msg, err := rr.pgConn.ReceiveMessage() - if err != nil { - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) - return rr.commandTag, rr.err - } - - switch msg := msg.(type) { - case *pgproto3.CommandComplete: - rr.commandTag = CommandTag(msg.CommandTag) - return rr.commandTag, rr.err - case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) - return rr.commandTag, rr.err - } - } -} - -func (rr *PgResultReader) close() { - if rr.complete { - return - } - - rr.cleanupContext() - rr.rowValues = nil - rr.complete = true -} - -// Flush sends the enqueued execs to the server. -func (pgConn *PgConn) Flush(ctx context.Context) error { - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - err := pgConn.flush() - cleanup() - return preferContextOverNetTimeoutError(ctx, err) -} - -// flush sends the enqueued execs to the server without handling a context. -func (pgConn *PgConn) flush() error { - n, err := pgConn.conn.Write(pgConn.batchBuf) - if err != nil && n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - - if err == nil { - pgConn.pendingReadyForQueryCount += pgConn.batchCount - } - - pgConn.resetBatch() - - return err -} - // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -595,63 +374,6 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { return err } -// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. This must be -// called after any context cancellation. This is not done automatically as RecoverFromTimeout may need to signal the -// server to abort the in-progress query and read and ignore data already sent from the server. This potentially can -// block indefinitely. Use ctx to guard against this. If recovery is successful true is returned. If recovery is not -// successful the connection is closed and false is returned. Recovery should usually be possible except in the case of -// a partial write. -func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { - if pgConn.closed { - return false - } - pgConn.resetBatch() - - // Clear any existing timeout - pgConn.conn.SetDeadline(time.Time{}) - - // Try to cancel any in-progress requests - for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { - pgConn.CancelRequest(ctx) - } - - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() - - err := pgConn.ensureReadyForQuery() - if err != nil { - preferContextOverNetTimeoutError(ctx, err) - pgConn.Close(context.Background()) - return false - } - - result, err := pgConn.Exec( - context.Background(), // do not use ctx again because deadline goroutine already started above - "select 'RecoverFromTimeout'", - ) - if err != nil || len(result.Rows) != 1 || len(result.Rows[0]) != 1 || string(result.Rows[0][0]) != "RecoverFromTimeout" { - pgConn.Close(context.Background()) - return false - } - - return true -} - -// startOperation gets the connection ready for a new operation. It should be called at the beginning of every public -// method that communicates with the server. The returned cleanup function must be called if err == nil or a goroutine may -// be leaked. The cleanup function is safe to call multiple times. -func (pgConn *PgConn) startOperation(ctx context.Context) (cleanup func(), err error) { - cleanup = contextDoneToConnDeadline(ctx, pgConn.conn) - - err = pgConn.ensureReadyForQuery() - if err != nil { - cleanup() - return cleanup, preferContextOverNetTimeoutError(ctx, err) - } - - return cleanup, nil -} - // contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from // ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to // call multiple times. @@ -665,7 +387,6 @@ func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func conn.SetDeadline(deadlineTime) deadlineWasSet = true <-doneChan - // TODO case <-doneChan: } }() @@ -685,135 +406,6 @@ func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func return func() {} } -// ensureReadyForQuery reads until pendingReadyForQueryCount == 0. -func (pgConn *PgConn) ensureReadyForQuery() error { - for pgConn.pendingReadyForQueryCount > 0 { - _, err := pgConn.ReceiveMessage() - if err != nil { - return err - } - } - - return nil -} - -func (pgConn *PgConn) resetBatch() { - pgConn.batchCount = 0 - if len(pgConn.batchBuf) > batchBufferSize { - pgConn.batchBuf = make([]byte, 0, batchBufferSize) - } else { - pgConn.batchBuf = pgConn.batchBuf[0:0] - } -} - -type PgResult struct { - Rows [][][]byte - CommandTag CommandTag -} - -// Exec executes sql via the PostgreSQL simple query protocol, buffers the entire result, and returns it. sql may -// contain multiple queries, but only the last results will be returned. Execution is implicitly wrapped in a -// transactions unless a transaction is already in progress or sql contains transaction control statements. -// -// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExec(sql) - err = pgConn.flush() - if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) - } - - return pgConn.bufferLastResult(ctx) -} - -func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { - var result *PgResult - - for pgConn.NextResult(ctx) { - resultReader := pgConn.ResultReader() - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - - commandTag, err := resultReader.Close() - if err != nil { - return nil, err - } - - result = &PgResult{ - Rows: rows, - CommandTag: commandTag, - } - } - - if result == nil { - return nil, errors.New("unexpected missing result") - } - - return result, nil -} - -// ExecParams executes sql via the PostgreSQL extended query protocol, buffers the entire result, and returns it. See -// SendExecParams for parameter descriptions. -// -// ExecParams must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) - err = pgConn.flush() - if err != nil { - return nil, err - } - - return pgConn.bufferLastResult(ctx) -} - -// ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and -// returns it. See SendExecPrepared for parameter descriptions. -// -// ExecPrepared must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) - err = pgConn.flush() - if err != nil { - return nil, err - } - - return pgConn.bufferLastResult(ctx) -} - type PreparedStatementDescription struct { Name string SQL string @@ -823,30 +415,38 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") + select { + case <-ctx.Done(): + return nil, ctx.Err() + case pgConn.controller <- pgConn: } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() + var buf []byte + buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) - pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 - err = pgConn.flush() + n, err := pgConn.conn.Write(buf) if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + return nil, preferContextOverNetTimeoutError(ctx, err) } psd := &PreparedStatementDescription{Name: name, SQL: sql} - for pgConn.pendingReadyForQueryCount > 0 { +readloop: + for { msg, err := pgConn.ReceiveMessage() if err != nil { + go pgConn.recoverFromTimeout() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -858,10 +458,14 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: + go pgConn.recoverFromTimeout() return nil, errorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + break readloop } } + <-pgConn.controller return psd, nil } @@ -892,10 +496,10 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { return (*Notice)(pgerr) } -// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// cancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 -func (pgConn *PgConn) CancelRequest(ctx context.Context) error { +func (pgConn *PgConn) cancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. @@ -926,3 +530,514 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { return nil } + +// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is +// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control +// statements. +// +// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *PgMultiResult { + multiResult := &PgMultiResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + case pgConn.controller <- multiResult: + } + multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + multiResult.cleanupContextDeadline() + multiResult.closed = true + multiResult.err = preferContextOverNetTimeoutError(ctx, err) + <-pgConn.controller + return multiResult + } + + return multiResult +} + +// ExecParams executes a command via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, +// etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. ExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Result must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *PgResult { + result := &PgResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + result.concludeCommand(nil, ctx.Err()) + result.closed = true + return result + case pgConn.controller <- result: + } + result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + + // TODO - refactor ExecParams and ExecPrepared - these lines only difference + buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + buf = (&pgproto3.Execute{}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + result.concludeCommand(nil, err) + result.cleanupContextDeadline() + result.closed = true + <-pgConn.controller + } + + return result +} + +// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. ExecPrepared will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Result must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *PgResult { + result := &PgResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + result.concludeCommand(nil, ctx.Err()) + result.closed = true + return result + case pgConn.controller <- result: + } + result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + buf = (&pgproto3.Execute{}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + result.concludeCommand(nil, err) + result.cleanupContextDeadline() + result.closed = true + <-pgConn.controller + } + + return result +} + +type PgMultiResult struct { + pgConn *PgConn + ctx context.Context + cleanupContextDeadline func() + + pgResult *PgResult + + closed bool + err error +} + +func (mr *PgMultiResult) ReadAll() ([]*BufferedResult, error) { + var results []*BufferedResult + + for mr.NextResult() { + results = append(results, mr.Result().ReadAll()) + } + err := mr.Close() + + return results, err +} + +func (mr *PgMultiResult) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mr.pgConn.ReceiveMessage() + + if err != nil { + mr.cleanupContextDeadline() + mr.err = preferContextOverNetTimeoutError(mr.ctx, err) + mr.closed = true + + if err, ok := err.(net.Error); ok && err.Timeout() { + go mr.pgConn.recoverFromTimeout() + } else { + <-mr.pgConn.controller + } + + return nil, mr.err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + mr.cleanupContextDeadline() + mr.closed = true + <-mr.pgConn.controller + case *pgproto3.ErrorResponse: + mr.err = errorResponseToPgError(msg) + } + + return msg, nil +} + +// NextResult returns advances the PgMultiResult to the next result and returns true if a result is available. +func (mr *PgMultiResult) NextResult() bool { + for !mr.closed && mr.err == nil { + msg, err := mr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + mr.pgResult = &PgResult{ + pgConn: mr.pgConn, + pgMultiResult: mr, + ctx: mr.ctx, + cleanupContextDeadline: func() {}, + fieldDescriptions: msg.Fields, + } + return true + case *pgproto3.CommandComplete: + mr.pgResult = &PgResult{ + commandTag: CommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return true + case *pgproto3.EmptyQueryResponse: + return false + } + } + + return false +} + +func (mr *PgMultiResult) Result() *PgResult { + return mr.pgResult +} + +func (mr *PgMultiResult) Close() error { + for !mr.closed { + _, err := mr.receiveMessage() + if err != nil { + return mr.err + } + } + + return mr.err +} + +type PgResult struct { + pgConn *PgConn + pgMultiResult *PgMultiResult + ctx context.Context + cleanupContextDeadline func() + + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + commandConcluded bool + closed bool + err error +} + +type BufferedResult struct { + FieldDescriptions []pgproto3.FieldDescription + Rows [][][]byte + CommandTag CommandTag + Err error +} + +func (rr *PgResult) ReadAll() *BufferedResult { + br := &BufferedResult{} + + for rr.NextRow() { + if br.FieldDescriptions == nil { + br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) + copy(br.FieldDescriptions, rr.FieldDescriptions()) + } + + row := make([][]byte, len(rr.Values())) + copy(row, rr.Values()) + br.Rows = append(br.Rows, row) + } + + br.CommandTag, br.Err = rr.Close() + + return br +} + +// NextRow advances the PgResult to the next row and returns true if a row is available. +func (rr *PgResult) NextRow() bool { + for !rr.commandConcluded { + msg, err := rr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + } + } + + return false +} + +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the PgResult is closed. +func (rr *PgResult) FieldDescriptions() []pgproto3.FieldDescription { + return rr.fieldDescriptions +} + +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the PgResult is closed. However, the underlying byte data is safe to +// retain a reference to and mutate. +func (rr *PgResult) Values() [][]byte { + return rr.rowValues +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *PgResult) Close() (CommandTag, error) { + if rr.closed { + return rr.commandTag, rr.err + } + rr.closed = true + + for !rr.commandConcluded { + _, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + } + + if rr.pgMultiResult == nil { + for { + msg, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + rr.cleanupContextDeadline() + <-rr.pgConn.controller + return rr.commandTag, rr.err + } + } + } + + return rr.commandTag, rr.err +} + +func (rr *PgResult) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.pgMultiResult == nil { + msg, err = rr.pgConn.ReceiveMessage() + } else { + msg, err = rr.pgMultiResult.receiveMessage() + } + + if err != nil { + rr.concludeCommand(nil, err) + rr.cleanupContextDeadline() + rr.closed = true + if rr.pgMultiResult == nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + go rr.pgConn.recoverFromTimeout() + } else { + <-rr.pgConn.controller + } + } + + return nil, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = msg.Fields + case *pgproto3.CommandComplete: + rr.concludeCommand(CommandTag(msg.CommandTag), nil) + case *pgproto3.ErrorResponse: + rr.concludeCommand(nil, errorResponseToPgError(msg)) + } + + return msg, nil +} + +func (rr *PgResult) concludeCommand(commandTag CommandTag, err error) { + if rr.commandConcluded { + return + } + + rr.commandTag = commandTag + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.fieldDescriptions = nil + rr.rowValues = nil + rr.commandConcluded = true +} + +func (pgConn *PgConn) recoverFromTimeout() { + // Regardless of recovery outcome the lock on the pgConn must be released. + defer func() { <-pgConn.controller }() + + // Send a cancellation request to the PostgreSQL server. If it is not successful in a reasonable amount of time do not + // try further to recover the connection. + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + err := pgConn.cancelRequest(ctx) + cancel() + if err != nil { + pgConn.hardClose() + return + } + + // Limit time to wait for ReadyForQuery message. + err = pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + pgConn.hardClose() + return + } + + // A cancel query request will always return a "57014" error response, even if no query was in progress. This error + // may be returned before or after the ReadyForQuery message. Must ensure both messages are read. + needError57014 := true + needReadyForQuery := true + + for needError57014 || needReadyForQuery { + msg, err := pgConn.ReceiveMessage() + if err != nil { + pgConn.hardClose() + return + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + if msg.Code == "57014" { + needError57014 = false + } + case *pgproto3.ReadyForQuery: + needReadyForQuery = false + } + } + + err = pgConn.conn.SetDeadline(time.Time{}) + if err != nil { + pgConn.hardClose() + } +} + +type Batch struct { + buf []byte +} + +// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + // TODO - refactor ExecParams and ExecPrepared - these lines only difference + batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + batch.ExecPrepared("", paramValues, paramFormats, resultFormats) +} + +// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) +} + +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *PgMultiResult { + multiResult := &PgMultiResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + case pgConn.controller <- multiResult: + } + multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + n, err := pgConn.conn.Write(batch.buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + multiResult.cleanupContextDeadline() + multiResult.closed = true + multiResult.err = preferContextOverNetTimeoutError(ctx, err) + <-pgConn.controller + return multiResult + } + + return multiResult +} diff --git a/pgconn/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go index 9aa94539..17d344b7 100644 --- a/pgconn/pgconn_stress_test.go +++ b/pgconn/pgconn_stress_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/jackc/pgx/pgconn" - "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -22,9 +21,9 @@ func TestConnStress(t *testing.T) { defer closeConn(t, pgConn) actionCount := 100 - if s := os.Getenv("PTX_TEST_STRESS_FACTOR"); s != "" { + if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { stressFactor, err := strconv.ParseInt(s, 10, 64) - require.Nil(t, err, "Failed to parse PTX_TEST_STRESS_FACTOR") + require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") actionCount *= int(stressFactor) } @@ -61,138 +60,61 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { insert into widgets(name, description) values ('Foo', 'bar'), ('baz', 'Something really long Something really long Something really long Something really long Something really long'), - ('a', 'b')`) + ('a', 'b')`).ReadAll() require.Nil(t, err) } func stressExecSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.Exec(context.Background(), "select * from widgets") + _, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll() return err } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - return err + result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() + return result.Err } func stressBatch(pgConn *pgconn.PgConn) error { - pgConn.SendExec("select * from widgets") - pgConn.SendExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - err := pgConn.Flush(context.Background()) - if err != nil { - return err - } + batch := &pgconn.Batch{} - // Query 1 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader := pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // Query 2 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader = pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // No more - if pgConn.NextResult(context.Background()) { - return errors.New("unexpected result reader") - } - - return nil + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + _, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + return err } func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets") + _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll() cancel() if err != context.DeadlineExceeded { return err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() cancel() - if err != context.DeadlineExceeded { - return err + if result.Err != context.DeadlineExceeded { + return result.Err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } func stressBatchCanceled(pgConn *pgconn.PgConn) error { - - pgConn.SendExec("select * from widgets") - pgConn.SendExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - err := pgConn.Flush(context.Background()) - if err != nil { - return err - } - - // Query 1 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader := pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // Query 2 + batch := &pgconn.Batch{} + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - if !pgConn.NextResult(ctx) { - return errors.New("missing result") - } + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() cancel() - resultReader = pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() if err != context.DeadlineExceeded { return err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index e436d739..a2eb7838 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -134,13 +134,13 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - result, err := conn.Exec(context.Background(), "show application_name") - require.Nil(t, err) + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).ReadAll() + require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result, err = conn.Exec(context.Background(), "show search_path") - require.Nil(t, err) + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).ReadAll() + require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) } @@ -239,10 +239,14 @@ func TestConnExec(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select current_database()") - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Nil(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) ensureConnValid(t, pgConn) } @@ -254,10 +258,16 @@ func TestConnExecEmpty(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), ";") - require.Nil(t, err) - assert.Nil(t, result.CommandTag) - assert.Equal(t, 0, len(result.Rows)) + multiResult := pgConn.Exec(context.Background(), ";") + + resultCount := 0 + for multiResult.NextResult() { + resultCount += 1 + multiResult.Result().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -269,10 +279,20 @@ func TestConnExecMultipleQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select current_database(); select 1") - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "1", string(result.Rows[0][0])) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.Nil(t, err) + + assert.Len(t, results, 2) + + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) ensureConnValid(t, pgConn) } @@ -284,15 +304,18 @@ func TestConnExecMultipleQueriesError(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1") + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() require.NotNil(t, err) - require.Nil(t, result) if pgErr, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "22012", pgErr.Code) } else { t.Errorf("unexpected error: %v", err) } + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + ensureConnValid(t, pgConn) } @@ -305,11 +328,12 @@ func TestConnExecContextCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - assert.Nil(t, result) - assert.Equal(t, context.DeadlineExceeded, err) + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)") - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.Equal(t, context.DeadlineExceeded, err) ensureConnValid(t, pgConn) } @@ -321,10 +345,16 @@ func TestConnExecParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -338,12 +368,16 @@ func TestConnExecParamsCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) - assert.Nil(t, result) + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Nil(t, commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) - ensureConnValid(t, pgConn) } @@ -360,10 +394,16 @@ func TestConnExecPrepared(t *testing.T) { assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) - result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -380,16 +420,20 @@ func TestConnExecPreparedCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) - assert.Nil(t, result) + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Nil(t, commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) - ensureConnValid(t, pgConn) } -func TestConnBatchedQueries(t *testing.T) { +func TestConnExecBatch(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) @@ -399,160 +443,26 @@ func TestConnBatchedQueries(t *testing.T) { _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) - pgConn.SendExec("select 'SendExec 1'") - pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil) - pgConn.SendExecPrepared("ps1", [][]byte{[]byte("SendExecPrepared 1")}, nil, nil) - pgConn.SendExec("select 'SendExec 2'") - pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil) - err = pgConn.Flush(context.Background()) + batch := &pgconn.Batch{} - // "select 'SendExec 1'" - require.True(t, pgConn.NextResult(context.Background())) - resultReader := pgConn.ResultReader() - - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExec 1", string(rows[0][0])) - - commandTag, err := resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecParams 1" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecParams 1", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecPrepared 1" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecPrepared 1", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExec 2" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExec 2", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecParams 2" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecParams 2", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // Done - require.False(t, pgConn.NextResult(context.Background())) - - ensureConnValid(t, pgConn) -} - -func TestConnRecoverFromTimeout(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() require.Nil(t, err) - defer closeConn(t, pgConn) + require.Len(t, results, 3) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - cancel() - require.Nil(t, result) - assert.Equal(t, context.DeadlineExceeded, err) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - if assert.True(t, pgConn.RecoverFromTimeout(ctx)) { - result, err := pgConn.Exec(ctx, "select 1") - require.Nil(t, err) - assert.Len(t, result.Rows, 1) - assert.Len(t, result.Rows[0], 1) - assert.Equal(t, "1", string(result.Rows[0][0])) - } - cancel() + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - ensureConnValid(t, pgConn) -} - -func TestConnCancelQuery(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) - defer closeConn(t, pgConn) - - pgConn.SendExec("select current_database(), pg_sleep(5)") - err = pgConn.Flush(context.Background()) - require.Nil(t, err) - - err = pgConn.CancelRequest(context.Background()) - require.Nil(t, err) - - require.True(t, pgConn.NextResult(context.Background())) - _, err = pgConn.ResultReader().Close() - if err, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "57014", err.Code) - } else { - t.Errorf("expected pgconn.PgError got %v", err) - } - - require.False(t, pgConn.NextResult(context.Background())) - - ensureConnValid(t, pgConn) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } func TestCommandTag(t *testing.T) { @@ -593,10 +503,11 @@ func TestConnOnNotice(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `do $$ + multiResult := pgConn.Exec(context.Background(), `do $$ begin raise notice 'hello, world'; end$$;`) + err = multiResult.Close() require.Nil(t, err) assert.Equal(t, "hello, world", msg)