From 2e13f2fe7691a7c99f55a85fdb2e8934da7a9582 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Thu, 16 Apr 2020 20:59:07 +0100 Subject: [PATCH] Move lowlevel binary routines into own package --- binary/record.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++ convert.go | 22 ++++++-------- record.go | 61 ++++--------------------------------- 3 files changed, 93 insertions(+), 68 deletions(-) create mode 100644 binary/record.go diff --git a/binary/record.go b/binary/record.go new file mode 100644 index 00000000..72b688a8 --- /dev/null +++ b/binary/record.go @@ -0,0 +1,78 @@ +package binary + +import ( + "encoding/binary" + + "github.com/jackc/pgio" + errors "golang.org/x/xerrors" +) + +type RecordFieldIter struct { + rp int + src []byte +} + +// NewRecordFieldIterator creates iterator over binary representation +// of record, aka ROW(), aka Composite +func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) { + rp := 0 + if len(src[rp:]) < 4 { + return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src) + } + + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + return RecordFieldIter{ + rp: rp, + src: src, + }, fieldCount, nil +} + +// Next returns next field decoded from record. eof is returned if no +// more fields left to decode. +func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) { + if fi.rp == len(fi.src) { + eof = true + return + } + + if len(fi.src[fi.rp:]) < 8 { + err = errors.Errorf("Record incomplete %v", fi.src) + return + } + fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) + fi.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) + fi.rp += 4 + + if fieldLen >= 0 { + if len(fi.src[fi.rp:]) < fieldLen { + err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) + return + } + buf = fi.src[fi.rp : fi.rp+fieldLen] + fi.rp += fieldLen + } + + return +} + +// RecordStart adds record header to the buf +func RecordStart(buf []byte, fieldCount int) []byte { + return pgio.AppendUint32(buf, uint32(fieldCount)) +} + +// RecordAdd adds record field to the buf +func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { + buf = pgio.AppendUint32(buf, oid) + buf = pgio.AppendUint32(buf, uint32(len(fieldBytes))) + buf = append(buf, fieldBytes...) + return buf +} + +// RecordAddNull adds null value as a field to the buf +func RecordAddNull(buf []byte, oid uint32) []byte { + return pgio.AppendInt32(buf, int32(-1)) +} diff --git a/convert.go b/convert.go index 134e123d..6d5ea0c9 100644 --- a/convert.go +++ b/convert.go @@ -5,7 +5,7 @@ import ( "reflect" "time" - "github.com/jackc/pgio" + "github.com/jackc/pgtype/binary" errors "golang.org/x/xerrors" ) @@ -442,16 +442,16 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // Values must implement BinaryDecoder interface otherwise error is returned. // ScanRowValue takes ownership of src, caller MUST not use it after call func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { - fieldIter, err := newFieldIterator(src) + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) if err != nil { return err } - if len(dst) != fieldIter.fieldCount { - return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", fieldIter.fieldCount, len(dst)) + if len(dst) != fieldCount { + return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", fieldCount, len(dst)) } - _, fieldBytes, eof, err := fieldIter.next() + _, fieldBytes, eof, err := fieldIter.Next() for i := 0; !eof; i++ { if err != nil { return err @@ -466,7 +466,7 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { return err } - _, fieldBytes, eof, err = fieldIter.next() + _, fieldBytes, eof, err = fieldIter.Next() } return nil @@ -476,14 +476,12 @@ func ScanRowValue(ci *ConnInfo, src []byte, dst ...Value) error { func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { fieldBytes := make([]byte, 0, 128) - newBuf = pgio.AppendUint32(buf, uint32(len(fields))) + newBuf = binary.RecordStart(buf, len(fields)) for _, f := range fields { dt, ok := ci.DataTypeForValue(f) if !ok { return nil, errors.Errorf("Unknown OID for %s", f) } - newBuf = pgio.AppendUint32(newBuf, dt.OID) - if f.Get() != nil { binaryEncoder, ok := f.(BinaryEncoder) if !ok { @@ -493,11 +491,9 @@ func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err er if err != nil { return nil, err } - - newBuf = pgio.AppendUint32(newBuf, uint32(len(fieldBytes))) - newBuf = append(newBuf, fieldBytes...) + newBuf = binary.RecordAdd(newBuf, dt.OID, fieldBytes) } else { - newBuf = pgio.AppendInt32(newBuf, int32(-1)) + newBuf = binary.RecordAddNull(newBuf, dt.OID) } } diff --git a/record.go b/record.go index 08603140..4e39f92a 100644 --- a/record.go +++ b/record.go @@ -1,9 +1,10 @@ package pgtype import ( - "encoding/binary" "reflect" + "github.com/jackc/pgtype/binary" + errors "golang.org/x/xerrors" ) @@ -78,56 +79,6 @@ func (src *Record) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } -type fieldIter struct { - rp int - fieldCount int - src []byte -} - -func newFieldIterator(src []byte) (fieldIter, error) { - rp := 0 - if len(src[rp:]) < 4 { - return fieldIter{}, errors.Errorf("Record incomplete %v", src) - } - - fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - return fieldIter{ - rp: rp, - fieldCount: fieldCount, - src: src, - }, nil -} - -func (fi *fieldIter) next() (fieldOID uint32, buf []byte, eof bool, err error) { - if fi.rp == len(fi.src) { - eof = true - return - } - - if len(fi.src[fi.rp:]) < 8 { - err = errors.Errorf("Record incomplete %v", fi.src) - return - } - fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) - fi.rp += 4 - - fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) - fi.rp += 4 - - if fieldLen >= 0 { - if len(fi.src[fi.rp:]) < fieldLen { - err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) - return - } - buf = fi.src[fi.rp : fi.rp+fieldLen] - fi.rp += fieldLen - } - - return -} - func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { var binaryDecoder BinaryDecoder @@ -153,13 +104,13 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - fieldIter, err := newFieldIterator(src) + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) if err != nil { return err } - fields := make([]Value, fieldIter.fieldCount) - fieldOID, fieldBytes, eof, err := fieldIter.next() + fields := make([]Value, fieldCount) + fieldOID, fieldBytes, eof, err := fieldIter.Next() for i := 0; !eof; i++ { if err != nil { @@ -175,7 +126,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return err } - fieldOID, fieldBytes, eof, err = fieldIter.next() + fieldOID, fieldBytes, eof, err = fieldIter.Next() } *dst = Record{Fields: fields, Status: Present}