Split ValidateConnect from AfterConnect
This avoids the foot-gun of ParseConfig setting AfterConnect because of target_session_attrs and the user inadvertently overriding it with an AfterConnect designed to setup the connection. Now target_session_attrs will be handled with ValidateConnect.
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
|
||||||
// Config is the settings used to establish a connection to a PostgreSQL server.
|
// Config is the settings used to establish a connection to a PostgreSQL server.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -36,9 +37,14 @@ type Config struct {
|
|||||||
|
|
||||||
Fallbacks []*FallbackConfig
|
Fallbacks []*FallbackConfig
|
||||||
|
|
||||||
// AfterConnect is called after successful connection. It can be used to set up the connection or to validate that
|
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||||
// server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This
|
// It can be used validate that server is acceptable. If this returns an error the connection is closed and the next
|
||||||
// allows implementing high availability behavior such as libpq does with target_session_attrs.
|
// fallback config is tried. This allows implementing high availability behavior such as libpq does with
|
||||||
|
// target_session_attrs.
|
||||||
|
ValidateConnect ValidateConnectFunc
|
||||||
|
|
||||||
|
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
|
||||||
|
// or prepare statements). If this returns an error the connection attempt fails.
|
||||||
AfterConnect AfterConnectFunc
|
AfterConnect AfterConnectFunc
|
||||||
|
|
||||||
// OnNotice is a callback function called when a notice response is received.
|
// OnNotice is a callback function called when a notice response is received.
|
||||||
@@ -245,7 +251,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if settings["target_session_attrs"] == "read-write" {
|
if settings["target_session_attrs"] == "read-write" {
|
||||||
config.AfterConnect = AfterConnectTargetSessionAttrsReadWrite
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||||
} else if settings["target_session_attrs"] != "any" {
|
} else if settings["target_session_attrs"] != "any" {
|
||||||
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
|
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
|
||||||
}
|
}
|
||||||
@@ -481,9 +487,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
|||||||
return d.DialContext, nil
|
return d.DialContext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible
|
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||||
// target_session_attrs=read-write.
|
// target_session_attrs=read-write.
|
||||||
func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
return result.Err
|
return result.Err
|
||||||
|
|||||||
+9
-8
@@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) {
|
|||||||
name: "target_session_attrs",
|
name: "target_session_attrs",
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
User: "jack",
|
User: "jack",
|
||||||
Password: "secret",
|
Password: "secret",
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 5432,
|
Port: 5432,
|
||||||
Database: "mydb",
|
Database: "mydb",
|
||||||
TLSConfig: nil,
|
TLSConfig: nil,
|
||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
AfterConnect: pgconn.AfterConnectTargetSessionAttrsReadWrite,
|
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -416,6 +416,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
|
|||||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||||
|
|
||||||
// Can't test function equality, so just test that they are set or not.
|
// Can't test function equality, so just test that they are set or not.
|
||||||
|
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
|
||||||
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
|
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
|
||||||
|
|
||||||
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
||||||
|
|||||||
@@ -122,13 +122,25 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||||||
for _, fc := range fallbackConfigs {
|
for _, fc := range fallbackConfigs {
|
||||||
pgConn, err = connect(ctx, config, fc)
|
pgConn, err = connect(ctx, config, fc)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return pgConn, nil
|
break
|
||||||
} else if err, ok := err.(*PgError); ok {
|
} else if err, ok := err.(*PgError); ok {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AfterConnect != nil {
|
||||||
|
err := config.AfterConnect(ctx, pgConn)
|
||||||
|
if err != nil {
|
||||||
|
pgConn.conn.Close()
|
||||||
|
return nil, errors.Errorf("AfterConnect: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pgConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||||
@@ -201,11 +213,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
}
|
}
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
pgConn.status = connStatusIdle
|
pgConn.status = connStatusIdle
|
||||||
if config.AfterConnect != nil {
|
if config.ValidateConnect != nil {
|
||||||
err := config.AfterConnect(ctx, pgConn)
|
err := config.ValidateConnect(ctx, pgConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, errors.Errorf("AfterConnect: %v", err)
|
return nil, errors.Errorf("ValidateConnect: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return pgConn, nil
|
return pgConn, nil
|
||||||
|
|||||||
+25
-4
@@ -187,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) {
|
|||||||
closeConn(t, conn)
|
closeConn(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectWithAfterConnect(t *testing.T) {
|
func TestConnectWithValidateConnect(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
@@ -200,7 +200,7 @@ func TestConnectWithAfterConnect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
acceptConnCount := 0
|
acceptConnCount := 0
|
||||||
config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
||||||
acceptConnCount++
|
acceptConnCount++
|
||||||
if acceptConnCount < 2 {
|
if acceptConnCount < 2 {
|
||||||
return errors.New("reject first conn")
|
return errors.New("reject first conn")
|
||||||
@@ -226,13 +226,13 @@ func TestConnectWithAfterConnect(t *testing.T) {
|
|||||||
assert.True(t, acceptConnCount > 1)
|
assert.True(t, acceptConnCount > 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
config.AfterConnect = pgconn.AfterConnectTargetSessionAttrsReadWrite
|
config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||||
config.RuntimeParams["default_transaction_read_only"] = "on"
|
config.RuntimeParams["default_transaction_read_only"] = "on"
|
||||||
|
|
||||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
@@ -241,6 +241,27 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectWithAfterConnect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
||||||
|
_, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err := conn.Exec(context.Background(), "show search_path;").ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnPrepareSyntaxError(t *testing.T) {
|
func TestConnPrepareSyntaxError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user