msgReader implemented in terms of ChunkReader
This should substantially reduce memory allocations and memory copies. It also means that PostgreSQL messages are always entirely buffered in memory before processing begins. This simplifies the message processing code. In particular, Conn.WaitForNotification is dramatically simplified by this change.
This commit is contained in:
+62
-140
@@ -1,26 +1,29 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
)
|
||||
|
||||
// msgReader is a helper that reads values from a PostgreSQL message.
|
||||
type msgReader struct {
|
||||
reader *bufio.Reader
|
||||
msgBytesRemaining int32
|
||||
err error
|
||||
log func(lvl int, msg string, ctx ...interface{})
|
||||
shouldLog func(lvl int) bool
|
||||
cr *chunkreader.ChunkReader
|
||||
msgType byte
|
||||
msgBody []byte
|
||||
rp int // read position
|
||||
err error
|
||||
log func(lvl int, msg string, ctx ...interface{})
|
||||
shouldLog func(lvl int) bool
|
||||
}
|
||||
|
||||
// fatal tells rc that a Fatal error has occurred
|
||||
func (r *msgReader) fatal(err error) {
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp)
|
||||
}
|
||||
r.err = err
|
||||
}
|
||||
@@ -31,22 +34,7 @@ func (r *msgReader) rxMsg() (byte, error) {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
if r.msgBytesRemaining > 0 {
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
n, err := r.reader.Discard(int(r.msgBytesRemaining))
|
||||
r.msgBytesRemaining -= int32(n)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(5)
|
||||
header, err := r.cr.Next(5)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
@@ -54,22 +42,20 @@ func (r *msgReader) rxMsg() (byte, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
msgType := b[0]
|
||||
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||
r.msgType = header[0]
|
||||
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
|
||||
// Try to preload bufio.Reader with entire message
|
||||
b, err = r.reader.Peek(5 + int(payloadSize))
|
||||
if err != nil && err != bufio.ErrBufferFull {
|
||||
r.msgBody, err = r.cr.Next(bodyLen)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.msgBytesRemaining = payloadSize
|
||||
r.reader.Discard(5)
|
||||
r.rp = 0
|
||||
|
||||
return msgType, nil
|
||||
return r.msgType, nil
|
||||
}
|
||||
|
||||
func (r *msgReader) readByte() byte {
|
||||
@@ -77,20 +63,16 @@ func (r *msgReader) readByte() byte {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining--
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 1 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadByte()
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
b := r.msgBody[r.rp]
|
||||
r.rp++
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return b
|
||||
@@ -101,24 +83,16 @@ func (r *msgReader) readInt16() int16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 2 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(2)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int16(binary.BigEndian.Uint16(b))
|
||||
|
||||
r.reader.Discard(2)
|
||||
n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:]))
|
||||
r.rp += 2
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
@@ -129,24 +103,16 @@ func (r *msgReader) readInt32() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 4 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(4)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int32(binary.BigEndian.Uint32(b))
|
||||
|
||||
r.reader.Discard(4)
|
||||
n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:]))
|
||||
r.rp += 4
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
@@ -157,24 +123,16 @@ func (r *msgReader) readUint16() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 2
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 2 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(2)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := uint16(binary.BigEndian.Uint16(b))
|
||||
|
||||
r.reader.Discard(2)
|
||||
n := binary.BigEndian.Uint16(r.msgBody[r.rp:])
|
||||
r.rp += 2
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
@@ -185,24 +143,16 @@ func (r *msgReader) readUint32() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 4
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 4 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(4)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := uint32(binary.BigEndian.Uint32(b))
|
||||
|
||||
r.reader.Discard(4)
|
||||
n := binary.BigEndian.Uint32(r.msgBody[r.rp:])
|
||||
r.rp += 4
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
@@ -213,24 +163,16 @@ func (r *msgReader) readInt64() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= 8
|
||||
if r.msgBytesRemaining < 0 {
|
||||
if len(r.msgBody)-r.rp < 8 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(8)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
n := int64(binary.BigEndian.Uint64(b))
|
||||
|
||||
r.reader.Discard(8)
|
||||
n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:]))
|
||||
r.rp += 8
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return n
|
||||
@@ -246,22 +188,17 @@ func (r *msgReader) readCString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := r.reader.ReadBytes(0)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0)
|
||||
if nullIdx == -1 {
|
||||
r.fatal(errors.New("null terminated string not found"))
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= int32(len(b))
|
||||
if r.msgBytesRemaining < 0 {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
s := string(b[0 : len(b)-1])
|
||||
s := string(r.msgBody[r.rp : r.rp+nullIdx])
|
||||
r.rp += nullIdx + 1
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -273,58 +210,43 @@ func (r *msgReader) readString(countI32 int32) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= countI32
|
||||
if r.msgBytesRemaining < 0 {
|
||||
count := int(countI32)
|
||||
|
||||
if len(r.msgBody)-r.rp < count {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return ""
|
||||
}
|
||||
|
||||
count := int(countI32)
|
||||
var s string
|
||||
|
||||
if r.reader.Buffered() >= count {
|
||||
buf, _ := r.reader.Peek(count)
|
||||
s = string(buf)
|
||||
r.reader.Discard(count)
|
||||
} else {
|
||||
buf := make([]byte, count)
|
||||
_, err := io.ReadFull(r.reader, buf)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return ""
|
||||
}
|
||||
s = string(buf)
|
||||
}
|
||||
s := string(r.msgBody[r.rp : r.rp+count])
|
||||
r.rp += count
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// readBytes reads count bytes and returns as []byte
|
||||
func (r *msgReader) readBytes(count int32) []byte {
|
||||
func (r *msgReader) readBytes(countI32 int32) []byte {
|
||||
if r.err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.msgBytesRemaining -= count
|
||||
if r.msgBytesRemaining < 0 {
|
||||
count := int(countI32)
|
||||
|
||||
if len(r.msgBody)-r.rp < count {
|
||||
r.fatal(errors.New("read past end of message"))
|
||||
return nil
|
||||
}
|
||||
|
||||
b := make([]byte, int(count))
|
||||
b := r.msgBody[r.rp : r.rp+count]
|
||||
r.rp += count
|
||||
|
||||
_, err := io.ReadFull(r.reader, b)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
return nil
|
||||
}
|
||||
r.cr.KeepLast()
|
||||
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp)
|
||||
}
|
||||
|
||||
return b
|
||||
|
||||
Reference in New Issue
Block a user