2
0

Add ScanDecoder and ScanValue to composite scanners.

Rename Scan to Next to disambiguate.
This commit is contained in:
Jack Christensen
2020-05-12 15:04:14 -05:00
parent e51cb1ef09
commit fcb385dccb
4 changed files with 104 additions and 94 deletions
+3 -21
View File
@@ -5,7 +5,6 @@ import (
"github.com/jackc/pgio"
"github.com/jackc/pgtype"
errors "golang.org/x/xerrors"
)
type MyCompositeRaw struct {
@@ -35,26 +34,9 @@ func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
a := pgtype.Int4{}
b := pgtype.Text{}
scanner, err := pgtype.NewCompositeBinaryScanner(src)
if err != nil {
return err
}
if 2 != scanner.FieldCount() {
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", scanner.FieldCount())
}
if scanner.Scan() {
if err = a.DecodeBinary(ci, scanner.Bytes()); err != nil {
return err
}
}
if scanner.Scan() {
if err = b.DecodeBinary(ci, scanner.Bytes()); err != nil {
return err
}
}
scanner := pgtype.NewCompositeBinaryScanner(ci, src)
scanner.ScanDecoder(&a)
scanner.ScanDecoder(&b)
if scanner.Err() != nil {
return scanner.Err()
+6 -29
View File
@@ -20,19 +20,10 @@ func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error {
return errors.Errorf("cannot decode unexpected null into CompositeFields")
}
scanner, err := NewCompositeBinaryScanner(src)
if err != nil {
return err
}
if len(cf) != scanner.FieldCount() {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount())
}
scanner := NewCompositeBinaryScanner(ci, src)
for i := 0; scanner.Scan(); i++ {
err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i])
if err != nil {
return err
}
for _, f := range cf {
scanner.ScanValue(f)
}
if scanner.Err() != nil {
@@ -51,30 +42,16 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
return errors.Errorf("cannot decode unexpected null into CompositeFields")
}
scanner, err := NewCompositeTextScanner(src)
if err != nil {
return err
}
scanner := NewCompositeTextScanner(ci, src)
fieldCount := 0
for i := 0; scanner.Scan(); i++ {
err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i])
if err != nil {
return err
}
fieldCount += 1
for _, f := range cf {
scanner.ScanValue(f)
}
if scanner.Err() != nil {
return scanner.Err()
}
if len(cf) != fieldCount {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount)
}
return nil
}
+93 -39
View File
@@ -12,7 +12,7 @@ type CompositeType struct {
status Status
typeName string
fields []Value
fields []ValueTranscoder
}
// NewCompositeType creates a Composite object, which acts as a "schema" for
@@ -22,7 +22,7 @@ type CompositeType struct {
// SetFields method
// To read composite fields back pass result of Scan() method
// to query Scan function.
func NewCompositeType(typeName string, fields ...Value) *CompositeType {
func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType {
return &CompositeType{typeName: typeName, fields: fields}
}
@@ -44,11 +44,11 @@ func (src CompositeType) Get() interface{} {
func (ct *CompositeType) NewTypeValue() Value {
a := &CompositeType{
typeName: ct.typeName,
fields: make([]Value, len(ct.fields)),
fields: make([]ValueTranscoder, len(ct.fields)),
}
for i := range ct.fields {
a.fields[i] = NewValue(ct.fields[i])
a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder)
}
return a
@@ -138,36 +138,34 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte,
case Undefined:
return nil, errUndefined
}
return EncodeRow(ci, buf, src.fields...)
b := NewCompositeBinaryBuilder(ci, buf)
for _, f := range src.fields {
dt, ok := ci.DataTypeForValue(f)
if !ok {
return nil, errors.Errorf("unknown oid")
}
b.AppendEncoder(dt.OID, f)
}
return b.Finish()
}
// DecodeBinary implements BinaryDecoder interface.
// Opposite to Record, fields in a composite act as a "schema"
// and decoding fails if SQL value can't be assigned due to
// type mismatch
func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) {
func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
if buf == nil {
dst.status = Null
return nil
}
scanner, err := NewCompositeBinaryScanner(buf)
if err != nil {
return err
}
if len(dst.fields) != scanner.FieldCount() {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), scanner.FieldCount())
}
scanner := NewCompositeBinaryScanner(ci, buf)
for i := 0; scanner.Scan(); i++ {
binaryDecoder, ok := dst.fields[i].(BinaryDecoder)
if !ok {
return errors.New("Composite field doesn't support binary protocol")
}
if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil {
return err
}
for _, f := range dst.fields {
scanner.ScanDecoder(f)
}
if scanner.Err() != nil {
@@ -180,6 +178,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) {
}
type CompositeBinaryScanner struct {
ci *ConnInfo
rp int
src []byte
@@ -190,25 +189,52 @@ type CompositeBinaryScanner struct {
}
// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
func NewCompositeBinaryScanner(src []byte) (CompositeBinaryScanner, error) {
func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner {
rp := 0
if len(src[rp:]) < 4 {
return CompositeBinaryScanner{}, errors.Errorf("Record incomplete %v", src)
return &CompositeBinaryScanner{err: errors.Errorf("Record incomplete %v", src)}
}
fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4
return CompositeBinaryScanner{
return &CompositeBinaryScanner{
ci: ci,
rp: rp,
src: src,
fieldCount: fieldCount,
}, nil
}
}
// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Scan returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeBinaryScanner) Scan() bool {
// ScanDecoder calls Next and decodes the result with d.
func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// ScanDecoder calls Next and scans the result into d.
func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeBinaryScanner) Next() bool {
if cfs.err != nil {
return false
}
@@ -261,6 +287,7 @@ func (cfs *CompositeBinaryScanner) Err() error {
}
type CompositeTextScanner struct {
ci *ConnInfo
rp int
src []byte
@@ -268,29 +295,56 @@ type CompositeTextScanner struct {
err error
}
// NewCompositeTextScanner a scanner over a text encoded composite balue.
func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) {
// NewCompositeTextScanner a scanner over a text encoded composite value.
func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
if len(src) < 2 {
return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src)
return &CompositeTextScanner{err: errors.Errorf("Record incomplete %v", src)}
}
if src[0] != '(' {
return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('")
return &CompositeTextScanner{err: errors.Errorf("composite text format must start with '('")}
}
if src[len(src)-1] != ')' {
return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'")
return &CompositeTextScanner{err: errors.Errorf("composite text format must end with ')'")}
}
return CompositeTextScanner{
return &CompositeTextScanner{
ci: ci,
rp: 1,
src: src,
}, nil
}
}
// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Scan returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeTextScanner) Scan() bool {
// ScanDecoder calls Next and decodes the result with d.
func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// ScanDecoder calls Next and scans the result into d.
func (cfs *CompositeTextScanner) ScanValue(d interface{}) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d)
} else {
cfs.err = errors.New("read past end of composite")
}
}
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeTextScanner) Next() bool {
if cfs.err != nil {
return false
}
+2 -5
View File
@@ -102,14 +102,11 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error {
return nil
}
scanner, err := NewCompositeBinaryScanner(src)
if err != nil {
return err
}
scanner := NewCompositeBinaryScanner(ci, src)
fields := make([]Value, scanner.FieldCount())
for i := 0; scanner.Scan(); i++ {
for i := 0; scanner.Next(); i++ {
binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i])
if err != nil {
return err