From 11b82b3ca4bda887a7aca04e0ef1cc513798b744 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Feb 2017 20:41:58 -0600 Subject: [PATCH] msgReader implemented in terms of ChunkReader This should substantially reduce memory allocations and memory copies. It also means that PostgreSQL messages are always entirely buffered in memory before processing begins. This simplifies the message processing code. In particular, Conn.WaitForNotification is dramatically simplified by this change. --- conn.go | 115 +++++++------------------ conn_test.go | 28 +++--- msg_reader.go | 202 ++++++++++++++------------------------------ msg_reader_test.go | 189 ----------------------------------------- replication.go | 48 +++-------- replication_test.go | 10 +-- stress_test.go | 5 +- 7 files changed, 130 insertions(+), 467 deletions(-) delete mode 100644 msg_reader_test.go diff --git a/conn.go b/conn.go index 07422a32..a8b0b22c 100644 --- a/conn.go +++ b/conn.go @@ -1,7 +1,6 @@ package pgx import ( - "bufio" "crypto/md5" "crypto/tls" "encoding/binary" @@ -20,6 +19,8 @@ import ( "strings" "sync/atomic" "time" + + "github.com/jackc/pgx/chunkreader" ) const ( @@ -283,7 +284,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.mr.reader = bufio.NewReader(c.conn) + c.mr.cr = chunkreader.NewChunkReader(c.conn) msg := newStartupMessage() @@ -844,9 +845,8 @@ func (c *Conn) Unlisten(channel string) error { return nil } -// WaitForNotification waits for a PostgreSQL notification for up to timeout. -// If the timeout occurs it returns pgx.ErrNotificationTimeout -func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) { +// WaitForNotification waits for a PostgreSQL notification. +func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) { // Return already received notification immediately if len(c.notifications) > 0 { notification := c.notifications[0] @@ -854,97 +854,40 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } - ctx, cancelFn := context.WithTimeout(context.Background(), timeout) - if err := c.waitForPreviousCancelQuery(ctx); err != nil { - cancelFn() + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { return nil, err } - cancelFn() + + err = c.initContext(ctx) + if err != nil { + return nil, err + } + defer func() { + err = c.termContext(err) + }() + + if err = c.lock(); err != nil { + return nil, err + } + defer func() { + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() if err := c.ensureConnectionReadyForQuery(); err != nil { return nil, err } - stopTime := time.Now().Add(timeout) - for { - now := time.Now() - - if now.After(stopTime) { - return nil, ErrNotificationTimeout - } - - // If there has been no activity on this connection for a while send a nop message just to ensure - // the connection is alive - nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second) - if nextEnsureAliveTime.Before(now) { - // If the server can't respond to a nop in 15 seconds, assume it's dead - err := c.conn.SetReadDeadline(now.Add(15 * time.Second)) - if err != nil { - return nil, err - } - - _, err = c.Exec("--;") - if err != nil { - return nil, err - } - - c.lastActivityTime = now - } - - var deadline time.Time - if stopTime.Before(nextEnsureAliveTime) { - deadline = stopTime - } else { - deadline = nextEnsureAliveTime - } - - notification, err := c.waitForNotification(deadline) - if err != ErrNotificationTimeout { - return notification, err - } - } -} - -func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { - var zeroTime time.Time - - for { - // Use SetReadDeadline to implement the timeout. SetReadDeadline will - // cause operations to fail with a *net.OpError that has a Timeout() - // of true. Because the normal pgx rxMsg path considers any error to - // have potentially corrupted the state of the connection, it dies - // on any errors. So to avoid timeout errors in rxMsg we set the - // deadline and peek into the reader. If a timeout error occurs there - // we don't break the pgx connection. If the Peek returns that data - // is available then we turn off the read deadline before the rxMsg. - err := c.conn.SetReadDeadline(deadline) + t, r, err := c.rxMsg() if err != nil { return nil, err } - // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = c.mr.reader.Peek(1) + err = c.processContextFreeMsg(t, r) if err != nil { - c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline - if err, ok := err.(*net.OpError); ok && err.Timeout() { - return nil, ErrNotificationTimeout - } - return nil, err - } - - err = c.conn.SetReadDeadline(zeroTime) - if err != nil { - return nil, err - } - - var t byte - var r *msgReader - if t, r, err = c.rxMsg(); err == nil { - if err = c.processContextFreeMsg(t, r); err != nil { - return nil, err - } - } else { return nil, err } @@ -1114,7 +1057,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { c.lastActivityTime = time.Now() if c.shouldLog(LogLevelTrace) { - c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) + c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody)) } return t, &c.mr, err @@ -1236,11 +1179,11 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) { // wrong. So read the count, ignore it, and compute the proper value from // the size of the message. r.readInt16() - parameterCount := r.msgBytesRemaining / 4 + parameterCount := len(r.msgBody[r.rp:]) / 4 parameters = make([]OID, 0, parameterCount) - for i := int32(0); i < parameterCount; i++ { + for i := 0; i < parameterCount; i++ { parameters = append(parameters, r.readOID()) } return diff --git a/conn_test.go b/conn_test.go index a8398507..63b486a6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1084,7 +1084,7 @@ func TestListenNotify(t *testing.T) { mustExec(t, notifier, "notify chat") // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(time.Second) + notification, err := listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1099,7 +1099,10 @@ func TestListenNotify(t *testing.T) { if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = listener.WaitForNotification(0) + + ctx, cancelFn := context.WithCancel(context.Background()) + cancelFn() + notification, err = listener.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1108,8 +1111,9 @@ func TestListenNotify(t *testing.T) { } // when timeout occurs - notification, err = listener.WaitForNotification(time.Millisecond) - if err != pgx.ErrNotificationTimeout { + ctx, _ = context.WithTimeout(context.Background(), time.Millisecond) + notification, err = listener.WaitForNotification(ctx) + if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } if notification != nil { @@ -1118,7 +1122,7 @@ func TestListenNotify(t *testing.T) { // listener can listen again after a timeout mustExec(t, notifier, "notify chat") - notification, err = listener.WaitForNotification(time.Second) + notification, err = listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1143,7 +1147,7 @@ func TestUnlistenSpecificChannel(t *testing.T) { mustExec(t, notifier, "notify unlisten_test") // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(time.Second) + notification, err := listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1163,8 +1167,10 @@ func TestUnlistenSpecificChannel(t *testing.T) { if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = listener.WaitForNotification(100 * time.Millisecond) - if err != pgx.ErrNotificationTimeout { + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + notification, err = listener.WaitForNotification(ctx) + if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } } @@ -1246,7 +1252,8 @@ func TestListenNotifySelfNotification(t *testing.T) { // Notify self and WaitForNotification immediately mustExec(t, conn, "notify self") - notification, err := conn.WaitForNotification(time.Second) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + notification, err := conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1263,7 +1270,8 @@ func TestListenNotifySelfNotification(t *testing.T) { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = conn.WaitForNotification(time.Second) + ctx, _ = context.WithTimeout(context.Background(), time.Second) + notification, err = conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } diff --git a/msg_reader.go b/msg_reader.go index f507c198..53e944bb 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -1,26 +1,29 @@ package pgx import ( - "bufio" + "bytes" "encoding/binary" "errors" - "io" "net" + + "github.com/jackc/pgx/chunkreader" ) // msgReader is a helper that reads values from a PostgreSQL message. type msgReader struct { - reader *bufio.Reader - msgBytesRemaining int32 - err error - log func(lvl int, msg string, ctx ...interface{}) - shouldLog func(lvl int) bool + cr *chunkreader.ChunkReader + msgType byte + msgBody []byte + rp int // read position + err error + log func(lvl int, msg string, ctx ...interface{}) + shouldLog func(lvl int) bool } // fatal tells rc that a Fatal error has occurred func (r *msgReader) fatal(err error) { if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp) } r.err = err } @@ -31,22 +34,7 @@ func (r *msgReader) rxMsg() (byte, error) { return 0, r.err } - if r.msgBytesRemaining > 0 { - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) - } - - n, err := r.reader.Discard(int(r.msgBytesRemaining)) - r.msgBytesRemaining -= int32(n) - if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { - r.fatal(err) - } - return 0, err - } - } - - b, err := r.reader.Peek(5) + header, err := r.cr.Next(5) if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { r.fatal(err) @@ -54,22 +42,20 @@ func (r *msgReader) rxMsg() (byte, error) { return 0, err } - msgType := b[0] - payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4 + r.msgType = header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - // Try to preload bufio.Reader with entire message - b, err = r.reader.Peek(5 + int(payloadSize)) - if err != nil && err != bufio.ErrBufferFull { + r.msgBody, err = r.cr.Next(bodyLen) + if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { r.fatal(err) } return 0, err } - r.msgBytesRemaining = payloadSize - r.reader.Discard(5) + r.rp = 0 - return msgType, nil + return r.msgType, nil } func (r *msgReader) readByte() byte { @@ -77,20 +63,16 @@ func (r *msgReader) readByte() byte { return 0 } - r.msgBytesRemaining-- - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 1 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.ReadByte() - if err != nil { - r.fatal(err) - return 0 - } + b := r.msgBody[r.rp] + r.rp++ if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp) } return b @@ -101,24 +83,16 @@ func (r *msgReader) readInt16() int16 { return 0 } - r.msgBytesRemaining -= 2 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 2 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(2) - if err != nil { - r.fatal(err) - return 0 - } - - n := int16(binary.BigEndian.Uint16(b)) - - r.reader.Discard(2) + n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:])) + r.rp += 2 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -129,24 +103,16 @@ func (r *msgReader) readInt32() int32 { return 0 } - r.msgBytesRemaining -= 4 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 4 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(4) - if err != nil { - r.fatal(err) - return 0 - } - - n := int32(binary.BigEndian.Uint32(b)) - - r.reader.Discard(4) + n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:])) + r.rp += 4 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -157,24 +123,16 @@ func (r *msgReader) readUint16() uint16 { return 0 } - r.msgBytesRemaining -= 2 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 2 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(2) - if err != nil { - r.fatal(err) - return 0 - } - - n := uint16(binary.BigEndian.Uint16(b)) - - r.reader.Discard(2) + n := binary.BigEndian.Uint16(r.msgBody[r.rp:]) + r.rp += 2 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -185,24 +143,16 @@ func (r *msgReader) readUint32() uint32 { return 0 } - r.msgBytesRemaining -= 4 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 4 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(4) - if err != nil { - r.fatal(err) - return 0 - } - - n := uint32(binary.BigEndian.Uint32(b)) - - r.reader.Discard(4) + n := binary.BigEndian.Uint32(r.msgBody[r.rp:]) + r.rp += 4 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -213,24 +163,16 @@ func (r *msgReader) readInt64() int64 { return 0 } - r.msgBytesRemaining -= 8 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 8 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(8) - if err != nil { - r.fatal(err) - return 0 - } - - n := int64(binary.BigEndian.Uint64(b)) - - r.reader.Discard(8) + n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:])) + r.rp += 8 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -246,22 +188,17 @@ func (r *msgReader) readCString() string { return "" } - b, err := r.reader.ReadBytes(0) - if err != nil { - r.fatal(err) + nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0) + if nullIdx == -1 { + r.fatal(errors.New("null terminated string not found")) return "" } - r.msgBytesRemaining -= int32(len(b)) - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return "" - } - - s := string(b[0 : len(b)-1]) + s := string(r.msgBody[r.rp : r.rp+nullIdx]) + r.rp += nullIdx + 1 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp) } return s @@ -273,58 +210,43 @@ func (r *msgReader) readString(countI32 int32) string { return "" } - r.msgBytesRemaining -= countI32 - if r.msgBytesRemaining < 0 { + count := int(countI32) + + if len(r.msgBody)-r.rp < count { r.fatal(errors.New("read past end of message")) return "" } - count := int(countI32) - var s string - - if r.reader.Buffered() >= count { - buf, _ := r.reader.Peek(count) - s = string(buf) - r.reader.Discard(count) - } else { - buf := make([]byte, count) - _, err := io.ReadFull(r.reader, buf) - if err != nil { - r.fatal(err) - return "" - } - s = string(buf) - } + s := string(r.msgBody[r.rp : r.rp+count]) + r.rp += count if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp) } return s } // readBytes reads count bytes and returns as []byte -func (r *msgReader) readBytes(count int32) []byte { +func (r *msgReader) readBytes(countI32 int32) []byte { if r.err != nil { return nil } - r.msgBytesRemaining -= count - if r.msgBytesRemaining < 0 { + count := int(countI32) + + if len(r.msgBody)-r.rp < count { r.fatal(errors.New("read past end of message")) return nil } - b := make([]byte, int(count)) + b := r.msgBody[r.rp : r.rp+count] + r.rp += count - _, err := io.ReadFull(r.reader, b) - if err != nil { - r.fatal(err) - return nil - } + r.cr.KeepLast() if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp) } return b diff --git a/msg_reader_test.go b/msg_reader_test.go deleted file mode 100644 index 2bbd53c9..00000000 --- a/msg_reader_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package pgx - -import ( - "bufio" - "net" - "testing" - "time" - - "github.com/jackc/pgmock/pgmsg" -) - -func TestMsgReaderPrebuffersWhenPossible(t *testing.T) { - t.Parallel() - - tests := []struct { - msgType byte - payloadSize int32 - buffered bool - }{ - {1, 50, true}, - {2, 0, true}, - {3, 500, true}, - {4, 1050, true}, - {5, 1500, true}, - {6, 1500, true}, - {7, 4000, true}, - {8, 24000, false}, - {9, 4000, true}, - {1, 1500, true}, - {2, 0, true}, - {3, 500, true}, - {4, 1050, true}, - {5, 1500, true}, - {6, 1500, true}, - {7, 4000, true}, - {8, 14000, false}, - {9, 0, true}, - {1, 500, true}, - } - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer ln.Close() - - go func() { - var bigEndian pgmsg.BigEndianBuf - - conn, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - for _, tt := range tests { - _, err = conn.Write([]byte{tt.msgType}) - if err != nil { - t.Fatal(err) - } - - _, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4)) - if err != nil { - t.Fatal(err) - } - - payload := make([]byte, int(tt.payloadSize)) - _, err = conn.Write(payload) - if err != nil { - t.Fatal(err) - } - } - }() - - conn, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - mr := &msgReader{ - reader: bufio.NewReader(conn), - shouldLog: func(int) bool { return false }, - } - - for i, tt := range tests { - msgType, err := mr.rxMsg() - if err != nil { - t.Fatalf("%d. Unexpected error: %v", i, err) - } - - if msgType != tt.msgType { - t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType) - } - - if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered { - t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered()) - } - } -} - -func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) { - t.Parallel() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer ln.Close() - - testCount := 10000 - - go func() { - var bigEndian pgmsg.BigEndianBuf - - conn, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - for i := 0; i < testCount; i++ { - msgType := byte(i) - - _, err = conn.Write([]byte{msgType}) - if err != nil { - t.Fatal(err) - } - - msgSize := i % 4000 - - _, err = conn.Write(bigEndian.Int32(int32(msgSize + 4))) - if err != nil { - t.Fatal(err) - } - - payload := make([]byte, msgSize) - _, err = conn.Write(payload) - if err != nil { - t.Fatal(err) - } - } - }() - - conn, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - mr := &msgReader{ - reader: bufio.NewReader(conn), - shouldLog: func(int) bool { return false }, - } - - conn.SetReadDeadline(time.Now().Add(time.Millisecond)) - - i := 0 - for { - msgType, err := mr.rxMsg() - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - conn.SetReadDeadline(time.Now().Add(time.Millisecond)) - continue - } else { - t.Fatalf("%d. Unexpected error: %v", i, err) - } - } - - expectedMsgType := byte(i) - if msgType != expectedMsgType { - t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType) - } - - expectedMsgSize := i % 4000 - payload := mr.readBytes(mr.msgBytesRemaining) - if mr.err != nil { - t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err) - } - if len(payload) != expectedMsgSize { - t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload)) - } - - i++ - if i == testCount { - break - } - } -} diff --git a/replication.go b/replication.go index 0acc9df9..a3e58fa3 100644 --- a/replication.go +++ b/replication.go @@ -1,9 +1,9 @@ package pgx import ( + "context" "errors" "fmt" - "net" "time" ) @@ -234,7 +234,7 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err walStart := reader.readInt64() serverWalEnd := reader.readInt64() serverTime := reader.readInt64() - walData := reader.readBytes(reader.msgBytesRemaining) + walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp)) walMessage := WalMessage{WalStart: uint64(walStart), ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), @@ -261,47 +261,23 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err return } -// Wait for a single replication message up to timeout time. +// Wait for a single replication message. // // Properly using this requires some knowledge of the postgres replication mechanisms, // as the client can receive both WAL data (the ultimate payload) and server heartbeat // updates. The caller also must send standby status updates in order to keep the connection // alive and working. // -// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified -// duration. -func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) { - var zeroTime time.Time - - deadline := time.Now().Add(timeout) - - // Use SetReadDeadline to implement the timeout. SetReadDeadline will - // cause operations to fail with a *net.OpError that has a Timeout() - // of true. Because the normal pgx rxMsg path considers any error to - // have potentially corrupted the state of the connection, it dies - // on any errors. So to avoid timeout errors in rxMsg we set the - // deadline and peek into the reader. If a timeout error occurs there - // we don't break the pgx connection. If the Peek returns that data - // is available then we turn off the read deadline before the rxMsg. - err = rc.c.conn.SetReadDeadline(deadline) - if err != nil { - return nil, err - } - - // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = rc.c.mr.reader.Peek(1) - if err != nil { - rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline - if err, ok := err.(*net.OpError); ok && err.Timeout() { - return nil, ErrNotificationTimeout - } - return nil, err - } - - err = rc.c.conn.SetReadDeadline(zeroTime) +// This returns the context error when there is no replication message before +// the context is canceled. +func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (r *ReplicationMessage, err error) { + err = rc.c.initContext(ctx) if err != nil { return nil, err } + defer func() { + err = rc.c.termContext(err) + }() return rc.readReplicationMessage() } @@ -401,12 +377,14 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti return } + ctx, _ := context.WithTimeout(context.Background(), initialReplicationResponseTimeout) + // The first replication message that comes back here will be (in a success case) // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has // started. This call will either return nil, nil or if it returns an error // that indicates the start replication command failed var r *ReplicationMessage - r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout) + r, err = rc.WaitForReplicationMessage(ctx) if err != nil && r != nil { if rc.c.shouldLog(LogLevelError) { rc.c.log(LogLevelError, "Unxpected replication message %v", r) diff --git a/replication_test.go b/replication_test.go index 4f810c78..2c2d0af5 100644 --- a/replication_test.go +++ b/replication_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "fmt" "github.com/jackc/pgx" "reflect" @@ -88,11 +89,10 @@ func TestSimpleReplicationConnection(t *testing.T) { for { var message *pgx.ReplicationMessage - message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second)) - if err != nil { - if err != pgx.ErrNotificationTimeout { - t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) - } + ctx, _ := context.WithTimeout(context.Background(), time.Second) + message, err = replicationConn.WaitForReplicationMessage(ctx) + if err != nil && err != context.DeadlineExceeded { + t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) } if message != nil { if message.WalMessage != nil { diff --git a/stress_test.go b/stress_test.go index 72d48a5c..82979fd6 100644 --- a/stress_test.go +++ b/stress_test.go @@ -244,8 +244,9 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error { return err } - _, err = conn.WaitForNotification(100 * time.Millisecond) - if err == pgx.ErrNotificationTimeout { + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + _, err = conn.WaitForNotification(ctx) + if err == context.DeadlineExceeded { return nil } return err