2
0

Rename RecordFieldIter to CompositeBinaryScanner and adjust interface

Use interface similar to bufio.Scanner and pgx.Rows.
This commit is contained in:
Jack Christensen
2020-05-07 13:28:28 -05:00
parent ff9bc5d68d
commit 452511dfc5
4 changed files with 99 additions and 77 deletions
+69 -42
View File
@@ -84,31 +84,29 @@ func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) {
return nil return nil
} }
fieldIter, fieldCount, err := NewRecordFieldIterator(buf) scanner, err := NewCompositeBinaryScanner(buf)
if err != nil { if err != nil {
return err return err
} else if len(dst.fields) != fieldCount { }
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), fieldCount) if len(dst.fields) != scanner.FieldCount() {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), scanner.FieldCount())
} }
_, fieldBytes, eof, err := fieldIter.Next() for i := 0; scanner.Scan(); i++ {
for i := 0; !eof; i++ {
if err != nil {
return err
}
binaryDecoder, ok := dst.fields[i].(BinaryDecoder) binaryDecoder, ok := dst.fields[i].(BinaryDecoder)
if !ok { if !ok {
return errors.New("Composite field doesn't support binary protocol") return errors.New("Composite field doesn't support binary protocol")
} }
if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil {
return err return err
} }
_, fieldBytes, eof, err = fieldIter.Next()
} }
if scanner.Err() != nil {
return scanner.Err()
}
dst.Status = Present dst.Status = Present
return nil return nil
@@ -154,56 +152,85 @@ func (dst *Composite) SetFields(values ...interface{}) error {
return nil return nil
} }
type RecordFieldIter struct { type CompositeBinaryScanner struct {
rp int rp int
src []byte src []byte
fieldCount int32
fieldBytes []byte
fieldOID uint32
err error
} }
// NewRecordFieldIterator creates iterator over binary representation // NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
// of record, aka ROW(), aka Composite func NewCompositeBinaryScanner(src []byte) (CompositeBinaryScanner, error) {
func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) {
rp := 0 rp := 0
if len(src[rp:]) < 4 { if len(src[rp:]) < 4 {
return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src) return CompositeBinaryScanner{}, errors.Errorf("Record incomplete %v", src)
} }
fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4 rp += 4
return RecordFieldIter{ return CompositeBinaryScanner{
rp: rp, rp: rp,
src: src, src: src,
}, fieldCount, nil fieldCount: fieldCount,
}, nil
} }
// Next returns next field decoded from record. eof is returned if no // Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// more fields left to decode. // Scan returns false, the Err method can be called to check if any errors occurred.
func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) { func (cfs *CompositeBinaryScanner) Scan() bool {
if fi.rp == len(fi.src) { if cfs.err != nil {
eof = true return false
return
} }
if len(fi.src[fi.rp:]) < 8 { if cfs.rp == len(cfs.src) {
err = errors.Errorf("Record incomplete %v", fi.src) return false
return
} }
fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:])
fi.rp += 4
fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) if len(cfs.src[cfs.rp:]) < 8 {
fi.rp += 4 cfs.err = errors.Errorf("Record incomplete %v", cfs.src)
return false
}
cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
cfs.rp += 4
fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
cfs.rp += 4
if fieldLen >= 0 { if fieldLen >= 0 {
if len(fi.src[fi.rp:]) < fieldLen { if len(cfs.src[cfs.rp:]) < fieldLen {
err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) cfs.err = errors.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
return return false
} }
buf = fi.src[fi.rp : fi.rp+fieldLen] cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
fi.rp += fieldLen cfs.rp += fieldLen
} else {
cfs.fieldBytes = nil
} }
return return true
}
func (cfs *CompositeBinaryScanner) FieldCount() int {
return int(cfs.fieldCount)
}
// Bytes returns the bytes of the field most recently read by Scan().
func (cfs *CompositeBinaryScanner) Bytes() []byte {
return cfs.fieldBytes
}
// OID returns the OID of the field most recently read by Scan().
func (cfs *CompositeBinaryScanner) OID() uint32 {
return cfs.fieldOID
}
// Err returns any error encountered by the scanner.
func (cfs *CompositeBinaryScanner) Err() error {
return cfs.err
} }
// RecordStart adds record header to the buf // RecordStart adds record header to the buf
+14 -14
View File
@@ -34,29 +34,29 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
a := pgtype.Int4{} a := pgtype.Int4{}
b := pgtype.Text{} b := pgtype.Text{}
fieldIter, fieldCount, err := pgtype.NewRecordFieldIterator(src) scanner, err := pgtype.NewCompositeBinaryScanner(src)
if err != nil { if err != nil {
return err return err
} }
if 2 != fieldCount { if 2 != scanner.FieldCount() {
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", fieldCount) return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", scanner.FieldCount())
} }
_, fieldBytes, eof, err := fieldIter.Next() if scanner.Scan() {
if eof || err != nil { if err = a.DecodeBinary(ci, scanner.Bytes()); err != nil {
return errors.New("Bad record") return err
} }
if err = a.DecodeBinary(ci, fieldBytes); err != nil {
return err
} }
_, fieldBytes, eof, err = fieldIter.Next() if scanner.Scan() {
if eof || err != nil { if err = b.DecodeBinary(ci, scanner.Bytes()); err != nil {
return errors.New("Bad record") return err
}
} }
if err = b.DecodeBinary(ci, fieldBytes); err != nil {
return err if scanner.Err() != nil {
return scanner.Err()
} }
dst.A = a.Int dst.A = a.Int
+8 -10
View File
@@ -442,26 +442,24 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) {
// //
// ScanRowValue takes ownership of src, caller MUST not use it after call // ScanRowValue takes ownership of src, caller MUST not use it after call
func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error { func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error {
fieldIter, fieldCount, err := NewRecordFieldIterator(src) scanner, err := NewCompositeBinaryScanner(src)
if err != nil { if err != nil {
return err return err
} }
if len(dst) != fieldCount { if len(dst) != scanner.FieldCount() {
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", fieldCount, len(dst)) return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", scanner.FieldCount(), len(dst))
} }
fieldOID, fieldBytes, eof, err := fieldIter.Next() for i := 0; scanner.Scan(); i++ {
for i := 0; !eof; i++ { err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), dst[i])
if err != nil { if err != nil {
return err return err
} }
}
if err = ci.Scan(fieldOID, BinaryFormatCode, fieldBytes, dst[i]); err != nil { if scanner.Err() != nil {
return err return scanner.Err()
}
fieldOID, fieldBytes, eof, err = fieldIter.Next()
} }
return nil return nil
+8 -11
View File
@@ -102,29 +102,26 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error {
return nil return nil
} }
fieldIter, fieldCount, err := NewRecordFieldIterator(src) scanner, err := NewCompositeBinaryScanner(src)
if err != nil { if err != nil {
return err return err
} }
fields := make([]Value, fieldCount) fields := make([]Value, scanner.FieldCount())
fieldOID, fieldBytes, eof, err := fieldIter.Next()
for i := 0; !eof; i++ { for i := 0; scanner.Scan(); i++ {
binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i])
if err != nil { if err != nil {
return err return err
} }
binaryDecoder, err := prepareNewBinaryDecoder(ci, fieldOID, &fields[i]) if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil {
if err != nil {
return err return err
} }
}
if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { if scanner.Err() != nil {
return err return scanner.Err()
}
fieldOID, fieldBytes, eof, err = fieldIter.Next()
} }
*dst = Record{Fields: fields, Status: Present} *dst = Record{Fields: fields, Status: Present}