2
0

Split ValidateConnect from AfterConnect

This avoids the foot-gun of ParseConfig setting AfterConnect because of
target_session_attrs and the user inadvertently overriding it with an
AfterConnect designed to setup the connection.

Now target_session_attrs will be handled with ValidateConnect.
This commit is contained in:
Jack Christensen
2019-07-13 10:22:09 -05:00
parent 59941377c8
commit 3dec184811
4 changed files with 63 additions and 23 deletions
+12 -6
View File
@@ -22,6 +22,7 @@ import (
) )
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server. // Config is the settings used to establish a connection to a PostgreSQL server.
type Config struct { type Config struct {
@@ -36,9 +37,14 @@ type Config struct {
Fallbacks []*FallbackConfig Fallbacks []*FallbackConfig
// AfterConnect is called after successful connection. It can be used to set up the connection or to validate that // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next
// allows implementing high availability behavior such as libpq does with target_session_attrs. // fallback config is tried. This allows implementing high availability behavior such as libpq does with
// target_session_attrs.
ValidateConnect ValidateConnectFunc
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
// or prepare statements). If this returns an error the connection attempt fails.
AfterConnect AfterConnectFunc AfterConnect AfterConnectFunc
// OnNotice is a callback function called when a notice response is received. // OnNotice is a callback function called when a notice response is received.
@@ -245,7 +251,7 @@ func ParseConfig(connString string) (*Config, error) {
} }
if settings["target_session_attrs"] == "read-write" { if settings["target_session_attrs"] == "read-write" {
config.AfterConnect = AfterConnectTargetSessionAttrsReadWrite config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
} else if settings["target_session_attrs"] != "any" { } else if settings["target_session_attrs"] != "any" {
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
} }
@@ -481,9 +487,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
return d.DialContext, nil return d.DialContext, nil
} }
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write. // target_session_attrs=read-write.
func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil { if result.Err != nil {
return result.Err return result.Err
+9 -8
View File
@@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) {
name: "target_session_attrs", name: "target_session_attrs",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Password: "secret", Password: "secret",
Host: "localhost", Host: "localhost",
Port: 5432, Port: 5432,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
AfterConnect: pgconn.AfterConnectTargetSessionAttrsReadWrite, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
}, },
}, },
} }
@@ -416,6 +416,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not. // Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
+17 -5
View File
@@ -122,13 +122,25 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
for _, fc := range fallbackConfigs { for _, fc := range fallbackConfigs {
pgConn, err = connect(ctx, config, fc) pgConn, err = connect(ctx, config, fc)
if err == nil { if err == nil {
return pgConn, nil break
} else if err, ok := err.(*PgError); ok { } else if err, ok := err.(*PgError); ok {
return nil, err return nil, err
} }
} }
return nil, err if err != nil {
return nil, err
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, errors.Errorf("AfterConnect: %v", err)
}
}
return pgConn, nil
} }
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
@@ -201,11 +213,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle pgConn.status = connStatusIdle
if config.AfterConnect != nil { if config.ValidateConnect != nil {
err := config.AfterConnect(ctx, pgConn) err := config.ValidateConnect(ctx, pgConn)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, errors.Errorf("AfterConnect: %v", err) return nil, errors.Errorf("ValidateConnect: %v", err)
} }
} }
return pgConn, nil return pgConn, nil
+25 -4
View File
@@ -187,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) {
closeConn(t, conn) closeConn(t, conn)
} }
func TestConnectWithAfterConnect(t *testing.T) { func TestConnectWithValidateConnect(t *testing.T) {
t.Parallel() t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
@@ -200,7 +200,7 @@ func TestConnectWithAfterConnect(t *testing.T) {
} }
acceptConnCount := 0 acceptConnCount := 0
config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
acceptConnCount++ acceptConnCount++
if acceptConnCount < 2 { if acceptConnCount < 2 {
return errors.New("reject first conn") return errors.New("reject first conn")
@@ -226,13 +226,13 @@ func TestConnectWithAfterConnect(t *testing.T) {
assert.True(t, acceptConnCount > 1) assert.True(t, acceptConnCount > 1)
} }
func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
t.Parallel() t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err) require.NoError(t, err)
config.AfterConnect = pgconn.AfterConnectTargetSessionAttrsReadWrite config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
config.RuntimeParams["default_transaction_read_only"] = "on" config.RuntimeParams["default_transaction_read_only"] = "on"
conn, err := pgconn.ConnectConfig(context.Background(), config) conn, err := pgconn.ConnectConfig(context.Background(), config)
@@ -241,6 +241,27 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
} }
} }
func TestConnectWithAfterConnect(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
_, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
return err
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
results, err := conn.Exec(context.Background(), "show search_path;").ReadAll()
require.NoError(t, err)
defer closeConn(t, conn)
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
}
func TestConnPrepareSyntaxError(t *testing.T) { func TestConnPrepareSyntaxError(t *testing.T) {
t.Parallel() t.Parallel()