diff --git a/.gitignore b/.gitignore index 6eb9d442..e980f555 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .envrc -vendor/ \ No newline at end of file +vendor/ +.vscode diff --git a/benchmark_test.go b/benchmark_test.go index 8067c985..3295a90f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -39,46 +39,60 @@ func BenchmarkConnect(b *testing.B) { } func BenchmarkExec(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.Nil(b, err) - defer closeConn(b, conn) - expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + benchmarks := []struct { + name string + ctx context.Context + }{ + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, + {"empty context", context.TODO()}, + } - b.ResetTimer() + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) - for i := 0; i < b.N; i++ { - mrr := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + b.ResetTimer() - for mrr.NextResult() { - rr := mrr.ResultReader() + for i := 0; i < b.N; i++ { + mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") - rowCount := 0 - for rr.NextRow() { - rowCount++ - if len(rr.Values()) != len(expectedValues) { - b.Fatalf("unexpected number of values: %d", len(rr.Values())) - } - for i := range rr.Values() { - if !bytes.Equal(rr.Values()[i], expectedValues[i]) { - b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) } } - } - _, err = rr.Close() - if err != nil { - b.Fatal(err) + err := mrr.Close() + if err != nil { + b.Fatal(err) + } } - if rowCount != 1 { - b.Fatalf("unexpected rowCount: %d", rowCount) - } - } - - err := mrr.Close() - if err != nil { - b.Fatal(err) - } + }) } } @@ -130,40 +144,55 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { } func BenchmarkExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.Nil(b, err) - defer closeConn(b, conn) - - _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - require.Nil(b, err) - expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} - b.ResetTimer() + benchmarks := []struct { + name string + ctx context.Context + }{ + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, + {"empty context", context.TODO()}, + } - for i := 0; i < b.N; i++ { - rr := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) - rowCount := 0 - for rr.NextRow() { - rowCount++ - if len(rr.Values()) != len(expectedValues) { - b.Fatalf("unexpected number of values: %d", len(rr.Values())) - } - for i := range rr.Values() { - if !bytes.Equal(rr.Values()[i], expectedValues[i]) { - b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) } } - } - _, err = rr.Close() - - if err != nil { - b.Fatal(err) - } - if rowCount != 1 { - b.Fatalf("unexpected rowCount: %d", rowCount) - } + }) } } diff --git a/pgconn.go b/pgconn.go index 70c33c4f..c46dc6a6 100644 --- a/pgconn.go +++ b/pgconn.go @@ -362,13 +362,15 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != context.Background() { + select { + case <-ctx.Done(): + return &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() n, err := pgConn.conn.Write(buf) if err != nil { @@ -392,13 +394,15 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() msg, err := pgConn.receiveMessage() if err != nil { @@ -489,8 +493,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + if ctx != context.Background() { + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } // Ignore any errors sending Terminate message and waiting for server to close connection. // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully @@ -600,13 +606,15 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -693,12 +701,14 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) - contextWatcher.Watch(ctx) - defer contextWatcher.Unwatch() + if ctx != context.Background() { + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -726,14 +736,16 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return ctx.Err() - default: - } + if ctx != context.Background() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } for { msg, err := pgConn.receiveMessage() @@ -766,16 +778,17 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) } - pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -822,7 +835,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -848,7 +861,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -873,20 +886,22 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - select { - case <-ctx.Done(): - result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) - result.closed = true - pgConn.unlock() - return result - default: + if ctx != context.Background() { + select { + case <-ctx.Done(): + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) } - pgConn.contextWatcher.Watch(ctx) return result } -func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -907,14 +922,16 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - select { - case <-ctx.Done(): - pgConn.unlock() - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != context.Background() { + select { + case <-ctx.Done(): + pgConn.unlock() + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -966,13 +983,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -1359,15 +1378,17 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR } multiResult := &pgConn.multiResultReader - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) } - pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)