2
0

msgReader implemented in terms of ChunkReader

This should substantially reduce memory allocations and memory copies.

It also means that PostgreSQL messages are always entirely buffered in memory
before processing begins. This simplifies the message processing code.

In particular, Conn.WaitForNotification is dramatically simplified by this
change.
This commit is contained in:
Jack Christensen
2017-02-13 20:41:58 -06:00
parent 84802ece05
commit 11b82b3ca4
7 changed files with 130 additions and 467 deletions
+62 -140
View File
@@ -1,26 +1,29 @@
package pgx
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"github.com/jackc/pgx/chunkreader"
)
// msgReader is a helper that reads values from a PostgreSQL message.
type msgReader struct {
reader *bufio.Reader
msgBytesRemaining int32
err error
log func(lvl int, msg string, ctx ...interface{})
shouldLog func(lvl int) bool
cr *chunkreader.ChunkReader
msgType byte
msgBody []byte
rp int // read position
err error
log func(lvl int, msg string, ctx ...interface{})
shouldLog func(lvl int) bool
}
// fatal tells rc that a Fatal error has occurred
func (r *msgReader) fatal(err error) {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp)
}
r.err = err
}
@@ -31,22 +34,7 @@ func (r *msgReader) rxMsg() (byte, error) {
return 0, r.err
}
if r.msgBytesRemaining > 0 {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
}
n, err := r.reader.Discard(int(r.msgBytesRemaining))
r.msgBytesRemaining -= int32(n)
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
}
return 0, err
}
}
b, err := r.reader.Peek(5)
header, err := r.cr.Next(5)
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
@@ -54,22 +42,20 @@ func (r *msgReader) rxMsg() (byte, error) {
return 0, err
}
msgType := b[0]
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4
r.msgType = header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
// Try to preload bufio.Reader with entire message
b, err = r.reader.Peek(5 + int(payloadSize))
if err != nil && err != bufio.ErrBufferFull {
r.msgBody, err = r.cr.Next(bodyLen)
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
}
return 0, err
}
r.msgBytesRemaining = payloadSize
r.reader.Discard(5)
r.rp = 0
return msgType, nil
return r.msgType, nil
}
func (r *msgReader) readByte() byte {
@@ -77,20 +63,16 @@ func (r *msgReader) readByte() byte {
return 0
}
r.msgBytesRemaining--
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 1 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.ReadByte()
if err != nil {
r.fatal(err)
return 0
}
b := r.msgBody[r.rp]
r.rp++
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp)
}
return b
@@ -101,24 +83,16 @@ func (r *msgReader) readInt16() int16 {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 2 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(2)
if err != nil {
r.fatal(err)
return 0
}
n := int16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:]))
r.rp += 2
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@@ -129,24 +103,16 @@ func (r *msgReader) readInt32() int32 {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 4 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(4)
if err != nil {
r.fatal(err)
return 0
}
n := int32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:]))
r.rp += 4
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@@ -157,24 +123,16 @@ func (r *msgReader) readUint16() uint16 {
return 0
}
r.msgBytesRemaining -= 2
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 2 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(2)
if err != nil {
r.fatal(err)
return 0
}
n := uint16(binary.BigEndian.Uint16(b))
r.reader.Discard(2)
n := binary.BigEndian.Uint16(r.msgBody[r.rp:])
r.rp += 2
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@@ -185,24 +143,16 @@ func (r *msgReader) readUint32() uint32 {
return 0
}
r.msgBytesRemaining -= 4
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 4 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(4)
if err != nil {
r.fatal(err)
return 0
}
n := uint32(binary.BigEndian.Uint32(b))
r.reader.Discard(4)
n := binary.BigEndian.Uint32(r.msgBody[r.rp:])
r.rp += 4
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@@ -213,24 +163,16 @@ func (r *msgReader) readInt64() int64 {
return 0
}
r.msgBytesRemaining -= 8
if r.msgBytesRemaining < 0 {
if len(r.msgBody)-r.rp < 8 {
r.fatal(errors.New("read past end of message"))
return 0
}
b, err := r.reader.Peek(8)
if err != nil {
r.fatal(err)
return 0
}
n := int64(binary.BigEndian.Uint64(b))
r.reader.Discard(8)
n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:]))
r.rp += 8
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp)
}
return n
@@ -246,22 +188,17 @@ func (r *msgReader) readCString() string {
return ""
}
b, err := r.reader.ReadBytes(0)
if err != nil {
r.fatal(err)
nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0)
if nullIdx == -1 {
r.fatal(errors.New("null terminated string not found"))
return ""
}
r.msgBytesRemaining -= int32(len(b))
if r.msgBytesRemaining < 0 {
r.fatal(errors.New("read past end of message"))
return ""
}
s := string(b[0 : len(b)-1])
s := string(r.msgBody[r.rp : r.rp+nullIdx])
r.rp += nullIdx + 1
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp)
}
return s
@@ -273,58 +210,43 @@ func (r *msgReader) readString(countI32 int32) string {
return ""
}
r.msgBytesRemaining -= countI32
if r.msgBytesRemaining < 0 {
count := int(countI32)
if len(r.msgBody)-r.rp < count {
r.fatal(errors.New("read past end of message"))
return ""
}
count := int(countI32)
var s string
if r.reader.Buffered() >= count {
buf, _ := r.reader.Peek(count)
s = string(buf)
r.reader.Discard(count)
} else {
buf := make([]byte, count)
_, err := io.ReadFull(r.reader, buf)
if err != nil {
r.fatal(err)
return ""
}
s = string(buf)
}
s := string(r.msgBody[r.rp : r.rp+count])
r.rp += count
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp)
}
return s
}
// readBytes reads count bytes and returns as []byte
func (r *msgReader) readBytes(count int32) []byte {
func (r *msgReader) readBytes(countI32 int32) []byte {
if r.err != nil {
return nil
}
r.msgBytesRemaining -= count
if r.msgBytesRemaining < 0 {
count := int(countI32)
if len(r.msgBody)-r.rp < count {
r.fatal(errors.New("read past end of message"))
return nil
}
b := make([]byte, int(count))
b := r.msgBody[r.rp : r.rp+count]
r.rp += count
_, err := io.ReadFull(r.reader, b)
if err != nil {
r.fatal(err)
return nil
}
r.cr.KeepLast()
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp)
}
return b