diff --git a/msg_reader.go b/msg_reader.go index f7b497f7..1f4e67e9 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "io" + "net" ) // msgReader is a helper that reads values from a PostgreSQL message. @@ -35,20 +36,39 @@ func (r *msgReader) rxMsg() (byte, error) { r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) } - _, err := r.reader.Discard(int(r.msgBytesRemaining)) + n, err := r.reader.Discard(int(r.msgBytesRemaining)) + r.msgBytesRemaining -= int32(n) if err != nil { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } return 0, err } } b, err := r.reader.Peek(5) if err != nil { - r.fatal(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } return 0, err } + msgType := b[0] - r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4 + payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4 + + // Try to preload bufio.Reader with entire message + b, err = r.reader.Peek(5 + int(payloadSize)) + if err != nil && err != bufio.ErrBufferFull { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } + return 0, err + } + + r.msgBytesRemaining = payloadSize r.reader.Discard(5) + return msgType, nil } diff --git a/msg_reader_test.go b/msg_reader_test.go new file mode 100644 index 00000000..2bbd53c9 --- /dev/null +++ b/msg_reader_test.go @@ -0,0 +1,189 @@ +package pgx + +import ( + "bufio" + "net" + "testing" + "time" + + "github.com/jackc/pgmock/pgmsg" +) + +func TestMsgReaderPrebuffersWhenPossible(t *testing.T) { + t.Parallel() + + tests := []struct { + msgType byte + payloadSize int32 + buffered bool + }{ + {1, 50, true}, + {2, 0, true}, + {3, 500, true}, + {4, 1050, true}, + {5, 1500, true}, + {6, 1500, true}, + {7, 4000, true}, + {8, 24000, false}, + {9, 4000, true}, + {1, 1500, true}, + {2, 0, true}, + {3, 500, true}, + {4, 1050, true}, + {5, 1500, true}, + {6, 1500, true}, + {7, 4000, true}, + {8, 14000, false}, + {9, 0, true}, + {1, 500, true}, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + go func() { + var bigEndian pgmsg.BigEndianBuf + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + for _, tt := range tests { + _, err = conn.Write([]byte{tt.msgType}) + if err != nil { + t.Fatal(err) + } + + _, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4)) + if err != nil { + t.Fatal(err) + } + + payload := make([]byte, int(tt.payloadSize)) + _, err = conn.Write(payload) + if err != nil { + t.Fatal(err) + } + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + mr := &msgReader{ + reader: bufio.NewReader(conn), + shouldLog: func(int) bool { return false }, + } + + for i, tt := range tests { + msgType, err := mr.rxMsg() + if err != nil { + t.Fatalf("%d. Unexpected error: %v", i, err) + } + + if msgType != tt.msgType { + t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType) + } + + if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered { + t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered()) + } + } +} + +func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + testCount := 10000 + + go func() { + var bigEndian pgmsg.BigEndianBuf + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + for i := 0; i < testCount; i++ { + msgType := byte(i) + + _, err = conn.Write([]byte{msgType}) + if err != nil { + t.Fatal(err) + } + + msgSize := i % 4000 + + _, err = conn.Write(bigEndian.Int32(int32(msgSize + 4))) + if err != nil { + t.Fatal(err) + } + + payload := make([]byte, msgSize) + _, err = conn.Write(payload) + if err != nil { + t.Fatal(err) + } + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + mr := &msgReader{ + reader: bufio.NewReader(conn), + shouldLog: func(int) bool { return false }, + } + + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + + i := 0 + for { + msgType, err := mr.rxMsg() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + continue + } else { + t.Fatalf("%d. Unexpected error: %v", i, err) + } + } + + expectedMsgType := byte(i) + if msgType != expectedMsgType { + t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType) + } + + expectedMsgSize := i % 4000 + payload := mr.readBytes(mr.msgBytesRemaining) + if mr.err != nil { + t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err) + } + if len(payload) != expectedMsgSize { + t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload)) + } + + i++ + if i == testCount { + break + } + } +}