From b1631e8e35b681452a9a8787757e5dbeebf987b7 Mon Sep 17 00:00:00 2001 From: James Hartig Date: Tue, 5 Dec 2023 11:28:06 -0600 Subject: [PATCH] pgconn: add OnPGError to Config for error handling OnPGError is called on every error response received from Postgres and can be used to close connections on specific errors. Defaults to closing on FATAL-severity errors. Fixes #1803 --- pgconn/config.go | 12 ++++++++++++ pgconn/pgconn.go | 11 +++++++++-- pgconn/pgconn_test.go | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/pgconn/config.go b/pgconn/config.go index db0170e0..157b8098 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -60,6 +60,11 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + // OnPGError is a callback function called when a Postgres error is received by the server. The default handler will close + // the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure + // that you close on FATAL errors by returning false. + OnPGError ErrorPGHandler + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -261,6 +266,13 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + OnPGError: func(_ *PgConn, pgErr *PgError) bool { + // we want to automatically close any fatal errors + if strings.EqualFold(pgErr.Severity, "FATAL") { + return false + } + return true + }, } if connectTimeoutSetting, present := settings["connect_timeout"]; present { diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 1ccdc4db..71d8e50e 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend +// ErrorPGHandler is a function that handles errors returned from Postgres. This function must return true to keep +// the connection open. Returning false will cause the connection to be closed immediately. You should return +// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is +// aware of the origin of the error, but it must not invoke any query method. +type ErrorPGHandler func(*PgConn, *PgError) bool + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: - if msg.Severity == "FATAL" { + err := ErrorResponseToPgError(msg) + if pgConn.config.OnPGError != nil && !pgConn.config.OnPGError(pgConn, err) { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) - return nil, ErrorResponseToPgError(msg) + return nil, err } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 80621930..c1d9ae18 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3148,6 +3148,41 @@ func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { require.EqualError(t, err, "pipeline has unsynced requests") } +func TestConnOnPGError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.OnPGError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool { + require.NotNil(t, c) + require.NotNil(t, pgErr) + // close connection on undefined tables only + if pgErr.Code == "42P01" { + return false + } + return true + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select 1/0").ReadAll() + assert.Error(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll() + assert.Error(t, err) + assert.True(t, pgConn.IsClosed()) +} + func Example() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel()