From 4d345164f1027d985717335e841868f60ca69ac2 Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 14:36:38 +0200 Subject: [PATCH] Branch tests for nil context --- README.md | 4 +- helper_test.go | 22 + pgconn_test.go | 1500 +++++++++++++++++++++++++----------------------- 3 files changed, 818 insertions(+), 708 deletions(-) diff --git a/README.md b/README.md index 5d14e914..ddbfeaf3 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ low-level access to PostgreSQL functionality. ## Example Usage ```go -pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +pgConn, err := pgconn.Connect(nil, os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close() -result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(nil, "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } diff --git a/helper_test.go b/helper_test.go index 1a3ca75e..1cb05fd2 100644 --- a/helper_test.go +++ b/helper_test.go @@ -29,3 +29,25 @@ func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } + +// Run subtest both with a context.Background() and nil context +func splitOnContext(t *testing.T, test func(t *testing.T, ctx context.Context)) { + t.Helper() + + cases := [...]struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } + + for i := range cases { + c := cases[i] + t.Run(c.name, func(t *testing.T) { + t.Helper() + t.Parallel() + test(t, c.ctx) + }) + } +} diff --git a/pgconn_test.go b/pgconn_test.go index 6b57dd09..30d20229 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -27,31 +27,33 @@ import ( ) func TestConnect(t *testing.T) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) - closeConn(t, conn) - }) - } + closeConn(t, conn) + }) + } + }) } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure @@ -59,19 +61,21 @@ func TestConnect(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } - closeConn(t, conn) + closeConn(t, conn) + }) } type pgmockWaitStep time.Duration @@ -138,233 +142,259 @@ func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectInvalidUser(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - config.User = "pgxinvalidusertest" + config.User = "pgxinvalidusertest" - _, err = pgconn.ConnectConfig(context.Background(), config) - require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } + _, err = pgconn.ConnectConfig(ctx, config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } + }) } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") - if err == nil { - conn.Close(context.Background()) - t.Fatal("Expected error establishing connection to bad port") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") + if err == nil { + conn.Close(ctx) + t.Fatal("Expected error establishing connection to bad port") + } + }) } func TestConnectCustomDialer(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) + }) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - looked := false - config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { - looked = true - return net.LookupHost(host) - } + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, looked) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) + }) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) - result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) + result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) + }) } func TestConnectWithFallback(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + }) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount++ - return net.Dial(network, address) - } - - acceptConnCount := 0 - config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount++ - if acceptConnCount < 2 { - return errors.New("reject first conn") + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) } - return nil - } - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if !assert.NotNil(t, err) { - conn.Close(context.Background()) - } + conn, err := pgconn.ConnectConfig(ctx, config) + if !assert.NotNil(t, err) { + conn.Close(ctx) + } + }) } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() - return err - } + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) - results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() - require.NoError(t, err) - defer closeConn(t, conn) + results, err := conn.Exec(ctx, "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) - assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) + }) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() - config := &pgconn.Config{} + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config := &pgconn.Config{} - require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) + require.PanicsWithValue( + t, + "config must be created by ParseConfig", + func() { pgconn.ConnectConfig(ctx, config) }, + ) + }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) + psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnPrepareContextPrecanceled(t *testing.T) { @@ -388,116 +418,126 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecEmpty(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), ";") + multiResult := pgConn.Exec(ctx, ";") - resultCount := 0 - for multiResult.NextResult() { - resultCount++ - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 2) + assert.Len(t, results, 2) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } + results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() - require.NotNil(t, err) + _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecContextCanceled(t *testing.T) { @@ -538,95 +578,103 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() - require.NotNil(t, result.Err) - var pgErr *pgconn.PgError - require.True(t, errors.As(result.Err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsCanceled(t *testing.T) { @@ -671,86 +719,92 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -800,63 +854,67 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) + _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + }) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) - _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecBatchPrecanceled(t *testing.T) { @@ -895,76 +953,82 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - queryCount := 100000 - args := make([]string, queryCount) + queryCount := 100000 + args := make([]string, queryCount) - for i := range args { - args[i] = strconv.Itoa(i) - batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) - } + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, queryCount) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) - for i := range args { - require.Len(t, results[i].Rows, 1) - require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) - } + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } + }) } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) - batch.ExecParams("select 1/0", nil, nil, nil, nil) - _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.Error(t, err) + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) - result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() - require.Equal(t, "0", string(result.Rows[0][0])) + result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) + }) } func TestConnLocking(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "conn busy", err.Error()) - assert.True(t, pgconn.SafeToRetry(err)) + mrr := pgConn.Exec(ctx, "select 'Hello, world'") + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) - results, err := mrr.ReadAll() - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestCommandTag(t *testing.T) { @@ -993,91 +1057,97 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), `do $$ -begin - raise notice 'hello, world'; -end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + multiResult := pgConn.Exec(ctx, `do $$ + begin + raise notice 'hello, world'; + end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnOnNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - err = pgConn.WaitForNotification(context.Background()) - require.NoError(t, err) + err = pgConn.WaitForNotification(ctx) + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { @@ -1119,94 +1189,100 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) - - inputBytes := make([]byte, 0) - - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) + inputBytes := make([]byte, 0) - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.NoError(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } - ensureConnValid(t, pgConn) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0)) + outputWriter := bytes.NewBuffer(make([]byte, 0)) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyToCanceled(t *testing.T) { @@ -1250,37 +1326,39 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - srcBuf := &bytes.Buffer{} - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - } + defer closeConn(t, pgConn) - ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) + srcBuf := &bytes.Buffer{} - assert.Equal(t, inputRows, result.Rows) + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } - ensureConnValid(t, pgConn) + ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromCanceled(t *testing.T) { @@ -1358,153 +1436,163 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) - - gw := gzip.NewWriter(f) - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - } + defer closeConn(t, pgConn) - err = gw.Close() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - _, err = f.Seek(0, 0) - require.NoError(t, err) + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) - gr, err := gzip.NewReader(f) - require.NoError(t, err) + gw := gzip.NewWriter(f) - ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } - err = gr.Close() - require.NoError(t, err) + err = gw.Close() + require.NoError(t, err) - err = f.Close() - require.NoError(t, err) + _, err = f.Seek(0, 0) + require.NoError(t, err) - err = os.Remove(f.Name()) - require.NoError(t, err) + gr, err := gzip.NewReader(f) + require.NoError(t, err) - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) + ct, err := pgConn.CopyFrom(ctx, gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - assert.Equal(t, inputRows, result.Rows) + err = gr.Close() + require.NoError(t, err) - ensureConnValid(t, pgConn) + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnEscapeString(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, - } - - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, } - } - ensureConnValid(t, pgConn) + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) + }) } func TestConnCancelRequest(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) - err = pgConn.CancelRequest(context.Background()) - require.NoError(t, err) + err = pgConn.CancelRequest(ctx) + require.NoError(t, err) - for multiResult.NextResult() { - } - err = multiResult.Close() + for multiResult.NextResult() { + } + err = multiResult.Close() - require.IsType(t, &pgconn.PgError{}, err) - require.Equal(t, "57014", err.(*pgconn.PgError).Code) + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnSendBytesAndReceiveMessage(t *testing.T) { @@ -1547,13 +1635,13 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(nil, os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } - defer pgConn.Close(context.Background()) + defer pgConn.Close(nil) - result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() + result := pgConn.ExecParams(nil, "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) }