2
0

Add ScanDecoder and ScanValue to composite scanners.

Rename Scan to Next to disambiguate.
This commit is contained in:
Jack Christensen
2020-05-12 15:04:14 -05:00
parent e51cb1ef09
commit fcb385dccb
4 changed files with 104 additions and 94 deletions
+3 -21
View File
@@ -5,7 +5,6 @@ import (
"github.com/jackc/pgio" "github.com/jackc/pgio"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
errors "golang.org/x/xerrors"
) )
type MyCompositeRaw struct { type MyCompositeRaw struct {
@@ -35,26 +34,9 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
a := pgtype.Int4{} a := pgtype.Int4{}
b := pgtype.Text{} b := pgtype.Text{}
scanner, err := pgtype.NewCompositeBinaryScanner(src) scanner := pgtype.NewCompositeBinaryScanner(ci, src)
if err != nil { scanner.ScanDecoder(&a)
return err scanner.ScanDecoder(&b)
}
if 2 != scanner.FieldCount() {
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", scanner.FieldCount())
}
if scanner.Scan() {
if err = a.DecodeBinary(ci, scanner.Bytes()); err != nil {
return err
}
}
if scanner.Scan() {
if err = b.DecodeBinary(ci, scanner.Bytes()); err != nil {
return err
}
}
if scanner.Err() != nil { if scanner.Err() != nil {
return scanner.Err() return scanner.Err()
+6 -29
View File
@@ -20,19 +20,10 @@ func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error {
return errors.Errorf("cannot decode unexpected null into CompositeFields") return errors.Errorf("cannot decode unexpected null into CompositeFields")
} }
scanner, err := NewCompositeBinaryScanner(src) scanner := NewCompositeBinaryScanner(ci, src)
if err != nil {
return err
}
if len(cf) != scanner.FieldCount() {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount())
}
for i := 0; scanner.Scan(); i++ { for _, f := range cf {
err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i]) scanner.ScanValue(f)
if err != nil {
return err
}
} }
if scanner.Err() != nil { if scanner.Err() != nil {
@@ -51,30 +42,16 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
return errors.Errorf("cannot decode unexpected null into CompositeFields") return errors.Errorf("cannot decode unexpected null into CompositeFields")
} }
scanner, err := NewCompositeTextScanner(src) scanner := NewCompositeTextScanner(ci, src)
if err != nil {
return err
}
fieldCount := 0 for _, f := range cf {
scanner.ScanValue(f)
for i := 0; scanner.Scan(); i++ {
err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i])
if err != nil {
return err
}
fieldCount += 1
} }
if scanner.Err() != nil { if scanner.Err() != nil {
return scanner.Err() return scanner.Err()
} }
if len(cf) != fieldCount {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount)
}
return nil return nil
} }
+93 -39
View File
@@ -12,7 +12,7 @@ type CompositeType struct {
status Status status Status
typeName string typeName string
fields []Value fields []ValueTranscoder
} }
// NewCompositeType creates a Composite object, which acts as a "schema" for // NewCompositeType creates a Composite object, which acts as a "schema" for
@@ -22,7 +22,7 @@ type CompositeType struct {
// SetFields method // SetFields method
// To read composite fields back pass result of Scan() method // To read composite fields back pass result of Scan() method
// to query Scan function. // to query Scan function.
func NewCompositeType(typeName string, fields ...Value) *CompositeType { func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType {
return &CompositeType{typeName: typeName, fields: fields} return &CompositeType{typeName: typeName, fields: fields}
} }
@@ -44,11 +44,11 @@ func (src CompositeType) Get() interface{} {
func (ct *CompositeType) NewTypeValue() Value { func (ct *CompositeType) NewTypeValue() Value {
a := &CompositeType{ a := &CompositeType{
typeName: ct.typeName, typeName: ct.typeName,
fields: make([]Value, len(ct.fields)), fields: make([]ValueTranscoder, len(ct.fields)),
} }
for i := range ct.fields { for i := range ct.fields {
a.fields[i] = NewValue(ct.fields[i]) a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder)
} }
return a return a
@@ -138,36 +138,34 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte,
case Undefined: case Undefined:
return nil, errUndefined return nil, errUndefined
} }
return EncodeRow(ci, buf, src.fields...)
b := NewCompositeBinaryBuilder(ci, buf)
for _, f := range src.fields {
dt, ok := ci.DataTypeForValue(f)
if !ok {
return nil, errors.Errorf("unknown oid")
}
b.AppendEncoder(dt.OID, f)
}
return b.Finish()
} }
// DecodeBinary implements BinaryDecoder interface. // DecodeBinary implements BinaryDecoder interface.
// Opposite to Record, fields in a composite act as a "schema" // Opposite to Record, fields in a composite act as a "schema"
// and decoding fails if SQL value can't be assigned due to // and decoding fails if SQL value can't be assigned due to
// type mismatch // type mismatch
func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
if buf == nil { if buf == nil {
dst.status = Null dst.status = Null
return nil return nil
} }
scanner, err := NewCompositeBinaryScanner(buf) scanner := NewCompositeBinaryScanner(ci, buf)
if err != nil {
return err
}
if len(dst.fields) != scanner.FieldCount() {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), scanner.FieldCount())
}
for i := 0; scanner.Scan(); i++ { for _, f := range dst.fields {
binaryDecoder, ok := dst.fields[i].(BinaryDecoder) scanner.ScanDecoder(f)
if !ok {
return errors.New("Composite field doesn't support binary protocol")
}
if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil {
return err
}
} }
if scanner.Err() != nil { if scanner.Err() != nil {
@@ -180,6 +178,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) {
} }
type CompositeBinaryScanner struct { type CompositeBinaryScanner struct {
ci *ConnInfo
rp int rp int
src []byte src []byte
@@ -190,25 +189,52 @@ type CompositeBinaryScanner struct {
} }
// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. // NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
func NewCompositeBinaryScanner(src []byte) (CompositeBinaryScanner, error) { func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner {
rp := 0 rp := 0
if len(src[rp:]) < 4 { if len(src[rp:]) < 4 {
return CompositeBinaryScanner{}, errors.Errorf("Record incomplete %v", src) return &CompositeBinaryScanner{err: errors.Errorf("Record incomplete %v", src)}
} }
fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4 rp += 4
return CompositeBinaryScanner{ return &CompositeBinaryScanner{
ci: ci,
rp: rp, rp: rp,
src: src, src: src,
fieldCount: fieldCount, fieldCount: fieldCount,
}, nil }
} }
// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After // ScanDecoder calls Next and decodes the result with d.
// Scan returns false, the Err method can be called to check if any errors occurred. func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) {
func (cfs *CompositeBinaryScanner) Scan() bool { if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// ScanDecoder calls Next and scans the result into d.
func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeBinaryScanner) Next() bool {
if cfs.err != nil { if cfs.err != nil {
return false return false
} }
@@ -261,6 +287,7 @@ func (cfs *CompositeBinaryScanner) Err() error {
} }
type CompositeTextScanner struct { type CompositeTextScanner struct {
ci *ConnInfo
rp int rp int
src []byte src []byte
@@ -268,29 +295,56 @@ type CompositeTextScanner struct {
err error err error
} }
// NewCompositeTextScanner a scanner over a text encoded composite balue. // NewCompositeTextScanner a scanner over a text encoded composite value.
func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) { func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
if len(src) < 2 { if len(src) < 2 {
return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src) return &CompositeTextScanner{err: errors.Errorf("Record incomplete %v", src)}
} }
if src[0] != '(' { if src[0] != '(' {
return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('") return &CompositeTextScanner{err: errors.Errorf("composite text format must start with '('")}
} }
if src[len(src)-1] != ')' { if src[len(src)-1] != ')' {
return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'") return &CompositeTextScanner{err: errors.Errorf("composite text format must end with ')'")}
} }
return CompositeTextScanner{ return &CompositeTextScanner{
ci: ci,
rp: 1, rp: 1,
src: src, src: src,
}, nil }
} }
// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After // ScanDecoder calls Next and decodes the result with d.
// Scan returns false, the Err method can be called to check if any errors occurred. func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) {
func (cfs *CompositeTextScanner) Scan() bool { if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// ScanDecoder calls Next and scans the result into d.
func (cfs *CompositeTextScanner) ScanValue(d interface{}) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeTextScanner) Next() bool {
if cfs.err != nil { if cfs.err != nil {
return false return false
} }
+2 -5
View File
@@ -102,14 +102,11 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error {
return nil return nil
} }
scanner, err := NewCompositeBinaryScanner(src) scanner := NewCompositeBinaryScanner(ci, src)
if err != nil {
return err
}
fields := make([]Value, scanner.FieldCount()) fields := make([]Value, scanner.FieldCount())
for i := 0; scanner.Scan(); i++ { for i := 0; scanner.Next(); i++ {
binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i])
if err != nil { if err != nil {
return err return err