From 89416dd80542cc62f45af214ca0722c32e6624ca Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 13:09:50 +0200 Subject: [PATCH 1/5] Enable passing nil context --- .gitignore | 3 +- doc.go | 3 + pgconn.go | 187 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 116 insertions(+), 77 deletions(-) diff --git a/.gitignore b/.gitignore index 6eb9d442..e980f555 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .envrc -vendor/ \ No newline at end of file +vendor/ +.vscode diff --git a/doc.go b/doc.go index cde58cd8..12ed6630 100644 --- a/doc.go +++ b/doc.go @@ -23,6 +23,9 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. +A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional +slight performance increase, if you don't need the operation to be cancellable. + The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. */ diff --git a/pgconn.go b/pgconn.go index 4c75d367..3b90b802 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,6 +116,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } + if ctx == nil { + ctx = context.Background() + } + // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -362,13 +366,15 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() n, err := pgConn.conn.Write(buf) if err != nil { @@ -392,13 +398,15 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() msg, err := pgConn.receiveMessage() if err != nil { @@ -489,8 +497,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + if ctx != nil { + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } // Ignore any errors sending Terminate message and waiting for server to close connection. // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully @@ -586,13 +596,15 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -673,18 +685,24 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + _ctx := ctx + if _ctx == nil { + _ctx = context.Background() + } + cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) - contextWatcher.Watch(ctx) - defer contextWatcher.Unwatch() + if ctx != nil { + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -712,14 +730,16 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return ctx.Err() - default: - } + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } for { msg, err := pgConn.receiveMessage() @@ -752,16 +772,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != nil { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } else { + pgConn.multiResultReader.ctx = context.Background() } - pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -808,7 +831,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -834,7 +857,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -845,6 +868,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by ctx: ctx, } result := &pgConn.resultReader + if ctx == nil { + pgConn.resultReader.ctx = context.Background() + } if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) @@ -859,20 +885,22 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - select { - case <-ctx.Done(): - result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) - result.closed = true - pgConn.unlock() - return result - default: + if ctx != nil { + select { + case <-ctx.Done(): + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) } - pgConn.contextWatcher.Watch(ctx) return result } -func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -893,14 +921,16 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - select { - case <-ctx.Done(): - pgConn.unlock() - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + pgConn.unlock() + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -952,13 +982,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -1344,16 +1376,19 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != nil { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } else { + pgConn.multiResultReader.ctx = context.Background() } - pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) From 719623452110bc4bce0e2358db9d3df658777eeb Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 13:10:04 +0200 Subject: [PATCH 2/5] Benchmark nil context execution --- benchmark_test.go | 156 +++++++++++++++++++++++++++------------------- 1 file changed, 93 insertions(+), 63 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 8067c985..1914e07a 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -14,9 +14,14 @@ func BenchmarkConnect(b *testing.B) { benchmarks := []struct { name string env string + ctx context.Context }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + // The first benchmark in the list sometimes executes faster, no matter how + // you reorder it. Nil context is still faster on average. + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.Background()}, + {"TCP", "PGX_TEST_TCP_CONN_STRING", context.Background()}, + {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, + {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, } for _, bm := range benchmarks { @@ -28,10 +33,10 @@ func BenchmarkConnect(b *testing.B) { } for i := 0; i < b.N; i++ { - conn, err := pgconn.Connect(context.Background(), connString) + conn, err := pgconn.Connect(bm.ctx, connString) require.Nil(b, err) - err = conn.Close(context.Background()) + err = conn.Close(bm.ctx) require.Nil(b, err) } }) @@ -39,46 +44,58 @@ func BenchmarkConnect(b *testing.B) { } func BenchmarkExec(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.Nil(b, err) - defer closeConn(b, conn) - expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + benchmarks := []struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } - b.ResetTimer() + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) - for i := 0; i < b.N; i++ { - mrr := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + b.ResetTimer() - for mrr.NextResult() { - rr := mrr.ResultReader() + for i := 0; i < b.N; i++ { + mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") - rowCount := 0 - for rr.NextRow() { - rowCount++ - if len(rr.Values()) != len(expectedValues) { - b.Fatalf("unexpected number of values: %d", len(rr.Values())) - } - for i := range rr.Values() { - if !bytes.Equal(rr.Values()[i], expectedValues[i]) { - b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) } } - } - _, err = rr.Close() - if err != nil { - b.Fatal(err) + err := mrr.Close() + if err != nil { + b.Fatal(err) + } } - if rowCount != 1 { - b.Fatalf("unexpected rowCount: %d", rowCount) - } - } - - err := mrr.Close() - if err != nil { - b.Fatal(err) - } + }) } } @@ -130,40 +147,53 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { } func BenchmarkExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.Nil(b, err) - defer closeConn(b, conn) - - _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - require.Nil(b, err) - expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} - b.ResetTimer() + benchmarks := []struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } - for i := 0; i < b.N; i++ { - rr := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) - rowCount := 0 - for rr.NextRow() { - rowCount++ - if len(rr.Values()) != len(expectedValues) { - b.Fatalf("unexpected number of values: %d", len(rr.Values())) - } - for i := range rr.Values() { - if !bytes.Equal(rr.Values()[i], expectedValues[i]) { - b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) } } - } - _, err = rr.Close() - - if err != nil { - b.Fatal(err) - } - if rowCount != 1 { - b.Fatalf("unexpected rowCount: %d", rowCount) - } + }) } } From 4d345164f1027d985717335e841868f60ca69ac2 Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 14:36:38 +0200 Subject: [PATCH 3/5] Branch tests for nil context --- README.md | 4 +- helper_test.go | 22 + pgconn_test.go | 1500 +++++++++++++++++++++++++----------------------- 3 files changed, 818 insertions(+), 708 deletions(-) diff --git a/README.md b/README.md index 5d14e914..ddbfeaf3 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ low-level access to PostgreSQL functionality. ## Example Usage ```go -pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +pgConn, err := pgconn.Connect(nil, os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close() -result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(nil, "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } diff --git a/helper_test.go b/helper_test.go index 1a3ca75e..1cb05fd2 100644 --- a/helper_test.go +++ b/helper_test.go @@ -29,3 +29,25 @@ func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } + +// Run subtest both with a context.Background() and nil context +func splitOnContext(t *testing.T, test func(t *testing.T, ctx context.Context)) { + t.Helper() + + cases := [...]struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } + + for i := range cases { + c := cases[i] + t.Run(c.name, func(t *testing.T) { + t.Helper() + t.Parallel() + test(t, c.ctx) + }) + } +} diff --git a/pgconn_test.go b/pgconn_test.go index 6b57dd09..30d20229 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -27,31 +27,33 @@ import ( ) func TestConnect(t *testing.T) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) - closeConn(t, conn) - }) - } + closeConn(t, conn) + }) + } + }) } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure @@ -59,19 +61,21 @@ func TestConnect(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } - closeConn(t, conn) + closeConn(t, conn) + }) } type pgmockWaitStep time.Duration @@ -138,233 +142,259 @@ func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectInvalidUser(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - config.User = "pgxinvalidusertest" + config.User = "pgxinvalidusertest" - _, err = pgconn.ConnectConfig(context.Background(), config) - require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } + _, err = pgconn.ConnectConfig(ctx, config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } + }) } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") - if err == nil { - conn.Close(context.Background()) - t.Fatal("Expected error establishing connection to bad port") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") + if err == nil { + conn.Close(ctx) + t.Fatal("Expected error establishing connection to bad port") + } + }) } func TestConnectCustomDialer(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) + }) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - looked := false - config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { - looked = true - return net.LookupHost(host) - } + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, looked) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) + }) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) - result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) + result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) + }) } func TestConnectWithFallback(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + }) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount++ - return net.Dial(network, address) - } - - acceptConnCount := 0 - config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount++ - if acceptConnCount < 2 { - return errors.New("reject first conn") + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) } - return nil - } - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if !assert.NotNil(t, err) { - conn.Close(context.Background()) - } + conn, err := pgconn.ConnectConfig(ctx, config) + if !assert.NotNil(t, err) { + conn.Close(ctx) + } + }) } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() - return err - } + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) - results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() - require.NoError(t, err) - defer closeConn(t, conn) + results, err := conn.Exec(ctx, "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) - assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) + }) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() - config := &pgconn.Config{} + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config := &pgconn.Config{} - require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) + require.PanicsWithValue( + t, + "config must be created by ParseConfig", + func() { pgconn.ConnectConfig(ctx, config) }, + ) + }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) + psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnPrepareContextPrecanceled(t *testing.T) { @@ -388,116 +418,126 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecEmpty(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), ";") + multiResult := pgConn.Exec(ctx, ";") - resultCount := 0 - for multiResult.NextResult() { - resultCount++ - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 2) + assert.Len(t, results, 2) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } + results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() - require.NotNil(t, err) + _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecContextCanceled(t *testing.T) { @@ -538,95 +578,103 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() - require.NotNil(t, result.Err) - var pgErr *pgconn.PgError - require.True(t, errors.As(result.Err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsCanceled(t *testing.T) { @@ -671,86 +719,92 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -800,63 +854,67 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) + _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + }) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) - _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecBatchPrecanceled(t *testing.T) { @@ -895,76 +953,82 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - queryCount := 100000 - args := make([]string, queryCount) + queryCount := 100000 + args := make([]string, queryCount) - for i := range args { - args[i] = strconv.Itoa(i) - batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) - } + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, queryCount) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) - for i := range args { - require.Len(t, results[i].Rows, 1) - require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) - } + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } + }) } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) - batch.ExecParams("select 1/0", nil, nil, nil, nil) - _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.Error(t, err) + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) - result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() - require.Equal(t, "0", string(result.Rows[0][0])) + result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) + }) } func TestConnLocking(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "conn busy", err.Error()) - assert.True(t, pgconn.SafeToRetry(err)) + mrr := pgConn.Exec(ctx, "select 'Hello, world'") + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) - results, err := mrr.ReadAll() - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestCommandTag(t *testing.T) { @@ -993,91 +1057,97 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), `do $$ -begin - raise notice 'hello, world'; -end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + multiResult := pgConn.Exec(ctx, `do $$ + begin + raise notice 'hello, world'; + end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnOnNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - err = pgConn.WaitForNotification(context.Background()) - require.NoError(t, err) + err = pgConn.WaitForNotification(ctx) + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { @@ -1119,94 +1189,100 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) - - inputBytes := make([]byte, 0) - - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) + inputBytes := make([]byte, 0) - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.NoError(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } - ensureConnValid(t, pgConn) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0)) + outputWriter := bytes.NewBuffer(make([]byte, 0)) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyToCanceled(t *testing.T) { @@ -1250,37 +1326,39 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - srcBuf := &bytes.Buffer{} - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - } + defer closeConn(t, pgConn) - ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) + srcBuf := &bytes.Buffer{} - assert.Equal(t, inputRows, result.Rows) + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } - ensureConnValid(t, pgConn) + ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromCanceled(t *testing.T) { @@ -1358,153 +1436,163 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) - - gw := gzip.NewWriter(f) - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - } + defer closeConn(t, pgConn) - err = gw.Close() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - _, err = f.Seek(0, 0) - require.NoError(t, err) + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) - gr, err := gzip.NewReader(f) - require.NoError(t, err) + gw := gzip.NewWriter(f) - ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } - err = gr.Close() - require.NoError(t, err) + err = gw.Close() + require.NoError(t, err) - err = f.Close() - require.NoError(t, err) + _, err = f.Seek(0, 0) + require.NoError(t, err) - err = os.Remove(f.Name()) - require.NoError(t, err) + gr, err := gzip.NewReader(f) + require.NoError(t, err) - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) + ct, err := pgConn.CopyFrom(ctx, gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - assert.Equal(t, inputRows, result.Rows) + err = gr.Close() + require.NoError(t, err) - ensureConnValid(t, pgConn) + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnEscapeString(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, - } - - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, } - } - ensureConnValid(t, pgConn) + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) + }) } func TestConnCancelRequest(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) - err = pgConn.CancelRequest(context.Background()) - require.NoError(t, err) + err = pgConn.CancelRequest(ctx) + require.NoError(t, err) - for multiResult.NextResult() { - } - err = multiResult.Close() + for multiResult.NextResult() { + } + err = multiResult.Close() - require.IsType(t, &pgconn.PgError{}, err) - require.Equal(t, "57014", err.(*pgconn.PgError).Code) + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnSendBytesAndReceiveMessage(t *testing.T) { @@ -1547,13 +1635,13 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(nil, os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } - defer pgConn.Close(context.Background()) + defer pgConn.Close(nil) - result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() + result := pgConn.ExecParams(nil, "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) } From 93722181071cd124ad5bb67122d33b31d4ada632 Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 19:34:56 +0200 Subject: [PATCH 4/5] Don't synchronize with context.Background() --- benchmark_test.go | 12 +++++++---- doc.go | 4 +--- pgconn.go | 52 +++++++++++++++++++++++++++++++++-------------- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 1914e07a..4cce5a97 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -18,8 +18,10 @@ func BenchmarkConnect(b *testing.B) { }{ // The first benchmark in the list sometimes executes faster, no matter how // you reorder it. Nil context is still faster on average. - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.Background()}, - {"TCP", "PGX_TEST_TCP_CONN_STRING", context.Background()}, + // + // Using and empty context other than context.Background() to compare. + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.TODO()}, + {"TCP", "PGX_TEST_TCP_CONN_STRING", context.TODO()}, {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, } @@ -49,7 +51,8 @@ func BenchmarkExec(b *testing.B) { name string ctx context.Context }{ - {"background context", context.Background()}, + // Using and empty context other than context.Background() to compare. + {"empty context", context.TODO()}, {"nil context", nil}, } @@ -153,7 +156,8 @@ func BenchmarkExecPrepared(b *testing.B) { name string ctx context.Context }{ - {"background context", context.Background()}, + // Using and empty context other than context.Background() to compare. + {"empty context", context.TODO()}, {"nil context", nil}, } diff --git a/doc.go b/doc.go index 12ed6630..25382c68 100644 --- a/doc.go +++ b/doc.go @@ -22,9 +22,7 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. - -A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional -slight performance increase, if you don't need the operation to be cancellable. +A nil context can be passed for convenience. This has the same effect as passing context.Background(). The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/pgconn.go b/pgconn.go index 3b90b802..b8ea9df7 100644 --- a/pgconn.go +++ b/pgconn.go @@ -366,7 +366,9 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return &contextAlreadyDoneError{err: ctx.Err()} @@ -398,7 +400,9 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -497,7 +501,9 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } @@ -596,7 +602,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -695,7 +703,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, @@ -730,7 +740,9 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return ctx.Err() @@ -772,7 +784,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - if ctx != nil { + switch ctx { + case nil: + pgConn.multiResultReader.ctx = context.Background() + case context.Background(): + default: select { case <-ctx.Done(): multiResult.closed = true @@ -782,8 +798,6 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { default: } pgConn.contextWatcher.Watch(ctx) - } else { - pgConn.multiResultReader.ctx = context.Background() } buf := pgConn.wbuf @@ -885,7 +899,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) @@ -921,7 +937,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): pgConn.unlock() @@ -982,7 +1000,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -1376,7 +1396,11 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - if ctx != nil { + switch ctx { + case nil: + pgConn.multiResultReader.ctx = context.Background() + case context.Background(): + default: select { case <-ctx.Done(): multiResult.closed = true @@ -1386,8 +1410,6 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR default: } pgConn.contextWatcher.Watch(ctx) - } else { - pgConn.multiResultReader.ctx = context.Background() } batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) From 9decdbc2ec3357581cd2911141b3ead982f5026f Mon Sep 17 00:00:00 2001 From: bakape Date: Sat, 11 Jan 2020 16:53:50 +0200 Subject: [PATCH 5/5] Revert nil context support --- README.md | 4 +- benchmark_test.go | 25 +- doc.go | 1 - helper_test.go | 22 - pgconn.go | 62 +- pgconn_test.go | 1500 +++++++++++++++++++++------------------------ 6 files changed, 731 insertions(+), 883 deletions(-) diff --git a/README.md b/README.md index ddbfeaf3..5d14e914 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ low-level access to PostgreSQL functionality. ## Example Usage ```go -pgConn, err := pgconn.Connect(nil, os.Getenv("DATABASE_URL")) +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close() -result := pgConn.ExecParams(nil, "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } diff --git a/benchmark_test.go b/benchmark_test.go index 4cce5a97..3295a90f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -14,16 +14,9 @@ func BenchmarkConnect(b *testing.B) { benchmarks := []struct { name string env string - ctx context.Context }{ - // The first benchmark in the list sometimes executes faster, no matter how - // you reorder it. Nil context is still faster on average. - // - // Using and empty context other than context.Background() to compare. - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.TODO()}, - {"TCP", "PGX_TEST_TCP_CONN_STRING", context.TODO()}, - {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, - {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, } for _, bm := range benchmarks { @@ -35,10 +28,10 @@ func BenchmarkConnect(b *testing.B) { } for i := 0; i < b.N; i++ { - conn, err := pgconn.Connect(bm.ctx, connString) + conn, err := pgconn.Connect(context.Background(), connString) require.Nil(b, err) - err = conn.Close(bm.ctx) + err = conn.Close(context.Background()) require.Nil(b, err) } }) @@ -51,9 +44,10 @@ func BenchmarkExec(b *testing.B) { name string ctx context.Context }{ - // Using and empty context other than context.Background() to compare. + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, {"empty context", context.TODO()}, - {"nil context", nil}, } for _, bm := range benchmarks { @@ -156,9 +150,10 @@ func BenchmarkExecPrepared(b *testing.B) { name string ctx context.Context }{ - // Using and empty context other than context.Background() to compare. + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, {"empty context", context.TODO()}, - {"nil context", nil}, } for _, bm := range benchmarks { diff --git a/doc.go b/doc.go index 25382c68..cde58cd8 100644 --- a/doc.go +++ b/doc.go @@ -22,7 +22,6 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. -A nil context can be passed for convenience. This has the same effect as passing context.Background(). The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/helper_test.go b/helper_test.go index 1cb05fd2..1a3ca75e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -29,25 +29,3 @@ func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } - -// Run subtest both with a context.Background() and nil context -func splitOnContext(t *testing.T, test func(t *testing.T, ctx context.Context)) { - t.Helper() - - cases := [...]struct { - name string - ctx context.Context - }{ - {"background context", context.Background()}, - {"nil context", nil}, - } - - for i := range cases { - c := cases[i] - t.Run(c.name, func(t *testing.T) { - t.Helper() - t.Parallel() - test(t, c.ctx) - }) - } -} diff --git a/pgconn.go b/pgconn.go index b8ea9df7..9763b319 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,10 +116,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } - if ctx == nil { - ctx = context.Background() - } - // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -366,9 +362,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return &contextAlreadyDoneError{err: ctx.Err()} @@ -400,9 +394,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -501,9 +493,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } @@ -602,9 +592,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -693,19 +681,13 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - _ctx := ctx - if _ctx == nil { - _ctx = context.Background() - } - cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String()) + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, @@ -740,9 +722,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return ctx.Err() @@ -784,11 +764,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - switch ctx { - case nil: - pgConn.multiResultReader.ctx = context.Background() - case context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true @@ -882,9 +858,6 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by ctx: ctx, } result := &pgConn.resultReader - if ctx == nil { - pgConn.resultReader.ctx = context.Background() - } if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) @@ -899,9 +872,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) @@ -937,9 +908,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): pgConn.unlock() @@ -1000,9 +969,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -1396,11 +1363,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - switch ctx { - case nil: - pgConn.multiResultReader.ctx = context.Background() - case context.Background(): - default: + + if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true diff --git a/pgconn_test.go b/pgconn_test.go index 30d20229..6b57dd09 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -27,33 +27,31 @@ import ( ) func TestConnect(t *testing.T) { - splitOnContext(t, func(t *testing.T, ctx context.Context) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, - } + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) - closeConn(t, conn) - }) - } - }) + closeConn(t, conn) + }) + } } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure @@ -61,21 +59,19 @@ func TestConnect(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } - closeConn(t, conn) - }) + closeConn(t, conn) } type pgmockWaitStep time.Duration @@ -142,259 +138,233 @@ func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectInvalidUser(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - config.User = "pgxinvalidusertest" + config.User = "pgxinvalidusertest" - _, err = pgconn.ConnectConfig(ctx, config) - require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } - }) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") - if err == nil { - conn.Close(ctx) - t.Fatal("Expected error establishing connection to bad port") - } - }) + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close(context.Background()) + t.Fatal("Expected error establishing connection to bad port") + } } func TestConnectCustomDialer(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - looked := false - config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { - looked = true - return net.LookupHost(host) - } + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - require.True(t, looked) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, conn) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, conn) - result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) - }) + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) } func TestConnectWithFallback(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount++ - return net.Dial(network, address) + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") } + return nil + } - acceptConnCount := 0 - config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount++ - if acceptConnCount < 2 { - return errors.New("reject first conn") - } - return nil - } - - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(ctx, config) - if !assert.NotNil(t, err) { - conn.Close(ctx) - } - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + if !assert.NotNil(t, err) { + conn.Close(context.Background()) + } } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() - return err - } + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) - results, err := conn.Exec(ctx, "show search_path;").ReadAll() - require.NoError(t, err) - defer closeConn(t, conn) + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) - assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) - }) + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config := &pgconn.Config{} + config := &pgconn.Config{} - require.PanicsWithValue( - t, - "config must be created by ParseConfig", - func() { pgconn.ConnectConfig(ctx, config) }, - ) - }) + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnPrepareContextPrecanceled(t *testing.T) { @@ -418,126 +388,116 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecEmpty(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, ";") + multiResult := pgConn.Exec(context.Background(), ";") - resultCount := 0 - for multiResult.NextResult() { - resultCount++ - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 2) + assert.Len(t, results, 2) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() - require.NotNil(t, err) + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecContextCanceled(t *testing.T) { @@ -578,103 +538,95 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() - require.NotNil(t, result.Err) - var pgErr *pgconn.PgError - require.True(t, errors.As(result.Err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsCanceled(t *testing.T) { @@ -719,92 +671,86 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -854,67 +800,63 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) - require.NoError(t, err) + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(ctx, batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) - }) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) - _, err = pgConn.ExecBatch(ctx, batch).ReadAll() - require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecBatchPrecanceled(t *testing.T) { @@ -953,82 +895,76 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - queryCount := 100000 - args := make([]string, queryCount) + queryCount := 100000 + args := make([]string, queryCount) - for i := range args { - args[i] = strconv.Itoa(i) - batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) - } + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } - results, err := pgConn.ExecBatch(ctx, batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, queryCount) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) - for i := range args { - require.Len(t, results[i].Rows, 1) - require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) - } - }) + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) - batch.ExecParams("select 1/0", nil, nil, nil, nil) - _, err = pgConn.ExecBatch(ctx, batch).ReadAll() - require.Error(t, err) + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) - result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() - require.Equal(t, "0", string(result.Rows[0][0])) - }) + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) } func TestConnLocking(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - mrr := pgConn.Exec(ctx, "select 'Hello, world'") - _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "conn busy", err.Error()) - assert.True(t, pgconn.SafeToRetry(err)) + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) - results, err := mrr.ReadAll() - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestCommandTag(t *testing.T) { @@ -1057,97 +993,91 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, `do $$ - begin - raise notice 'hello, world'; - end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + multiResult := pgConn.Exec(context.Background(), `do $$ +begin + raise notice 'hello, world'; +end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnOnNotification(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, "select 1").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - err = pgConn.WaitForNotification(ctx) - require.NoError(t, err) + err = pgConn.WaitForNotification(context.Background()) + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { @@ -1189,100 +1119,94 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") - require.NoError(t, err) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() require.NoError(t, err) - defer closeConn(t, pgConn) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - inputBytes := make([]byte, 0) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() - require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - - res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") - require.NoError(t, err) - - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0)) + outputWriter := bytes.NewBuffer(make([]byte, 0)) - res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToCanceled(t *testing.T) { @@ -1326,39 +1250,37 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) - defer closeConn(t, pgConn) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - srcBuf := &bytes.Buffer{} + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } + assert.Equal(t, inputRows, result.Rows) - ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromCanceled(t *testing.T) { @@ -1436,163 +1358,153 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) - defer closeConn(t, pgConn) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + err = gw.Close() + require.NoError(t, err) - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) + _, err = f.Seek(0, 0) + require.NoError(t, err) - gw := gzip.NewWriter(f) + gr, err := gzip.NewReader(f) + require.NoError(t, err) - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - err = gw.Close() - require.NoError(t, err) + err = gr.Close() + require.NoError(t, err) - _, err = f.Seek(0, 0) - require.NoError(t, err) + err = f.Close() + require.NoError(t, err) - gr, err := gzip.NewReader(f) - require.NoError(t, err) + err = os.Remove(f.Name()) + require.NoError(t, err) - ct, err := pgConn.CopyFrom(ctx, gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) - err = gr.Close() - require.NoError(t, err) + assert.Equal(t, inputRows, result.Rows) - err = f.Close() - require.NoError(t, err) - - err = os.Remove(f.Name()) - require.NoError(t, err) - - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnEscapeString(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) } + } - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) - } - } - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCancelRequest(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(2)") + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) - err = pgConn.CancelRequest(ctx) - require.NoError(t, err) + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) - for multiResult.NextResult() { - } - err = multiResult.Close() + for multiResult.NextResult() { + } + err = multiResult.Close() - require.IsType(t, &pgconn.PgError{}, err) - require.Equal(t, "57014", err.(*pgconn.PgError).Code) + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnSendBytesAndReceiveMessage(t *testing.T) { @@ -1635,13 +1547,13 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(nil, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } - defer pgConn.Close(nil) + defer pgConn.Close(context.Background()) - result := pgConn.ExecParams(nil, "select generate_series(1,3)", nil, nil, nil, nil).Read() + result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) }