From 3dec1848118789c4430914ca04d2f6fd0542c3d9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jul 2019 10:22:09 -0500 Subject: [PATCH] Split ValidateConnect from AfterConnect This avoids the foot-gun of ParseConfig setting AfterConnect because of target_session_attrs and the user inadvertently overriding it with an AfterConnect designed to setup the connection. Now target_session_attrs will be handled with ValidateConnect. --- config.go | 18 ++++++++++++------ config_test.go | 17 +++++++++-------- pgconn.go | 22 +++++++++++++++++----- pgconn_test.go | 29 +++++++++++++++++++++++++---- 4 files changed, 63 insertions(+), 23 deletions(-) diff --git a/config.go b/config.go index 533791c2..9b74945e 100644 --- a/config.go +++ b/config.go @@ -22,6 +22,7 @@ import ( ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error +type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -36,9 +37,14 @@ type Config struct { Fallbacks []*FallbackConfig - // AfterConnect is called after successful connection. It can be used to set up the connection or to validate that - // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This - // allows implementing high availability behavior such as libpq does with target_session_attrs. + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. + // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with + // target_session_attrs. + ValidateConnect ValidateConnectFunc + + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables + // or prepare statements). If this returns an error the connection attempt fails. AfterConnect AfterConnectFunc // OnNotice is a callback function called when a notice response is received. @@ -245,7 +251,7 @@ func ParseConfig(connString string) (*Config, error) { } if settings["target_session_attrs"] == "read-write" { - config.AfterConnect = AfterConnectTargetSessionAttrsReadWrite + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) } @@ -481,9 +487,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { return d.DialContext, nil } -// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. -func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { +func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err diff --git a/config_test.go b/config_test.go index b222d8cc..23d86529 100644 --- a/config_test.go +++ b/config_test.go @@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - AfterConnect: pgconn.AfterConnectTargetSessionAttrsReadWrite, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, }, }, } @@ -416,6 +416,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { diff --git a/pgconn.go b/pgconn.go index 2db35587..6e1fb7e3 100644 --- a/pgconn.go +++ b/pgconn.go @@ -122,13 +122,25 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { - return pgConn, nil + break } else if err, ok := err.(*PgError); ok { return nil, err } } - return nil, err + if err != nil { + return nil, err + } + + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, errors.Errorf("AfterConnect: %v", err) + } + } + + return pgConn, nil } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { @@ -201,11 +213,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle - if config.AfterConnect != nil { - err := config.AfterConnect(ctx, pgConn) + if config.ValidateConnect != nil { + err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnect: %v", err) + return nil, errors.Errorf("ValidateConnect: %v", err) } } return pgConn, nil diff --git a/pgconn_test.go b/pgconn_test.go index 028d5e94..feb78641 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -187,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAfterConnect(t *testing.T) { +func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) @@ -200,7 +200,7 @@ func TestConnectWithAfterConnect(t *testing.T) { } acceptConnCount := 0 - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") @@ -226,13 +226,13 @@ func TestConnectWithAfterConnect(t *testing.T) { assert.True(t, acceptConnCount > 1) } -func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { +func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - config.AfterConnect = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" conn, err := pgconn.ConnectConfig(context.Background(), config) @@ -241,6 +241,27 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } +func TestConnectWithAfterConnect(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) + + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel()