From fcb385dccbdd133189d6349c0e402f40d18c248e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 15:04:14 -0500 Subject: [PATCH] Add ScanDecoder and ScanValue to composite scanners. Rename Scan to Next to disambiguate. --- composite_bench_test.go | 24 +------- composite_fields.go | 35 ++--------- composite_type.go | 132 ++++++++++++++++++++++++++++------------ record.go | 7 +-- 4 files changed, 104 insertions(+), 94 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index 4858ccad..cff9d518 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -5,7 +5,6 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgtype" - errors "golang.org/x/xerrors" ) type MyCompositeRaw struct { @@ -35,26 +34,9 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { a := pgtype.Int4{} b := pgtype.Text{} - scanner, err := pgtype.NewCompositeBinaryScanner(src) - if err != nil { - return err - } - - if 2 != scanner.FieldCount() { - return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", scanner.FieldCount()) - } - - if scanner.Scan() { - if err = a.DecodeBinary(ci, scanner.Bytes()); err != nil { - return err - } - } - - if scanner.Scan() { - if err = b.DecodeBinary(ci, scanner.Bytes()); err != nil { - return err - } - } + scanner := pgtype.NewCompositeBinaryScanner(ci, src) + scanner.ScanDecoder(&a) + scanner.ScanDecoder(&b) if scanner.Err() != nil { return scanner.Err() diff --git a/composite_fields.go b/composite_fields.go index b97506eb..b2d9f844 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -20,19 +20,10 @@ func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { return errors.Errorf("cannot decode unexpected null into CompositeFields") } - scanner, err := NewCompositeBinaryScanner(src) - if err != nil { - return err - } - if len(cf) != scanner.FieldCount() { - return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount()) - } + scanner := NewCompositeBinaryScanner(ci, src) - for i := 0; scanner.Scan(); i++ { - err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i]) - if err != nil { - return err - } + for _, f := range cf { + scanner.ScanValue(f) } if scanner.Err() != nil { @@ -51,30 +42,16 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { return errors.Errorf("cannot decode unexpected null into CompositeFields") } - scanner, err := NewCompositeTextScanner(src) - if err != nil { - return err - } + scanner := NewCompositeTextScanner(ci, src) - fieldCount := 0 - - for i := 0; scanner.Scan(); i++ { - err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i]) - if err != nil { - return err - } - - fieldCount += 1 + for _, f := range cf { + scanner.ScanValue(f) } if scanner.Err() != nil { return scanner.Err() } - if len(cf) != fieldCount { - return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount) - } - return nil } diff --git a/composite_type.go b/composite_type.go index 99f0189f..f01e8e64 100644 --- a/composite_type.go +++ b/composite_type.go @@ -12,7 +12,7 @@ type CompositeType struct { status Status typeName string - fields []Value + fields []ValueTranscoder } // NewCompositeType creates a Composite object, which acts as a "schema" for @@ -22,7 +22,7 @@ type CompositeType struct { // SetFields method // To read composite fields back pass result of Scan() method // to query Scan function. -func NewCompositeType(typeName string, fields ...Value) *CompositeType { +func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType { return &CompositeType{typeName: typeName, fields: fields} } @@ -44,11 +44,11 @@ func (src CompositeType) Get() interface{} { func (ct *CompositeType) NewTypeValue() Value { a := &CompositeType{ typeName: ct.typeName, - fields: make([]Value, len(ct.fields)), + fields: make([]ValueTranscoder, len(ct.fields)), } for i := range ct.fields { - a.fields[i] = NewValue(ct.fields[i]) + a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder) } return a @@ -138,36 +138,34 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, case Undefined: return nil, errUndefined } - return EncodeRow(ci, buf, src.fields...) + + b := NewCompositeBinaryBuilder(ci, buf) + for _, f := range src.fields { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, errors.Errorf("unknown oid") + } + + b.AppendEncoder(dt.OID, f) + } + + return b.Finish() } // DecodeBinary implements BinaryDecoder interface. // Opposite to Record, fields in a composite act as a "schema" // and decoding fails if SQL value can't be assigned due to // type mismatch -func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { +func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { if buf == nil { dst.status = Null return nil } - scanner, err := NewCompositeBinaryScanner(buf) - if err != nil { - return err - } - if len(dst.fields) != scanner.FieldCount() { - return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), scanner.FieldCount()) - } + scanner := NewCompositeBinaryScanner(ci, buf) - for i := 0; scanner.Scan(); i++ { - binaryDecoder, ok := dst.fields[i].(BinaryDecoder) - if !ok { - return errors.New("Composite field doesn't support binary protocol") - } - - if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { - return err - } + for _, f := range dst.fields { + scanner.ScanDecoder(f) } if scanner.Err() != nil { @@ -180,6 +178,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { } type CompositeBinaryScanner struct { + ci *ConnInfo rp int src []byte @@ -190,25 +189,52 @@ type CompositeBinaryScanner struct { } // NewCompositeBinaryScanner a scanner over a binary encoded composite balue. -func NewCompositeBinaryScanner(src []byte) (CompositeBinaryScanner, error) { +func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { rp := 0 if len(src[rp:]) < 4 { - return CompositeBinaryScanner{}, errors.Errorf("Record incomplete %v", src) + return &CompositeBinaryScanner{err: errors.Errorf("Record incomplete %v", src)} } fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 - return CompositeBinaryScanner{ + return &CompositeBinaryScanner{ + ci: ci, rp: rp, src: src, fieldCount: fieldCount, - }, nil + } } -// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After -// Scan returns false, the Err method can be called to check if any errors occurred. -func (cfs *CompositeBinaryScanner) Scan() bool { +// ScanDecoder calls Next and decodes the result with d. +func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// ScanDecoder calls Next and scans the result into d. +func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Next() bool { if cfs.err != nil { return false } @@ -261,6 +287,7 @@ func (cfs *CompositeBinaryScanner) Err() error { } type CompositeTextScanner struct { + ci *ConnInfo rp int src []byte @@ -268,29 +295,56 @@ type CompositeTextScanner struct { err error } -// NewCompositeTextScanner a scanner over a text encoded composite balue. -func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) { +// NewCompositeTextScanner a scanner over a text encoded composite value. +func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { if len(src) < 2 { - return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src) + return &CompositeTextScanner{err: errors.Errorf("Record incomplete %v", src)} } if src[0] != '(' { - return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('") + return &CompositeTextScanner{err: errors.Errorf("composite text format must start with '('")} } if src[len(src)-1] != ')' { - return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'") + return &CompositeTextScanner{err: errors.Errorf("composite text format must end with ')'")} } - return CompositeTextScanner{ + return &CompositeTextScanner{ + ci: ci, rp: 1, src: src, - }, nil + } } -// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After -// Scan returns false, the Err method can be called to check if any errors occurred. -func (cfs *CompositeTextScanner) Scan() bool { +// ScanDecoder calls Next and decodes the result with d. +func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// ScanDecoder calls Next and scans the result into d. +func (cfs *CompositeTextScanner) ScanValue(d interface{}) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Next() bool { if cfs.err != nil { return false } diff --git a/record.go b/record.go index 0d51ad4c..7899a881 100644 --- a/record.go +++ b/record.go @@ -102,14 +102,11 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - scanner, err := NewCompositeBinaryScanner(src) - if err != nil { - return err - } + scanner := NewCompositeBinaryScanner(ci, src) fields := make([]Value, scanner.FieldCount()) - for i := 0; scanner.Scan(); i++ { + for i := 0; scanner.Next(); i++ { binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) if err != nil { return err