2
0

msgReader implemented in terms of ChunkReader

This should substantially reduce memory allocations and memory copies.

It also means that PostgreSQL messages are always entirely buffered in memory
before processing begins. This simplifies the message processing code.

In particular, Conn.WaitForNotification is dramatically simplified by this
change.
This commit is contained in:
Jack Christensen
2017-02-13 20:41:58 -06:00
parent 84802ece05
commit 11b82b3ca4
7 changed files with 130 additions and 467 deletions
+29 -86
View File
@@ -1,7 +1,6 @@
package pgx package pgx
import ( import (
"bufio"
"crypto/md5" "crypto/md5"
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
@@ -20,6 +19,8 @@ import (
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/jackc/pgx/chunkreader"
) )
const ( const (
@@ -283,7 +284,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
} }
} }
c.mr.reader = bufio.NewReader(c.conn) c.mr.cr = chunkreader.NewChunkReader(c.conn)
msg := newStartupMessage() msg := newStartupMessage()
@@ -844,9 +845,8 @@ func (c *Conn) Unlisten(channel string) error {
return nil return nil
} }
// WaitForNotification waits for a PostgreSQL notification for up to timeout. // WaitForNotification waits for a PostgreSQL notification.
// If the timeout occurs it returns pgx.ErrNotificationTimeout func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) {
func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) {
// Return already received notification immediately // Return already received notification immediately
if len(c.notifications) > 0 { if len(c.notifications) > 0 {
notification := c.notifications[0] notification := c.notifications[0]
@@ -854,97 +854,40 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
return notification, nil return notification, nil
} }
ctx, cancelFn := context.WithTimeout(context.Background(), timeout) err = c.waitForPreviousCancelQuery(ctx)
if err := c.waitForPreviousCancelQuery(ctx); err != nil { if err != nil {
cancelFn()
return nil, err return nil, err
} }
cancelFn()
err = c.initContext(ctx)
if err != nil {
return nil, err
}
defer func() {
err = c.termContext(err)
}()
if err = c.lock(); err != nil {
return nil, err
}
defer func() {
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
err = unlockErr
}
}()
if err := c.ensureConnectionReadyForQuery(); err != nil { if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err return nil, err
} }
stopTime := time.Now().Add(timeout)
for { for {
now := time.Now() t, r, err := c.rxMsg()
if now.After(stopTime) {
return nil, ErrNotificationTimeout
}
// If there has been no activity on this connection for a while send a nop message just to ensure
// the connection is alive
nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second)
if nextEnsureAliveTime.Before(now) {
// If the server can't respond to a nop in 15 seconds, assume it's dead
err := c.conn.SetReadDeadline(now.Add(15 * time.Second))
if err != nil {
return nil, err
}
_, err = c.Exec("--;")
if err != nil {
return nil, err
}
c.lastActivityTime = now
}
var deadline time.Time
if stopTime.Before(nextEnsureAliveTime) {
deadline = stopTime
} else {
deadline = nextEnsureAliveTime
}
notification, err := c.waitForNotification(deadline)
if err != ErrNotificationTimeout {
return notification, err
}
}
}
func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
var zeroTime time.Time
for {
// Use SetReadDeadline to implement the timeout. SetReadDeadline will
// cause operations to fail with a *net.OpError that has a Timeout()
// of true. Because the normal pgx rxMsg path considers any error to
// have potentially corrupted the state of the connection, it dies
// on any errors. So to avoid timeout errors in rxMsg we set the
// deadline and peek into the reader. If a timeout error occurs there
// we don't break the pgx connection. If the Peek returns that data
// is available then we turn off the read deadline before the rxMsg.
err := c.conn.SetReadDeadline(deadline)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Wait until there is a byte available before continuing onto the normal msg reading path err = c.processContextFreeMsg(t, r)
_, err = c.mr.reader.Peek(1)
if err != nil { if err != nil {
c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
if err, ok := err.(*net.OpError); ok && err.Timeout() {
return nil, ErrNotificationTimeout
}
return nil, err
}
err = c.conn.SetReadDeadline(zeroTime)
if err != nil {
return nil, err
}
var t byte
var r *msgReader
if t, r, err = c.rxMsg(); err == nil {
if err = c.processContextFreeMsg(t, r); err != nil {
return nil, err
}
} else {
return nil, err return nil, err
} }
@@ -1114,7 +1057,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
c.lastActivityTime = time.Now() c.lastActivityTime = time.Now()
if c.shouldLog(LogLevelTrace) { if c.shouldLog(LogLevelTrace) {
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody))
} }
return t, &c.mr, err return t, &c.mr, err
@@ -1236,11 +1179,11 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) {
// wrong. So read the count, ignore it, and compute the proper value from // wrong. So read the count, ignore it, and compute the proper value from
// the size of the message. // the size of the message.
r.readInt16() r.readInt16()
parameterCount := r.msgBytesRemaining / 4 parameterCount := len(r.msgBody[r.rp:]) / 4
parameters = make([]OID, 0, parameterCount) parameters = make([]OID, 0, parameterCount)
for i := int32(0); i < parameterCount; i++ { for i := 0; i < parameterCount; i++ {
parameters = append(parameters, r.readOID()) parameters = append(parameters, r.readOID())
} }
return return
+18 -10
View File
@@ -1084,7 +1084,7 @@ func TestListenNotify(t *testing.T) {
mustExec(t, notifier, "notify chat") mustExec(t, notifier, "notify chat")
// when notification is waiting on the socket to be read // when notification is waiting on the socket to be read
notification, err := listener.WaitForNotification(time.Second) notification, err := listener.WaitForNotification(context.Background())
if err != nil { if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err) t.Fatalf("Unexpected error on WaitForNotification: %v", err)
} }
@@ -1099,7 +1099,10 @@ func TestListenNotify(t *testing.T) {
if rows.Err() != nil { if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err()) t.Fatalf("Unexpected error on Query: %v", rows.Err())
} }
notification, err = listener.WaitForNotification(0)
ctx, cancelFn := context.WithCancel(context.Background())
cancelFn()
notification, err = listener.WaitForNotification(ctx)
if err != nil { if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err) t.Fatalf("Unexpected error on WaitForNotification: %v", err)
} }
@@ -1108,8 +1111,9 @@ func TestListenNotify(t *testing.T) {
} }
// when timeout occurs // when timeout occurs
notification, err = listener.WaitForNotification(time.Millisecond) ctx, _ = context.WithTimeout(context.Background(), time.Millisecond)
if err != pgx.ErrNotificationTimeout { notification, err = listener.WaitForNotification(ctx)
if err != context.DeadlineExceeded {
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
} }
if notification != nil { if notification != nil {
@@ -1118,7 +1122,7 @@ func TestListenNotify(t *testing.T) {
// listener can listen again after a timeout // listener can listen again after a timeout
mustExec(t, notifier, "notify chat") mustExec(t, notifier, "notify chat")
notification, err = listener.WaitForNotification(time.Second) notification, err = listener.WaitForNotification(context.Background())
if err != nil { if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err) t.Fatalf("Unexpected error on WaitForNotification: %v", err)
} }
@@ -1143,7 +1147,7 @@ func TestUnlistenSpecificChannel(t *testing.T) {
mustExec(t, notifier, "notify unlisten_test") mustExec(t, notifier, "notify unlisten_test")
// when notification is waiting on the socket to be read // when notification is waiting on the socket to be read
notification, err := listener.WaitForNotification(time.Second) notification, err := listener.WaitForNotification(context.Background())
if err != nil { if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err) t.Fatalf("Unexpected error on WaitForNotification: %v", err)
} }
@@ -1163,8 +1167,10 @@ func TestUnlistenSpecificChannel(t *testing.T) {
if rows.Err() != nil { if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", rows.Err()) t.Fatalf("Unexpected error on Query: %v", rows.Err())
} }
notification, err = listener.WaitForNotification(100 * time.Millisecond)
if err != pgx.ErrNotificationTimeout { ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
notification, err = listener.WaitForNotification(ctx)
if err != context.DeadlineExceeded {
t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
} }
} }
@@ -1246,7 +1252,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
// Notify self and WaitForNotification immediately // Notify self and WaitForNotification immediately
mustExec(t, conn, "notify self") mustExec(t, conn, "notify self")
notification, err := conn.WaitForNotification(time.Second) ctx, _ := context.WithTimeout(context.Background(), time.Second)
notification, err := conn.WaitForNotification(ctx)
if err != nil { if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err) t.Fatalf("Unexpected error on WaitForNotification: %v", err)
} }
@@ -1263,7 +1270,8 @@ func TestListenNotifySelfNotification(t *testing.T) {
t.Fatalf("Unexpected error on Query: %v", rows.Err()) t.Fatalf("Unexpected error on Query: %v", rows.Err())
} }
notification, err = conn.WaitForNotification(time.Second) ctx, _ = context.WithTimeout(context.Background(), time.Second)
notification, err = conn.WaitForNotification(ctx)
if err != nil { if err != nil {
t.Fatalf("Unexpected error on WaitForNotification: %v", err) t.Fatalf("Unexpected error on WaitForNotification: %v", err)
} }
+62 -140
View File
@@ -1,26 +1,29 @@
package pgx package pgx
import ( import (
"bufio" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io"
"net" "net"
"github.com/jackc/pgx/chunkreader"
) )
// msgReader is a helper that reads values from a PostgreSQL message. // msgReader is a helper that reads values from a PostgreSQL message.
type msgReader struct { type msgReader struct {
reader *bufio.Reader cr *chunkreader.ChunkReader
msgBytesRemaining int32 msgType byte
err error msgBody []byte
log func(lvl int, msg string, ctx ...interface{}) rp int // read position
shouldLog func(lvl int) bool err error
log func(lvl int, msg string, ctx ...interface{})
shouldLog func(lvl int) bool
} }
// fatal tells rc that a Fatal error has occurred // fatal tells rc that a Fatal error has occurred
func (r *msgReader) fatal(err error) { func (r *msgReader) fatal(err error) {
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp)
} }
r.err = err r.err = err
} }
@@ -31,22 +34,7 @@ func (r *msgReader) rxMsg() (byte, error) {
return 0, r.err return 0, r.err
} }
if r.msgBytesRemaining > 0 { header, err := r.cr.Next(5)
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
}
n, err := r.reader.Discard(int(r.msgBytesRemaining))
r.msgBytesRemaining -= int32(n)
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
}
return 0, err
}
}
b, err := r.reader.Peek(5)
if err != nil { if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err) r.fatal(err)
@@ -54,22 +42,20 @@ func (r *msgReader) rxMsg() (byte, error) {
return 0, err return 0, err
} }
msgType := b[0] r.msgType = header[0]
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4 bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
// Try to preload bufio.Reader with entire message r.msgBody, err = r.cr.Next(bodyLen)
b, err = r.reader.Peek(5 + int(payloadSize)) if err != nil {
if err != nil && err != bufio.ErrBufferFull {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err) r.fatal(err)
} }
return 0, err return 0, err
} }
r.msgBytesRemaining = payloadSize r.rp = 0
r.reader.Discard(5)
return msgType, nil return r.msgType, nil
} }
func (r *msgReader) readByte() byte { func (r *msgReader) readByte() byte {
@@ -77,20 +63,16 @@ func (r *msgReader) readByte() byte {
return 0 return 0
} }
r.msgBytesRemaining-- if len(r.msgBody)-r.rp < 1 {
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b, err := r.reader.ReadByte() b := r.msgBody[r.rp]
if err != nil { r.rp++
r.fatal(err)
return 0
}
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp)
} }
return b return b
@@ -101,24 +83,16 @@ func (r *msgReader) readInt16() int16 {
return 0 return 0
} }
r.msgBytesRemaining -= 2 if len(r.msgBody)-r.rp < 2 {
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b, err := r.reader.Peek(2) n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:]))
if err != nil { r.rp += 2
r.fatal(err)
return 0
}
n := int16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp)
} }
return n return n
@@ -129,24 +103,16 @@ func (r *msgReader) readInt32() int32 {
return 0 return 0
} }
r.msgBytesRemaining -= 4 if len(r.msgBody)-r.rp < 4 {
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b, err := r.reader.Peek(4) n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:]))
if err != nil { r.rp += 4
r.fatal(err)
return 0
}
n := int32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp)
} }
return n return n
@@ -157,24 +123,16 @@ func (r *msgReader) readUint16() uint16 {
return 0 return 0
} }
r.msgBytesRemaining -= 2 if len(r.msgBody)-r.rp < 2 {
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b, err := r.reader.Peek(2) n := binary.BigEndian.Uint16(r.msgBody[r.rp:])
if err != nil { r.rp += 2
r.fatal(err)
return 0
}
n := uint16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp)
} }
return n return n
@@ -185,24 +143,16 @@ func (r *msgReader) readUint32() uint32 {
return 0 return 0
} }
r.msgBytesRemaining -= 4 if len(r.msgBody)-r.rp < 4 {
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b, err := r.reader.Peek(4) n := binary.BigEndian.Uint32(r.msgBody[r.rp:])
if err != nil { r.rp += 4
r.fatal(err)
return 0
}
n := uint32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp)
} }
return n return n
@@ -213,24 +163,16 @@ func (r *msgReader) readInt64() int64 {
return 0 return 0
} }
r.msgBytesRemaining -= 8 if len(r.msgBody)-r.rp < 8 {
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return 0 return 0
} }
b, err := r.reader.Peek(8) n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:]))
if err != nil { r.rp += 8
r.fatal(err)
return 0
}
n := int64(binary.BigEndian.Uint64(b))
r.reader.Discard(8)
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp)
} }
return n return n
@@ -246,22 +188,17 @@ func (r *msgReader) readCString() string {
return "" return ""
} }
b, err := r.reader.ReadBytes(0) nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0)
if err != nil { if nullIdx == -1 {
r.fatal(err) r.fatal(errors.New("null terminated string not found"))
return "" return ""
} }
r.msgBytesRemaining -= int32(len(b)) s := string(r.msgBody[r.rp : r.rp+nullIdx])
if r.msgBytesRemaining < 0 { r.rp += nullIdx + 1
r.fatal(errors.New("read past end of message"))
return ""
}
s := string(b[0 : len(b)-1])
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp)
} }
return s return s
@@ -273,58 +210,43 @@ func (r *msgReader) readString(countI32 int32) string {
return "" return ""
} }
r.msgBytesRemaining -= countI32 count := int(countI32)
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < count {
r.fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return "" return ""
} }
count := int(countI32) s := string(r.msgBody[r.rp : r.rp+count])
var s string r.rp += count
if r.reader.Buffered() >= count {
buf, _ := r.reader.Peek(count)
s = string(buf)
r.reader.Discard(count)
} else {
buf := make([]byte, count)
_, err := io.ReadFull(r.reader, buf)
if err != nil {
r.fatal(err)
return ""
}
s = string(buf)
}
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp)
} }
return s return s
} }
// readBytes reads count bytes and returns as []byte // readBytes reads count bytes and returns as []byte
func (r *msgReader) readBytes(count int32) []byte { func (r *msgReader) readBytes(countI32 int32) []byte {
if r.err != nil { if r.err != nil {
return nil return nil
} }
r.msgBytesRemaining -= count count := int(countI32)
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < count {
r.fatal(errors.New("read past end of message")) r.fatal(errors.New("read past end of message"))
return nil return nil
} }
b := make([]byte, int(count)) b := r.msgBody[r.rp : r.rp+count]
r.rp += count
_, err := io.ReadFull(r.reader, b) r.cr.KeepLast()
if err != nil {
r.fatal(err)
return nil
}
if r.shouldLog(LogLevelTrace) { if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp)
} }
return b return b
-189
View File
@@ -1,189 +0,0 @@
package pgx
import (
"bufio"
"net"
"testing"
"time"
"github.com/jackc/pgmock/pgmsg"
)
func TestMsgReaderPrebuffersWhenPossible(t *testing.T) {
t.Parallel()
tests := []struct {
msgType byte
payloadSize int32
buffered bool
}{
{1, 50, true},
{2, 0, true},
{3, 500, true},
{4, 1050, true},
{5, 1500, true},
{6, 1500, true},
{7, 4000, true},
{8, 24000, false},
{9, 4000, true},
{1, 1500, true},
{2, 0, true},
{3, 500, true},
{4, 1050, true},
{5, 1500, true},
{6, 1500, true},
{7, 4000, true},
{8, 14000, false},
{9, 0, true},
{1, 500, true},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
go func() {
var bigEndian pgmsg.BigEndianBuf
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
for _, tt := range tests {
_, err = conn.Write([]byte{tt.msgType})
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4))
if err != nil {
t.Fatal(err)
}
payload := make([]byte, int(tt.payloadSize))
_, err = conn.Write(payload)
if err != nil {
t.Fatal(err)
}
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
mr := &msgReader{
reader: bufio.NewReader(conn),
shouldLog: func(int) bool { return false },
}
for i, tt := range tests {
msgType, err := mr.rxMsg()
if err != nil {
t.Fatalf("%d. Unexpected error: %v", i, err)
}
if msgType != tt.msgType {
t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType)
}
if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered {
t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered())
}
}
}
func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
testCount := 10000
go func() {
var bigEndian pgmsg.BigEndianBuf
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
for i := 0; i < testCount; i++ {
msgType := byte(i)
_, err = conn.Write([]byte{msgType})
if err != nil {
t.Fatal(err)
}
msgSize := i % 4000
_, err = conn.Write(bigEndian.Int32(int32(msgSize + 4)))
if err != nil {
t.Fatal(err)
}
payload := make([]byte, msgSize)
_, err = conn.Write(payload)
if err != nil {
t.Fatal(err)
}
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
mr := &msgReader{
reader: bufio.NewReader(conn),
shouldLog: func(int) bool { return false },
}
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
i := 0
for {
msgType, err := mr.rxMsg()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
continue
} else {
t.Fatalf("%d. Unexpected error: %v", i, err)
}
}
expectedMsgType := byte(i)
if msgType != expectedMsgType {
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType)
}
expectedMsgSize := i % 4000
payload := mr.readBytes(mr.msgBytesRemaining)
if mr.err != nil {
t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err)
}
if len(payload) != expectedMsgSize {
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload))
}
i++
if i == testCount {
break
}
}
}
+13 -35
View File
@@ -1,9 +1,9 @@
package pgx package pgx
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net"
"time" "time"
) )
@@ -234,7 +234,7 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
walStart := reader.readInt64() walStart := reader.readInt64()
serverWalEnd := reader.readInt64() serverWalEnd := reader.readInt64()
serverTime := reader.readInt64() serverTime := reader.readInt64()
walData := reader.readBytes(reader.msgBytesRemaining) walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp))
walMessage := WalMessage{WalStart: uint64(walStart), walMessage := WalMessage{WalStart: uint64(walStart),
ServerWalEnd: uint64(serverWalEnd), ServerWalEnd: uint64(serverWalEnd),
ServerTime: uint64(serverTime), ServerTime: uint64(serverTime),
@@ -261,47 +261,23 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
return return
} }
// Wait for a single replication message up to timeout time. // Wait for a single replication message.
// //
// Properly using this requires some knowledge of the postgres replication mechanisms, // Properly using this requires some knowledge of the postgres replication mechanisms,
// as the client can receive both WAL data (the ultimate payload) and server heartbeat // as the client can receive both WAL data (the ultimate payload) and server heartbeat
// updates. The caller also must send standby status updates in order to keep the connection // updates. The caller also must send standby status updates in order to keep the connection
// alive and working. // alive and working.
// //
// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified // This returns the context error when there is no replication message before
// duration. // the context is canceled.
func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) { func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (r *ReplicationMessage, err error) {
var zeroTime time.Time err = rc.c.initContext(ctx)
deadline := time.Now().Add(timeout)
// Use SetReadDeadline to implement the timeout. SetReadDeadline will
// cause operations to fail with a *net.OpError that has a Timeout()
// of true. Because the normal pgx rxMsg path considers any error to
// have potentially corrupted the state of the connection, it dies
// on any errors. So to avoid timeout errors in rxMsg we set the
// deadline and peek into the reader. If a timeout error occurs there
// we don't break the pgx connection. If the Peek returns that data
// is available then we turn off the read deadline before the rxMsg.
err = rc.c.conn.SetReadDeadline(deadline)
if err != nil {
return nil, err
}
// Wait until there is a byte available before continuing onto the normal msg reading path
_, err = rc.c.mr.reader.Peek(1)
if err != nil {
rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
if err, ok := err.(*net.OpError); ok && err.Timeout() {
return nil, ErrNotificationTimeout
}
return nil, err
}
err = rc.c.conn.SetReadDeadline(zeroTime)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
err = rc.c.termContext(err)
}()
return rc.readReplicationMessage() return rc.readReplicationMessage()
} }
@@ -401,12 +377,14 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
return return
} }
ctx, _ := context.WithTimeout(context.Background(), initialReplicationResponseTimeout)
// The first replication message that comes back here will be (in a success case) // The first replication message that comes back here will be (in a success case)
// a empty CopyBoth that is (apparently) sent as the confirmation that the replication has // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has
// started. This call will either return nil, nil or if it returns an error // started. This call will either return nil, nil or if it returns an error
// that indicates the start replication command failed // that indicates the start replication command failed
var r *ReplicationMessage var r *ReplicationMessage
r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout) r, err = rc.WaitForReplicationMessage(ctx)
if err != nil && r != nil { if err != nil && r != nil {
if rc.c.shouldLog(LogLevelError) { if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unxpected replication message %v", r) rc.c.log(LogLevelError, "Unxpected replication message %v", r)
+5 -5
View File
@@ -1,6 +1,7 @@
package pgx_test package pgx_test
import ( import (
"context"
"fmt" "fmt"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"reflect" "reflect"
@@ -88,11 +89,10 @@ func TestSimpleReplicationConnection(t *testing.T) {
for { for {
var message *pgx.ReplicationMessage var message *pgx.ReplicationMessage
message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second)) ctx, _ := context.WithTimeout(context.Background(), time.Second)
if err != nil { message, err = replicationConn.WaitForReplicationMessage(ctx)
if err != pgx.ErrNotificationTimeout { if err != nil && err != context.DeadlineExceeded {
t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err))
}
} }
if message != nil { if message != nil {
if message.WalMessage != nil { if message.WalMessage != nil {
+3 -2
View File
@@ -244,8 +244,9 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error {
return err return err
} }
_, err = conn.WaitForNotification(100 * time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
if err == pgx.ErrNotificationTimeout { _, err = conn.WaitForNotification(ctx)
if err == context.DeadlineExceeded {
return nil return nil
} }
return err return err