2
0

ValueTranscoder uses new interfaces

This commit is contained in:
Jack Christensen
2021-12-04 12:45:20 -06:00
parent 8f454e4cd6
commit e22675d20b
4 changed files with 123 additions and 48 deletions
+37 -10
View File
@@ -129,12 +129,44 @@ func (src *ArrayType) AssignTo(dst interface{}) error {
} }
} }
func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { func (ArrayType) BinaryFormatSupported() bool {
return true
}
func (ArrayType) TextFormatSupported() bool {
return true
}
func (ArrayType) PreferredFormat() int16 {
return TextFormatCode
}
func (dst *ArrayType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
if src == nil { if src == nil {
dst.setNil() dst.setNil()
return nil return nil
} }
switch format {
case BinaryFormatCode:
return dst.DecodeBinary(ci, src)
case TextFormatCode:
return dst.DecodeText(ci, src)
}
return fmt.Errorf("unknown format code %d", format)
}
func (src ArrayType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) {
switch format {
case BinaryFormatCode:
return src.EncodeBinary(ci, buf)
case TextFormatCode:
return src.EncodeText(ci, buf)
}
return nil, fmt.Errorf("unknown format code %d", format)
}
func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
uta, err := ParseUntypedTextArray(string(src)) uta, err := ParseUntypedTextArray(string(src))
if err != nil { if err != nil {
return err return err
@@ -151,7 +183,7 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
if s != "NULL" { if s != "NULL" {
elemSrc = []byte(s) elemSrc = []byte(s)
} }
err = elem.DecodeText(ci, elemSrc) err = elem.DecodeResult(ci, dst.elementOID, TextFormatCode, elemSrc)
if err != nil { if err != nil {
return err return err
} }
@@ -168,11 +200,6 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
} }
func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
dst.setNil()
return nil
}
var arrayHeader ArrayHeader var arrayHeader ArrayHeader
rp, err := arrayHeader.DecodeBinary(ci, src) rp, err := arrayHeader.DecodeBinary(ci, src)
if err != nil { if err != nil {
@@ -204,7 +231,7 @@ func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error {
elemSrc = src[rp : rp+elemLen] elemSrc = src[rp : rp+elemLen]
rp += elemLen rp += elemLen
} }
err = elem.DecodeBinary(ci, elemSrc) err = elem.DecodeResult(ci, dst.elementOID, BinaryFormatCode, elemSrc)
if err != nil { if err != nil {
return err return err
} }
@@ -253,7 +280,7 @@ func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
} }
} }
elemBuf, err := elem.EncodeText(ci, inElemBuf) elemBuf, err := elem.EncodeParam(ci, src.elementOID, TextFormatCode, inElemBuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -296,7 +323,7 @@ func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
sp := len(buf) sp := len(buf)
buf = pgio.AppendInt32(buf, -1) buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.elements[i].EncodeBinary(ci, buf) elemBuf, err := src.elements[i].EncodeParam(ci, src.elementOID, BinaryFormatCode, buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+6 -6
View File
@@ -59,8 +59,8 @@ func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
b := NewCompositeTextBuilder(ci, buf) b := NewCompositeTextBuilder(ci, buf)
for _, f := range cf { for _, f := range cf {
if textEncoder, ok := f.(TextEncoder); ok { if paramEncoder, ok := f.(ParamEncoder); ok {
b.AppendEncoder(textEncoder) b.AppendEncoder(paramEncoder)
} else { } else {
b.AppendValue(f) b.AppendValue(f)
} }
@@ -88,15 +88,15 @@ func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error)
return nil, fmt.Errorf("Unknown OID for %#v", f) return nil, fmt.Errorf("Unknown OID for %#v", f)
} }
if binaryEncoder, ok := f.(BinaryEncoder); ok { if paramEncoder, ok := f.(ParamEncoder); ok {
b.AppendEncoder(dt.OID, binaryEncoder) b.AppendEncoder(dt.OID, paramEncoder)
} else { } else {
err := dt.Value.Set(f) err := dt.Value.Set(f)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { if paramEncoder, ok := dt.Value.(ParamEncoder); ok {
b.AppendEncoder(dt.OID, binaryEncoder) b.AppendEncoder(dt.OID, paramEncoder)
} else { } else {
return nil, fmt.Errorf("Cannot encode binary format for %v", f) return nil, fmt.Errorf("Cannot encode binary format for %v", f)
} }
+70 -26
View File
@@ -91,9 +91,13 @@ func (ct *CompositeType) Fields() []CompositeTypeField {
return ct.fields return ct.fields
} }
func (dst *CompositeType) setNil() {
dst.valid = false
}
func (dst *CompositeType) Set(src interface{}) error { func (dst *CompositeType) Set(src interface{}) error {
if src == nil { if src == nil {
dst.valid = false dst.setNil()
return nil return nil
} }
@@ -110,7 +114,7 @@ func (dst *CompositeType) Set(src interface{}) error {
dst.valid = true dst.valid = true
case *[]interface{}: case *[]interface{}:
if value == nil { if value == nil {
dst.valid = false dst.setNil()
return nil return nil
} }
return dst.Set(*value) return dst.Set(*value)
@@ -213,6 +217,56 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
return true, nil return true, nil
} }
func (ct *CompositeType) BinaryFormatSupported() bool {
for _, vt := range ct.valueTranscoders {
if !vt.BinaryFormatSupported() {
return false
}
}
return true
}
func (ct *CompositeType) TextFormatSupported() bool {
for _, vt := range ct.valueTranscoders {
if !vt.TextFormatSupported() {
return false
}
}
return true
}
func (ct *CompositeType) PreferredFormat() int16 {
if ct.BinaryFormatSupported() {
return BinaryFormatCode
}
return TextFormatCode
}
func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
if src == nil {
dst.setNil()
return nil
}
switch format {
case BinaryFormatCode:
return dst.DecodeBinary(ci, src)
case TextFormatCode:
return dst.DecodeText(ci, src)
}
return fmt.Errorf("unknown format code %d", format)
}
func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) {
switch format {
case BinaryFormatCode:
return src.EncodeBinary(ci, buf)
case TextFormatCode:
return src.EncodeText(ci, buf)
}
return nil, fmt.Errorf("unknown format code %d", format)
}
func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
if !src.valid { if !src.valid {
return nil, nil return nil, nil
@@ -231,11 +285,6 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte,
// and decoding fails if SQL value can't be assigned due to // and decoding fails if SQL value can't be assigned due to
// type mismatch // type mismatch
func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
if buf == nil {
dst.valid = false
return nil
}
scanner := NewCompositeBinaryScanner(ci, buf) scanner := NewCompositeBinaryScanner(ci, buf)
for _, f := range dst.valueTranscoders { for _, f := range dst.valueTranscoders {
@@ -252,11 +301,6 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
} }
func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
if buf == nil {
dst.valid = false
return nil
}
scanner := NewCompositeTextScanner(ci, buf) scanner := NewCompositeTextScanner(ci, buf)
for _, f := range dst.valueTranscoders { for _, f := range dst.valueTranscoders {
@@ -315,13 +359,13 @@ func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner
} }
// ScanDecoder calls Next and decodes the result with d. // ScanDecoder calls Next and decodes the result with d.
func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) { func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) {
if cfs.err != nil { if cfs.err != nil {
return return
} }
if cfs.Next() { if cfs.Next() {
cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes) cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes)
} else { } else {
cfs.err = errors.New("read past end of composite") cfs.err = errors.New("read past end of composite")
} }
@@ -425,13 +469,13 @@ func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
} }
// ScanDecoder calls Next and decodes the result with d. // ScanDecoder calls Next and decodes the result with d.
func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) { func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) {
if cfs.err != nil { if cfs.err != nil {
return return
} }
if cfs.Next() { if cfs.Next() {
cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes) cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes)
} else { } else {
cfs.err = errors.New("read past end of composite") cfs.err = errors.New("read past end of composite")
} }
@@ -547,16 +591,16 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) {
return return
} }
binaryEncoder, ok := dt.Value.(BinaryEncoder) paramEncoder, ok := dt.Value.(ParamEncoder)
if !ok { if !ok {
b.err = fmt.Errorf("unable to encode binary for OID: %d", oid) b.err = fmt.Errorf("unable to encode for OID: %d", oid)
return return
} }
b.AppendEncoder(oid, binaryEncoder) b.AppendEncoder(oid, paramEncoder)
} }
func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) { func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) {
if b.err != nil { if b.err != nil {
return return
} }
@@ -564,7 +608,7 @@ func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder)
b.buf = pgio.AppendUint32(b.buf, oid) b.buf = pgio.AppendUint32(b.buf, oid)
lengthPos := len(b.buf) lengthPos := len(b.buf)
b.buf = pgio.AppendInt32(b.buf, -1) b.buf = pgio.AppendInt32(b.buf, -1)
fieldBuf, err := field.EncodeBinary(b.ci, b.buf) fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf)
if err != nil { if err != nil {
b.err = err b.err = err
return return
@@ -622,21 +666,21 @@ func (b *CompositeTextBuilder) AppendValue(field interface{}) {
return return
} }
textEncoder, ok := dt.Value.(TextEncoder) paramEncoder, ok := dt.Value.(ParamEncoder)
if !ok { if !ok {
b.err = fmt.Errorf("unable to encode text for value: %v", field) b.err = fmt.Errorf("unable to encode for value: %v", field)
return return
} }
b.AppendEncoder(textEncoder) b.AppendEncoder(paramEncoder)
} }
func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) { func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) {
if b.err != nil { if b.err != nil {
return return
} }
fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0]) fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0])
if err != nil { if err != nil {
b.err = err b.err = err
return return
+10 -6
View File
@@ -147,10 +147,9 @@ type TypeValue interface {
// ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces. // ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces.
type ValueTranscoder interface { type ValueTranscoder interface {
Value Value
TextEncoder FormatSupport
BinaryEncoder ParamEncoder
TextDecoder ResultDecoder
BinaryDecoder
} }
type FormatSupport interface { type FormatSupport interface {
@@ -160,12 +159,17 @@ type FormatSupport interface {
} }
type ParamEncoder interface { type ParamEncoder interface {
FormatSupport // EncodeParam should append the encoded value of self to buf. If self is the
// SQL value NULL then append nothing and return (nil, nil). The caller of
// EncodeText is responsible for writing the correct NULL value or the
// length of the data written.
EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error)
} }
type ResultDecoder interface { type ResultDecoder interface {
FormatSupport // DecodeResult decodes src into ResultDecoder. If src is nil then the
// original SQL value is NULL. ResultDecoder takes ownership of src. The
// caller MUST not use it again.
DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error
} }