diff --git a/config.go b/config.go index 40cbd0bb..fec1fedf 100644 --- a/config.go +++ b/config.go @@ -54,6 +54,9 @@ type Config struct { // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler + + // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. + OnNotification NotificationHandler } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index 9277d4a8..b2ffe7ca 100644 --- a/pgconn.go +++ b/pgconn.go @@ -50,6 +50,13 @@ func (pe *PgError) Error() string { // LISTEN/NOTIFY notification. type Notice PgError +// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system +type Notification struct { + PID uint32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) @@ -59,6 +66,12 @@ type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // notification. type NoticeHandler func(*PgConn, *Notice) +// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications +// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is +// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a +// notice event. +type NotificationHandler func(*PgConn, *Notification) + // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") @@ -284,6 +297,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { if pgConn.Config.OnNotice != nil { pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) } + case *pgproto3.NotificationResponse: + if pgConn.Config.OnNotification != nil { + pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + } } return msg, nil diff --git a/pgconn_test.go b/pgconn_test.go index 90f99325..ad538257 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -597,6 +597,38 @@ end$$;`) ensureConnValid(t, pgConn) } +func TestConnOnNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.Nil(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.Nil(t, err) + + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.Nil(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil {