committed by
Hashrocket Workstation
parent
9e321af35c
commit
b271dd5bf1
@@ -12,6 +12,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionParameters contains all the options used to establish a connection.
|
// ConnectionParameters contains all the options used to establish a connection.
|
||||||
@@ -36,6 +37,7 @@ type Connection struct {
|
|||||||
parameters ConnectionParameters // parameters used when establishing this connection
|
parameters ConnectionParameters // parameters used when establishing this connection
|
||||||
TxStatus byte
|
TxStatus byte
|
||||||
preparedStatements map[string]*preparedStatement
|
preparedStatements map[string]*preparedStatement
|
||||||
|
notifications []*Notification
|
||||||
}
|
}
|
||||||
|
|
||||||
type preparedStatement struct {
|
type preparedStatement struct {
|
||||||
@@ -44,6 +46,12 @@ type preparedStatement struct {
|
|||||||
ParameterOids []Oid
|
ParameterOids []Oid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Notification struct {
|
||||||
|
Pid int32 // backend pid that sent the notification
|
||||||
|
Channel string // channel from which notification was received
|
||||||
|
Payload string
|
||||||
|
}
|
||||||
|
|
||||||
// NotSingleRowError is returned when exactly 1 row is expected, but 0 or more than
|
// NotSingleRowError is returned when exactly 1 row is expected, but 0 or more than
|
||||||
// 1 row is returned
|
// 1 row is returned
|
||||||
type NotSingleRowError struct {
|
type NotSingleRowError struct {
|
||||||
@@ -336,6 +344,45 @@ func (c *Connection) Deallocate(name string) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Listen establishes a PostgreSQL listen/notify to channel
|
||||||
|
func (c *Connection) Listen(channel string) (err error) {
|
||||||
|
_, err = c.Execute("listen " + channel)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForNotification waits for a PostgreSQL notification for up to timeout
|
||||||
|
func (c *Connection) WaitForNotification(timeout time.Duration) (notification *Notification, err error) {
|
||||||
|
err = c.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
var zeroTime time.Time
|
||||||
|
e := c.conn.SetReadDeadline(zeroTime)
|
||||||
|
if err == nil && e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if len(c.notifications) > 0 {
|
||||||
|
notification = c.notifications[0]
|
||||||
|
c.notifications = c.notifications[1:]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var t byte
|
||||||
|
var r *MessageReader
|
||||||
|
if t, r, err = c.rxMsg(); err == nil {
|
||||||
|
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error) {
|
func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error) {
|
||||||
if ps, present := c.preparedStatements[sql]; present {
|
if ps, present := c.preparedStatements[sql]; present {
|
||||||
return c.sendPreparedQuery(ps, arguments...)
|
return c.sendPreparedQuery(ps, arguments...)
|
||||||
@@ -525,6 +572,8 @@ func (c *Connection) processContextFreeMsg(t byte, r *MessageReader) (err error)
|
|||||||
return c.rxErrorResponse(r)
|
return c.rxErrorResponse(r)
|
||||||
case noticeResponse:
|
case noticeResponse:
|
||||||
return nil
|
return nil
|
||||||
|
case notificationResponse:
|
||||||
|
return c.rxNotificationResponse(r)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Received unknown message type: %c", t)
|
return fmt.Errorf("Received unknown message type: %c", t)
|
||||||
}
|
}
|
||||||
@@ -661,6 +710,15 @@ func (c *Connection) rxCommandComplete(r *MessageReader) string {
|
|||||||
return r.ReadString()
|
return r.ReadString()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connection) rxNotificationResponse(r *MessageReader) (err error) {
|
||||||
|
n := new(Notification)
|
||||||
|
n.Pid = r.ReadInt32()
|
||||||
|
n.Channel = r.ReadString()
|
||||||
|
n.Payload = r.ReadString()
|
||||||
|
c.notifications = append(c.notifications, n)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connection) txStartupMessage(msg *startupMessage) (err error) {
|
func (c *Connection) txStartupMessage(msg *startupMessage) (err error) {
|
||||||
_, err = c.conn.Write(msg.Bytes())
|
_, err = c.conn.Write(msg.Bytes())
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/JackC/pgx"
|
"github.com/JackC/pgx"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConnect(t *testing.T) {
|
func TestConnect(t *testing.T) {
|
||||||
@@ -565,3 +567,47 @@ func TestTransactionIso(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestListenNotify(t *testing.T) {
|
||||||
|
listener, err := pgx.Connect(*defaultConnectionParameters)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
if err := listener.Listen("chat"); err != nil {
|
||||||
|
t.Fatalf("Unable to start listening: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
notifier := getSharedConnection()
|
||||||
|
mustExecute(t, notifier, "notify chat")
|
||||||
|
|
||||||
|
// when notification is waiting on the socket to be read
|
||||||
|
notification, err := listener.WaitForNotification(time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||||
|
}
|
||||||
|
if notification.Channel != "chat" {
|
||||||
|
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// when notification has already been read during previous query
|
||||||
|
mustExecute(t, notifier, "notify chat")
|
||||||
|
mustSelectValue(t, listener, "select 1")
|
||||||
|
notification, err = listener.WaitForNotification(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error on WaitForNotification: %v", err)
|
||||||
|
}
|
||||||
|
if notification.Channel != "chat" {
|
||||||
|
t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// when timeout occurs
|
||||||
|
notification, err = listener.WaitForNotification(time.Millisecond)
|
||||||
|
if _, ok := err.(*net.OpError); !ok {
|
||||||
|
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
|
||||||
|
}
|
||||||
|
if notification != nil {
|
||||||
|
t.Errorf("WaitForNotification returned an unexpected notification: %v", notification)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ const (
|
|||||||
parseComplete = '1'
|
parseComplete = '1'
|
||||||
parameterDescription = 't'
|
parameterDescription = 't'
|
||||||
bindComplete = '2'
|
bindComplete = '2'
|
||||||
|
notificationResponse = 'A'
|
||||||
)
|
)
|
||||||
|
|
||||||
type startupMessage struct {
|
type startupMessage struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user