diff --git a/array_type.go b/array_type.go index 1df1689f..c4f162af 100644 --- a/array_type.go +++ b/array_type.go @@ -129,12 +129,44 @@ func (src *ArrayType) AssignTo(dst interface{}) error { } } -func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { +func (ArrayType) BinaryFormatSupported() bool { + return true +} + +func (ArrayType) TextFormatSupported() bool { + return true +} + +func (ArrayType) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *ArrayType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { if src == nil { dst.setNil() return nil } + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src ArrayType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} + +func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err @@ -151,7 +183,7 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(ci, elemSrc) + err = elem.DecodeResult(ci, dst.elementOID, TextFormatCode, elemSrc) if err != nil { return err } @@ -168,11 +200,6 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { } func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - dst.setNil() - return nil - } - var arrayHeader ArrayHeader rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { @@ -204,7 +231,7 @@ func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elem.DecodeBinary(ci, elemSrc) + err = elem.DecodeResult(ci, dst.elementOID, BinaryFormatCode, elemSrc) if err != nil { return err } @@ -253,7 +280,7 @@ func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } } - elemBuf, err := elem.EncodeText(ci, inElemBuf) + elemBuf, err := elem.EncodeParam(ci, src.elementOID, TextFormatCode, inElemBuf) if err != nil { return nil, err } @@ -296,7 +323,7 @@ func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { sp := len(buf) buf = pgio.AppendInt32(buf, -1) - elemBuf, err := src.elements[i].EncodeBinary(ci, buf) + elemBuf, err := src.elements[i].EncodeParam(ci, src.elementOID, BinaryFormatCode, buf) if err != nil { return nil, err } diff --git a/composite_fields.go b/composite_fields.go index b6d09fcf..e7ca89c7 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -59,8 +59,8 @@ func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { b := NewCompositeTextBuilder(ci, buf) for _, f := range cf { - if textEncoder, ok := f.(TextEncoder); ok { - b.AppendEncoder(textEncoder) + if paramEncoder, ok := f.(ParamEncoder); ok { + b.AppendEncoder(paramEncoder) } else { b.AppendValue(f) } @@ -88,15 +88,15 @@ func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) return nil, fmt.Errorf("Unknown OID for %#v", f) } - if binaryEncoder, ok := f.(BinaryEncoder); ok { - b.AppendEncoder(dt.OID, binaryEncoder) + if paramEncoder, ok := f.(ParamEncoder); ok { + b.AppendEncoder(dt.OID, paramEncoder) } else { err := dt.Value.Set(f) if err != nil { return nil, err } - if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { - b.AppendEncoder(dt.OID, binaryEncoder) + if paramEncoder, ok := dt.Value.(ParamEncoder); ok { + b.AppendEncoder(dt.OID, paramEncoder) } else { return nil, fmt.Errorf("Cannot encode binary format for %v", f) } diff --git a/composite_type.go b/composite_type.go index 90b7b6ff..85ab5910 100644 --- a/composite_type.go +++ b/composite_type.go @@ -91,9 +91,13 @@ func (ct *CompositeType) Fields() []CompositeTypeField { return ct.fields } +func (dst *CompositeType) setNil() { + dst.valid = false +} + func (dst *CompositeType) Set(src interface{}) error { if src == nil { - dst.valid = false + dst.setNil() return nil } @@ -110,7 +114,7 @@ func (dst *CompositeType) Set(src interface{}) error { dst.valid = true case *[]interface{}: if value == nil { - dst.valid = false + dst.setNil() return nil } return dst.Set(*value) @@ -213,6 +217,56 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { return true, nil } +func (ct *CompositeType) BinaryFormatSupported() bool { + for _, vt := range ct.valueTranscoders { + if !vt.BinaryFormatSupported() { + return false + } + } + return true +} + +func (ct *CompositeType) TextFormatSupported() bool { + for _, vt := range ct.valueTranscoders { + if !vt.TextFormatSupported() { + return false + } + } + return true +} + +func (ct *CompositeType) PreferredFormat() int16 { + if ct.BinaryFormatSupported() { + return BinaryFormatCode + } + return TextFormatCode +} + +func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} + func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { if !src.valid { return nil, nil @@ -231,11 +285,6 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, // and decoding fails if SQL value can't be assigned due to // type mismatch func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { - if buf == nil { - dst.valid = false - return nil - } - scanner := NewCompositeBinaryScanner(ci, buf) for _, f := range dst.valueTranscoders { @@ -252,11 +301,6 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { } func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { - if buf == nil { - dst.valid = false - return nil - } - scanner := NewCompositeTextScanner(ci, buf) for _, f := range dst.valueTranscoders { @@ -315,13 +359,13 @@ func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner } // ScanDecoder calls Next and decodes the result with d. -func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) { +func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) { if cfs.err != nil { return } if cfs.Next() { - cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes) + cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes) } else { cfs.err = errors.New("read past end of composite") } @@ -425,13 +469,13 @@ func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { } // ScanDecoder calls Next and decodes the result with d. -func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) { +func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) { if cfs.err != nil { return } if cfs.Next() { - cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes) + cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes) } else { cfs.err = errors.New("read past end of composite") } @@ -547,16 +591,16 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { return } - binaryEncoder, ok := dt.Value.(BinaryEncoder) + paramEncoder, ok := dt.Value.(ParamEncoder) if !ok { - b.err = fmt.Errorf("unable to encode binary for OID: %d", oid) + b.err = fmt.Errorf("unable to encode for OID: %d", oid) return } - b.AppendEncoder(oid, binaryEncoder) + b.AppendEncoder(oid, paramEncoder) } -func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) { +func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) { if b.err != nil { return } @@ -564,7 +608,7 @@ func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) b.buf = pgio.AppendUint32(b.buf, oid) lengthPos := len(b.buf) b.buf = pgio.AppendInt32(b.buf, -1) - fieldBuf, err := field.EncodeBinary(b.ci, b.buf) + fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf) if err != nil { b.err = err return @@ -622,21 +666,21 @@ func (b *CompositeTextBuilder) AppendValue(field interface{}) { return } - textEncoder, ok := dt.Value.(TextEncoder) + paramEncoder, ok := dt.Value.(ParamEncoder) if !ok { - b.err = fmt.Errorf("unable to encode text for value: %v", field) + b.err = fmt.Errorf("unable to encode for value: %v", field) return } - b.AppendEncoder(textEncoder) + b.AppendEncoder(paramEncoder) } -func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) { +func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) { if b.err != nil { return } - fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0]) + fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0]) if err != nil { b.err = err return diff --git a/pgtype.go b/pgtype.go index b9067fab..1705ae41 100644 --- a/pgtype.go +++ b/pgtype.go @@ -147,10 +147,9 @@ type TypeValue interface { // ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces. type ValueTranscoder interface { Value - TextEncoder - BinaryEncoder - TextDecoder - BinaryDecoder + FormatSupport + ParamEncoder + ResultDecoder } type FormatSupport interface { @@ -160,12 +159,17 @@ type FormatSupport interface { } type ParamEncoder interface { - FormatSupport + // EncodeParam should append the encoded value of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the + // length of the data written. EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) } type ResultDecoder interface { - FormatSupport + // DecodeResult decodes src into ResultDecoder. If src is nil then the + // original SQL value is NULL. ResultDecoder takes ownership of src. The + // caller MUST not use it again. DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error }