From 452511dfc51d2f5948062f96c905849fbe1f4053 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 7 May 2020 13:28:28 -0500 Subject: [PATCH] Rename RecordFieldIter to CompositeBinaryScanner and adjust interface Use interface similar to bufio.Scanner and pgx.Rows. --- composite.go | 111 +++++++++++++++++++++++++--------------- composite_bench_test.go | 28 +++++----- convert.go | 18 +++---- record.go | 19 +++---- 4 files changed, 99 insertions(+), 77 deletions(-) diff --git a/composite.go b/composite.go index 6ffe9acf..4e6b68ca 100644 --- a/composite.go +++ b/composite.go @@ -84,31 +84,29 @@ func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { return nil } - fieldIter, fieldCount, err := NewRecordFieldIterator(buf) + scanner, err := NewCompositeBinaryScanner(buf) if err != nil { return err - } else if len(dst.fields) != fieldCount { - return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), fieldCount) + } + if len(dst.fields) != scanner.FieldCount() { + return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), scanner.FieldCount()) } - _, fieldBytes, eof, err := fieldIter.Next() - - for i := 0; !eof; i++ { - if err != nil { - return err - } - + for i := 0; scanner.Scan(); i++ { binaryDecoder, ok := dst.fields[i].(BinaryDecoder) if !ok { return errors.New("Composite field doesn't support binary protocol") } - if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { return err } - - _, fieldBytes, eof, err = fieldIter.Next() } + + if scanner.Err() != nil { + return scanner.Err() + } + dst.Status = Present return nil @@ -154,56 +152,85 @@ func (dst *Composite) SetFields(values ...interface{}) error { return nil } -type RecordFieldIter struct { +type CompositeBinaryScanner struct { rp int src []byte + + fieldCount int32 + fieldBytes []byte + fieldOID uint32 + err error } -// NewRecordFieldIterator creates iterator over binary representation -// of record, aka ROW(), aka Composite -func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) { +// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +func NewCompositeBinaryScanner(src []byte) (CompositeBinaryScanner, error) { rp := 0 if len(src[rp:]) < 4 { - return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src) + return CompositeBinaryScanner{}, errors.Errorf("Record incomplete %v", src) } - fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 - return RecordFieldIter{ - rp: rp, - src: src, - }, fieldCount, nil + return CompositeBinaryScanner{ + rp: rp, + src: src, + fieldCount: fieldCount, + }, nil } -// Next returns next field decoded from record. eof is returned if no -// more fields left to decode. -func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) { - if fi.rp == len(fi.src) { - eof = true - return +// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Scan returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Scan() bool { + if cfs.err != nil { + return false } - if len(fi.src[fi.rp:]) < 8 { - err = errors.Errorf("Record incomplete %v", fi.src) - return + if cfs.rp == len(cfs.src) { + return false } - fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) - fi.rp += 4 - fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) - fi.rp += 4 + if len(cfs.src[cfs.rp:]) < 8 { + cfs.err = errors.Errorf("Record incomplete %v", cfs.src) + return false + } + cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) + cfs.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) + cfs.rp += 4 if fieldLen >= 0 { - if len(fi.src[fi.rp:]) < fieldLen { - err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) - return + if len(cfs.src[cfs.rp:]) < fieldLen { + cfs.err = errors.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) + return false } - buf = fi.src[fi.rp : fi.rp+fieldLen] - fi.rp += fieldLen + cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] + cfs.rp += fieldLen + } else { + cfs.fieldBytes = nil } - return + return true +} + +func (cfs *CompositeBinaryScanner) FieldCount() int { + return int(cfs.fieldCount) +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// OID returns the OID of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) OID() uint32 { + return cfs.fieldOID +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeBinaryScanner) Err() error { + return cfs.err } // RecordStart adds record header to the buf diff --git a/composite_bench_test.go b/composite_bench_test.go index fd31e8ea..fa0f9f61 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -34,29 +34,29 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { a := pgtype.Int4{} b := pgtype.Text{} - fieldIter, fieldCount, err := pgtype.NewRecordFieldIterator(src) + scanner, err := pgtype.NewCompositeBinaryScanner(src) if err != nil { return err } - if 2 != fieldCount { - return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", fieldCount) + if 2 != scanner.FieldCount() { + return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", scanner.FieldCount()) } - _, fieldBytes, eof, err := fieldIter.Next() - if eof || err != nil { - return errors.New("Bad record") - } - if err = a.DecodeBinary(ci, fieldBytes); err != nil { - return err + if scanner.Scan() { + if err = a.DecodeBinary(ci, scanner.Bytes()); err != nil { + return err + } } - _, fieldBytes, eof, err = fieldIter.Next() - if eof || err != nil { - return errors.New("Bad record") + if scanner.Scan() { + if err = b.DecodeBinary(ci, scanner.Bytes()); err != nil { + return err + } } - if err = b.DecodeBinary(ci, fieldBytes); err != nil { - return err + + if scanner.Err() != nil { + return scanner.Err() } dst.A = a.Int diff --git a/convert.go b/convert.go index 91a32a60..4fe659b3 100644 --- a/convert.go +++ b/convert.go @@ -442,26 +442,24 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // // ScanRowValue takes ownership of src, caller MUST not use it after call func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error { - fieldIter, fieldCount, err := NewRecordFieldIterator(src) + scanner, err := NewCompositeBinaryScanner(src) if err != nil { return err } - if len(dst) != fieldCount { - return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", fieldCount, len(dst)) + if len(dst) != scanner.FieldCount() { + return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", scanner.FieldCount(), len(dst)) } - fieldOID, fieldBytes, eof, err := fieldIter.Next() - for i := 0; !eof; i++ { + for i := 0; scanner.Scan(); i++ { + err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), dst[i]) if err != nil { return err } + } - if err = ci.Scan(fieldOID, BinaryFormatCode, fieldBytes, dst[i]); err != nil { - return err - } - - fieldOID, fieldBytes, eof, err = fieldIter.Next() + if scanner.Err() != nil { + return scanner.Err() } return nil diff --git a/record.go b/record.go index b0c47185..0d51ad4c 100644 --- a/record.go +++ b/record.go @@ -102,29 +102,26 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - fieldIter, fieldCount, err := NewRecordFieldIterator(src) + scanner, err := NewCompositeBinaryScanner(src) if err != nil { return err } - fields := make([]Value, fieldCount) - fieldOID, fieldBytes, eof, err := fieldIter.Next() + fields := make([]Value, scanner.FieldCount()) - for i := 0; !eof; i++ { + for i := 0; scanner.Scan(); i++ { + binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) if err != nil { return err } - binaryDecoder, err := prepareNewBinaryDecoder(ci, fieldOID, &fields[i]) - if err != nil { + if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { return err } + } - if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { - return err - } - - fieldOID, fieldBytes, eof, err = fieldIter.Next() + if scanner.Err() != nil { + return scanner.Err() } *dst = Record{Fields: fields, Status: Present}