diff --git a/msg_reader.go b/msg_reader.go deleted file mode 100644 index 1858037a..00000000 --- a/msg_reader.go +++ /dev/null @@ -1,249 +0,0 @@ -package pgx - -import ( - "bytes" - "encoding/binary" - "errors" - "net" - - "github.com/jackc/pgx/chunkreader" -) - -// msgReader is a helper that reads values from a PostgreSQL message. -type msgReader struct { - cr *chunkreader.ChunkReader - msgType byte - msgBody []byte - rp int // read position - err error - log func(lvl int, msg string, ctx ...interface{}) - shouldLog func(lvl int) bool -} - -// fatal tells rc that a Fatal error has occurred -func (r *msgReader) fatal(err error) { - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp) - } - r.err = err -} - -// rxMsg reads the type and size of the next message. -func (r *msgReader) rxMsg() (byte, error) { - if r.err != nil { - return 0, r.err - } - - header, err := r.cr.Next(5) - if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { - r.fatal(err) - } - return 0, err - } - - r.msgType = header[0] - bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - - r.msgBody, err = r.cr.Next(bodyLen) - if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { - r.fatal(err) - } - return 0, err - } - - r.rp = 0 - - return r.msgType, nil -} - -func (r *msgReader) readByte() byte { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 1 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - b := r.msgBody[r.rp] - r.rp++ - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp) - } - - return b -} - -func (r *msgReader) readInt16() int16 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 2 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:])) - r.rp += 2 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -func (r *msgReader) readInt32() int32 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 4 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:])) - r.rp += 4 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -func (r *msgReader) readUint16() uint16 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 2 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := binary.BigEndian.Uint16(r.msgBody[r.rp:]) - r.rp += 2 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -func (r *msgReader) readUint32() uint32 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 4 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := binary.BigEndian.Uint32(r.msgBody[r.rp:]) - r.rp += 4 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -func (r *msgReader) readInt64() int64 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 8 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:])) - r.rp += 8 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -// readCString reads a null terminated string -func (r *msgReader) readCString() string { - if r.err != nil { - return "" - } - - nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0) - if nullIdx == -1 { - r.fatal(errors.New("null terminated string not found")) - return "" - } - - s := string(r.msgBody[r.rp : r.rp+nullIdx]) - r.rp += nullIdx + 1 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp) - } - - return s -} - -// readString reads count bytes and returns as string -func (r *msgReader) readString(countI32 int32) string { - if r.err != nil { - return "" - } - - count := int(countI32) - - if len(r.msgBody)-r.rp < count { - r.fatal(errors.New("read past end of message")) - return "" - } - - s := string(r.msgBody[r.rp : r.rp+count]) - r.rp += count - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp) - } - - return s -} - -// readBytes reads count bytes and returns as []byte -func (r *msgReader) readBytes(countI32 int32) []byte { - if r.err != nil { - return nil - } - - count := int(countI32) - - if len(r.msgBody)-r.rp < count { - r.fatal(errors.New("read past end of message")) - return nil - } - - b := r.msgBody[r.rp : r.rp+count] - r.rp += count - - r.cr.KeepLast() - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp) - } - - return b -}