From 9decdbc2ec3357581cd2911141b3ead982f5026f Mon Sep 17 00:00:00 2001 From: bakape Date: Sat, 11 Jan 2020 16:53:50 +0200 Subject: [PATCH] Revert nil context support --- README.md | 4 +- benchmark_test.go | 25 +- doc.go | 1 - helper_test.go | 22 - pgconn.go | 62 +- pgconn_test.go | 1500 +++++++++++++++++++++------------------------ 6 files changed, 731 insertions(+), 883 deletions(-) diff --git a/README.md b/README.md index ddbfeaf3..5d14e914 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ low-level access to PostgreSQL functionality. ## Example Usage ```go -pgConn, err := pgconn.Connect(nil, os.Getenv("DATABASE_URL")) +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close() -result := pgConn.ExecParams(nil, "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } diff --git a/benchmark_test.go b/benchmark_test.go index 4cce5a97..3295a90f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -14,16 +14,9 @@ func BenchmarkConnect(b *testing.B) { benchmarks := []struct { name string env string - ctx context.Context }{ - // The first benchmark in the list sometimes executes faster, no matter how - // you reorder it. Nil context is still faster on average. - // - // Using and empty context other than context.Background() to compare. - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.TODO()}, - {"TCP", "PGX_TEST_TCP_CONN_STRING", context.TODO()}, - {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, - {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, } for _, bm := range benchmarks { @@ -35,10 +28,10 @@ func BenchmarkConnect(b *testing.B) { } for i := 0; i < b.N; i++ { - conn, err := pgconn.Connect(bm.ctx, connString) + conn, err := pgconn.Connect(context.Background(), connString) require.Nil(b, err) - err = conn.Close(bm.ctx) + err = conn.Close(context.Background()) require.Nil(b, err) } }) @@ -51,9 +44,10 @@ func BenchmarkExec(b *testing.B) { name string ctx context.Context }{ - // Using and empty context other than context.Background() to compare. + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, {"empty context", context.TODO()}, - {"nil context", nil}, } for _, bm := range benchmarks { @@ -156,9 +150,10 @@ func BenchmarkExecPrepared(b *testing.B) { name string ctx context.Context }{ - // Using and empty context other than context.Background() to compare. + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, {"empty context", context.TODO()}, - {"nil context", nil}, } for _, bm := range benchmarks { diff --git a/doc.go b/doc.go index 25382c68..cde58cd8 100644 --- a/doc.go +++ b/doc.go @@ -22,7 +22,6 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. -A nil context can be passed for convenience. This has the same effect as passing context.Background(). The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/helper_test.go b/helper_test.go index 1cb05fd2..1a3ca75e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -29,25 +29,3 @@ func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } - -// Run subtest both with a context.Background() and nil context -func splitOnContext(t *testing.T, test func(t *testing.T, ctx context.Context)) { - t.Helper() - - cases := [...]struct { - name string - ctx context.Context - }{ - {"background context", context.Background()}, - {"nil context", nil}, - } - - for i := range cases { - c := cases[i] - t.Run(c.name, func(t *testing.T) { - t.Helper() - t.Parallel() - test(t, c.ctx) - }) - } -} diff --git a/pgconn.go b/pgconn.go index b8ea9df7..9763b319 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,10 +116,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } - if ctx == nil { - ctx = context.Background() - } - // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -366,9 +362,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return &contextAlreadyDoneError{err: ctx.Err()} @@ -400,9 +394,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -501,9 +493,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } @@ -602,9 +592,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -693,19 +681,13 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - _ctx := ctx - if _ctx == nil { - _ctx = context.Background() - } - cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String()) + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, @@ -740,9 +722,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return ctx.Err() @@ -784,11 +764,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - switch ctx { - case nil: - pgConn.multiResultReader.ctx = context.Background() - case context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true @@ -882,9 +858,6 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by ctx: ctx, } result := &pgConn.resultReader - if ctx == nil { - pgConn.resultReader.ctx = context.Background() - } if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) @@ -899,9 +872,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) @@ -937,9 +908,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): pgConn.unlock() @@ -1000,9 +969,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -1396,11 +1363,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - switch ctx { - case nil: - pgConn.multiResultReader.ctx = context.Background() - case context.Background(): - default: + + if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true diff --git a/pgconn_test.go b/pgconn_test.go index 30d20229..6b57dd09 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -27,33 +27,31 @@ import ( ) func TestConnect(t *testing.T) { - splitOnContext(t, func(t *testing.T, ctx context.Context) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, - } + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) - closeConn(t, conn) - }) - } - }) + closeConn(t, conn) + }) + } } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure @@ -61,21 +59,19 @@ func TestConnect(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } - closeConn(t, conn) - }) + closeConn(t, conn) } type pgmockWaitStep time.Duration @@ -142,259 +138,233 @@ func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectInvalidUser(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - config.User = "pgxinvalidusertest" + config.User = "pgxinvalidusertest" - _, err = pgconn.ConnectConfig(ctx, config) - require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } - }) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") - if err == nil { - conn.Close(ctx) - t.Fatal("Expected error establishing connection to bad port") - } - }) + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close(context.Background()) + t.Fatal("Expected error establishing connection to bad port") + } } func TestConnectCustomDialer(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - looked := false - config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { - looked = true - return net.LookupHost(host) - } + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - require.True(t, looked) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, conn) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, conn) - result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) - }) + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) } func TestConnectWithFallback(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount++ - return net.Dial(network, address) + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") } + return nil + } - acceptConnCount := 0 - config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount++ - if acceptConnCount < 2 { - return errors.New("reject first conn") - } - return nil - } - - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(ctx, config) - if !assert.NotNil(t, err) { - conn.Close(ctx) - } - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + if !assert.NotNil(t, err) { + conn.Close(context.Background()) + } } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() - return err - } + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) - results, err := conn.Exec(ctx, "show search_path;").ReadAll() - require.NoError(t, err) - defer closeConn(t, conn) + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) - assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) - }) + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config := &pgconn.Config{} + config := &pgconn.Config{} - require.PanicsWithValue( - t, - "config must be created by ParseConfig", - func() { pgconn.ConnectConfig(ctx, config) }, - ) - }) + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnPrepareContextPrecanceled(t *testing.T) { @@ -418,126 +388,116 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecEmpty(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, ";") + multiResult := pgConn.Exec(context.Background(), ";") - resultCount := 0 - for multiResult.NextResult() { - resultCount++ - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 2) + assert.Len(t, results, 2) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() - require.NotNil(t, err) + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecContextCanceled(t *testing.T) { @@ -578,103 +538,95 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() - require.NotNil(t, result.Err) - var pgErr *pgconn.PgError - require.True(t, errors.As(result.Err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsCanceled(t *testing.T) { @@ -719,92 +671,86 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -854,67 +800,63 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) - require.NoError(t, err) + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(ctx, batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) - }) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) - _, err = pgConn.ExecBatch(ctx, batch).ReadAll() - require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecBatchPrecanceled(t *testing.T) { @@ -953,82 +895,76 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - queryCount := 100000 - args := make([]string, queryCount) + queryCount := 100000 + args := make([]string, queryCount) - for i := range args { - args[i] = strconv.Itoa(i) - batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) - } + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } - results, err := pgConn.ExecBatch(ctx, batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, queryCount) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) - for i := range args { - require.Len(t, results[i].Rows, 1) - require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) - } - }) + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) - batch.ExecParams("select 1/0", nil, nil, nil, nil) - _, err = pgConn.ExecBatch(ctx, batch).ReadAll() - require.Error(t, err) + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) - result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() - require.Equal(t, "0", string(result.Rows[0][0])) - }) + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) } func TestConnLocking(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - mrr := pgConn.Exec(ctx, "select 'Hello, world'") - _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "conn busy", err.Error()) - assert.True(t, pgconn.SafeToRetry(err)) + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) - results, err := mrr.ReadAll() - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestCommandTag(t *testing.T) { @@ -1057,97 +993,91 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, `do $$ - begin - raise notice 'hello, world'; - end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + multiResult := pgConn.Exec(context.Background(), `do $$ +begin + raise notice 'hello, world'; +end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnOnNotification(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, "select 1").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - err = pgConn.WaitForNotification(ctx) - require.NoError(t, err) + err = pgConn.WaitForNotification(context.Background()) + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { @@ -1189,100 +1119,94 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") - require.NoError(t, err) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() require.NoError(t, err) - defer closeConn(t, pgConn) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - inputBytes := make([]byte, 0) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() - require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - - res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") - require.NoError(t, err) - - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0)) + outputWriter := bytes.NewBuffer(make([]byte, 0)) - res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToCanceled(t *testing.T) { @@ -1326,39 +1250,37 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) - defer closeConn(t, pgConn) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - srcBuf := &bytes.Buffer{} + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } + assert.Equal(t, inputRows, result.Rows) - ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromCanceled(t *testing.T) { @@ -1436,163 +1358,153 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) - defer closeConn(t, pgConn) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + err = gw.Close() + require.NoError(t, err) - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) + _, err = f.Seek(0, 0) + require.NoError(t, err) - gw := gzip.NewWriter(f) + gr, err := gzip.NewReader(f) + require.NoError(t, err) - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - err = gw.Close() - require.NoError(t, err) + err = gr.Close() + require.NoError(t, err) - _, err = f.Seek(0, 0) - require.NoError(t, err) + err = f.Close() + require.NoError(t, err) - gr, err := gzip.NewReader(f) - require.NoError(t, err) + err = os.Remove(f.Name()) + require.NoError(t, err) - ct, err := pgConn.CopyFrom(ctx, gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) - err = gr.Close() - require.NoError(t, err) + assert.Equal(t, inputRows, result.Rows) - err = f.Close() - require.NoError(t, err) - - err = os.Remove(f.Name()) - require.NoError(t, err) - - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnEscapeString(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) } + } - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) - } - } - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCancelRequest(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(2)") + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) - err = pgConn.CancelRequest(ctx) - require.NoError(t, err) + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) - for multiResult.NextResult() { - } - err = multiResult.Close() + for multiResult.NextResult() { + } + err = multiResult.Close() - require.IsType(t, &pgconn.PgError{}, err) - require.Equal(t, "57014", err.(*pgconn.PgError).Code) + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnSendBytesAndReceiveMessage(t *testing.T) { @@ -1635,13 +1547,13 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(nil, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } - defer pgConn.Close(nil) + defer pgConn.Close(context.Background()) - result := pgConn.ExecParams(nil, "select generate_series(1,3)", nil, nil, nil, nil).Read() + result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) }