diff --git a/aclitem.go b/aclitem.go index b8a1549e..f9faab20 100644 --- a/aclitem.go +++ b/aclitem.go @@ -90,7 +90,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error { return nil } -func (dst *Aclitem) DecodeText(src []byte) error { +func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Aclitem{Status: Null} return nil @@ -100,7 +100,7 @@ func (dst *Aclitem) DecodeText(src []byte) error { return nil } -func (src Aclitem) EncodeText(w io.Writer) (bool, error) { +func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/aclitem_array.go b/aclitem_array.go index 5e3647b7..f02d339e 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -82,7 +82,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { return nil } -func (dst *AclitemArray) DecodeText(src []byte) error { +func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = AclitemArray{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,7 +118,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { return nil } -func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -165,7 +165,7 @@ func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/array.go b/array.go index dff0fe81..9561afe5 100644 --- a/array.go +++ b/array.go @@ -27,7 +27,7 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { +func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) } @@ -60,7 +60,7 @@ func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(w io.Writer) error { +func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, w io.Writer) error { _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) if err != nil { return err diff --git a/bool.go b/bool.go index a8e9b8e1..87316381 100644 --- a/bool.go +++ b/bool.go @@ -79,7 +79,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } -func (dst *Bool) DecodeText(src []byte) error { +func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Bool) DecodeText(src []byte) error { return nil } -func (dst *Bool) DecodeBinary(src []byte) error { +func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Bool) DecodeBinary(src []byte) error { return nil } -func (src Bool) EncodeText(w io.Writer) (bool, error) { +func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -126,7 +126,7 @@ func (src Bool) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bool) EncodeBinary(w io.Writer) (bool, error) { +func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/bool_array.go b/bool_array.go index 4c5fc563..1cb46cf6 100644 --- a/bool_array.go +++ b/bool_array.go @@ -83,7 +83,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return nil } -func (dst *BoolArray) DecodeText(src []byte) error { +func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *BoolArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *BoolArray) DecodeText(src []byte) error { return nil } -func (dst *BoolArray) DecodeBinary(src []byte) error { +func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { return nil } -func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { +func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, BoolOid) +func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, BoolOid) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/bytea.go b/bytea.go index 5df05360..dc1e9c07 100644 --- a/bytea.go +++ b/bytea.go @@ -78,7 +78,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. -func (dst *Bytea) DecodeText(src []byte) error { +func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -98,7 +98,7 @@ func (dst *Bytea) DecodeText(src []byte) error { return nil } -func (dst *Bytea) DecodeBinary(src []byte) error { +func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } -func (src Bytea) EncodeText(w io.Writer) (bool, error) { +func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -128,7 +128,7 @@ func (src Bytea) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bytea) EncodeBinary(w io.Writer) (bool, error) { +func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/bytea_array.go b/bytea_array.go index c6f676a4..30405509 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -83,7 +83,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return nil } -func (dst *ByteaArray) DecodeText(src []byte) error { +func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *ByteaArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *ByteaArray) DecodeText(src []byte) error { return nil } -func (dst *ByteaArray) DecodeBinary(src []byte) error { +func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { return nil } -func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, ByteaOid) +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, ByteaOid) } -func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/cid.go b/cid.go index 20957f36..d86e8063 100644 --- a/cid.go +++ b/cid.go @@ -34,18 +34,18 @@ func (src *Cid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Cid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Cid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Cid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Cid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Cid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/cidr.go b/cidr.go new file mode 100644 index 00000000..463b279d --- /dev/null +++ b/cidr.go @@ -0,0 +1,35 @@ +package pgtype + +import ( + "io" +) + +type Cidr Inet + +func (dst *Cidr) Set(src interface{}) error { + return (*Inet)(dst).Set(src) +} + +func (dst *Cidr) Get() interface{} { + return (*Inet)(dst).Get() +} + +func (src *Cidr) AssignTo(dst interface{}) error { + return (*Inet)(src).AssignTo(dst) +} + +func (dst *Cidr) DecodeText(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeText(ci, src) +} + +func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeBinary(ci, src) +} + +func (src Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeText(ci, w) +} + +func (src Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeBinary(ci, w) +} diff --git a/cidr_array.go b/cidr_array.go index c30c53d3..32d2e7bf 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -1,35 +1,328 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + "net" + + "github.com/jackc/pgx/pgio" ) -type CidrArray InetArray +type CidrArray struct { + Elements []Cidr + Dimensions []ArrayDimension + Status Status +} func (dst *CidrArray) Set(src interface{}) error { - return (*InetArray)(dst).Set(src) + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Cidr", value) + } + + return nil } func (dst *CidrArray) Get() interface{} { - return (*InetArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *CidrArray) AssignTo(dst interface{}) error { - return (*InetArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]*net.IPNet: + if src.Status == Present { + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]net.IP: + if src.Status == Present { + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *CidrArray) DecodeText(src []byte) error { - return (*InetArray)(dst).DecodeText(src) +func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Cidr + + if len(uta.Elements) > 0 { + elements = make([]Cidr, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Cidr + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = CidrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *CidrArray) DecodeBinary(src []byte) error { - return (*InetArray)(dst).DecodeBinary(src) +func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = CidrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Cidr, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = CidrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { - return (*InetArray)(src).EncodeText(w) +func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { - return (*InetArray)(src).encodeBinary(w, CidrOid) +func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, CidrOid) +} + +func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/cidr_array_test.go b/cidr_array_test.go new file mode 100644 index 00000000..ec105914 --- /dev/null +++ b/cidr_array_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCidrArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "cidr[]", []interface{}{ + &pgtype.CidrArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{Status: pgtype.Null}, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + pgtype.Cidr{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestCidrArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CidrArray + }{ + { + source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.CidrArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCidrArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.CidrArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/database_sql.go b/database_sql.go new file mode 100644 index 00000000..969d6542 --- /dev/null +++ b/database_sql.go @@ -0,0 +1,66 @@ +package pgtype + +import ( + "bytes" + "errors" +) + +func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { + switch src := src.(type) { + case *Bool: + return src.Bool, nil + case *Bytea: + return src.Bytes, nil + case *Date: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Float4: + return float64(src.Float), nil + case *Float8: + return src.Float, nil + case *GenericBinary: + return src.Bytes, nil + case *GenericText: + return src.String, nil + case *Int2: + return int64(src.Int), nil + case *Int4: + return int64(src.Int), nil + case *Int8: + return int64(src.Int), nil + case *Text: + return src.String, nil + case *Timestamp: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Timestamptz: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Unknown: + return src.String, nil + case *Varchar: + return src.String, nil + } + + buf := &bytes.Buffer{} + if textEncoder, ok := src.(TextEncoder); ok { + _, err := textEncoder.EncodeText(ci, buf) + if err != nil { + return nil, err + } + return buf.String(), nil + } + + if binaryEncoder, ok := src.(BinaryEncoder); ok { + _, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil + } + + return nil, errors.New("cannot convert to database/sql compatible value") +} diff --git a/date.go b/date.go index d0481637..b6cc8329 100644 --- a/date.go +++ b/date.go @@ -38,6 +38,9 @@ func (dst *Date) Set(src interface{}) error { func (dst *Date) Get() interface{} { switch dst.Status { case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } return dst.Time case Null: return nil @@ -76,7 +79,7 @@ func (src *Date) AssignTo(dst interface{}) error { return nil } -func (dst *Date) DecodeText(src []byte) error { +func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -100,7 +103,7 @@ func (dst *Date) DecodeText(src []byte) error { return nil } -func (dst *Date) DecodeBinary(src []byte) error { +func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -125,7 +128,7 @@ func (dst *Date) DecodeBinary(src []byte) error { return nil } -func (src Date) EncodeText(w io.Writer) (bool, error) { +func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -148,7 +151,7 @@ func (src Date) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Date) EncodeBinary(w io.Writer) (bool, error) { +func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/date_array.go b/date_array.go index 7f602d83..ba68d561 100644 --- a/date_array.go +++ b/date_array.go @@ -84,7 +84,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return nil } -func (dst *DateArray) DecodeText(src []byte) error { +func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *DateArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *DateArray) DecodeText(src []byte) error { return nil } -func (dst *DateArray) DecodeBinary(src []byte) error { +func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { return nil } -func (src *DateArray) EncodeText(w io.Writer) (bool, error) { +func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, DateOid) +func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, DateOid) } -func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/float4.go b/float4.go index 053af44b..94b7b7a1 100644 --- a/float4.go +++ b/float4.go @@ -102,7 +102,7 @@ func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } -func (dst *Float4) DecodeText(src []byte) error { +func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Float4) DecodeText(src []byte) error { return nil } -func (dst *Float4) DecodeBinary(src []byte) error { +func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -133,7 +133,7 @@ func (dst *Float4) DecodeBinary(src []byte) error { return nil } -func (src Float4) EncodeText(w io.Writer) (bool, error) { +func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -145,7 +145,7 @@ func (src Float4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float4) EncodeBinary(w io.Writer) (bool, error) { +func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/float4_array.go b/float4_array.go index 0e815e0b..40152bcf 100644 --- a/float4_array.go +++ b/float4_array.go @@ -83,7 +83,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float4Array) DecodeText(src []byte) error { +func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float4Array) DecodeText(src []byte) error { return nil } -func (dst *Float4Array) DecodeBinary(src []byte) error { +func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { return nil } -func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float4Oid) +func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float4Oid) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/float8.go b/float8.go index 635b7a09..dd2d592d 100644 --- a/float8.go +++ b/float8.go @@ -92,7 +92,7 @@ func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } -func (dst *Float8) DecodeText(src []byte) error { +func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Float8) DecodeText(src []byte) error { return nil } -func (dst *Float8) DecodeBinary(src []byte) error { +func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -123,7 +123,7 @@ func (dst *Float8) DecodeBinary(src []byte) error { return nil } -func (src Float8) EncodeText(w io.Writer) (bool, error) { +func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -135,7 +135,7 @@ func (src Float8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float8) EncodeBinary(w io.Writer) (bool, error) { +func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/float8_array.go b/float8_array.go index 811c5a1f..d0ee0d70 100644 --- a/float8_array.go +++ b/float8_array.go @@ -83,7 +83,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float8Array) DecodeText(src []byte) error { +func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float8Array) DecodeText(src []byte) error { return nil } -func (dst *Float8Array) DecodeBinary(src []byte) error { +func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { return nil } -func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float8Oid) +func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float8Oid) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/generic_binary.go b/generic_binary.go index ac35ea60..aa28bb62 100644 --- a/generic_binary.go +++ b/generic_binary.go @@ -20,10 +20,10 @@ func (src *GenericBinary) AssignTo(dst interface{}) error { return (*Bytea)(src).AssignTo(dst) } -func (dst *GenericBinary) DecodeBinary(src []byte) error { - return (*Bytea)(dst).DecodeBinary(src) +func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src GenericBinary) EncodeBinary(w io.Writer) (bool, error) { - return (Bytea)(src).EncodeBinary(w) +func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Bytea)(src).EncodeBinary(ci, w) } diff --git a/generic_text.go b/generic_text.go index 19f41059..bd75e0d0 100644 --- a/generic_text.go +++ b/generic_text.go @@ -20,10 +20,10 @@ func (src *GenericText) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *GenericText) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (src GenericText) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } diff --git a/hstore.go b/hstore.go index c48ae6da..d771d6e6 100644 --- a/hstore.go +++ b/hstore.go @@ -70,7 +70,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { return nil } -func (dst *Hstore) DecodeText(src []byte) error { +func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -90,7 +90,7 @@ func (dst *Hstore) DecodeText(src []byte) error { return nil } -func (dst *Hstore) DecodeBinary(src []byte) error { +func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -132,7 +132,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { rp += valueLen var value Text - err := value.DecodeBinary(valueBuf) + err := value.DecodeBinary(ci, valueBuf) if err != nil { return err } @@ -144,7 +144,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { return nil } -func (src Hstore) EncodeText(w io.Writer) (bool, error) { +func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -175,7 +175,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -196,7 +196,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { +func (src Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -220,7 +220,7 @@ func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { return false, err } - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/inet.go b/inet.go index 87d675f9..b83bd1c9 100644 --- a/inet.go +++ b/inet.go @@ -100,7 +100,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil } -func (dst *Inet) DecodeText(src []byte) error { +func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Inet) DecodeText(src []byte) error { return nil } -func (dst *Inet) DecodeBinary(src []byte) error { +func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -153,7 +153,7 @@ func (dst *Inet) DecodeBinary(src []byte) error { return nil } -func (src Inet) EncodeText(w io.Writer) (bool, error) { +func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Inet) EncodeText(w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(w io.Writer) (bool, error) { +func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/inet_array.go b/inet_array.go index 1d1cf3fd..6cad82e7 100644 --- a/inet_array.go +++ b/inet_array.go @@ -115,7 +115,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return nil } -func (dst *InetArray) DecodeText(src []byte) error { +func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil @@ -137,7 +137,7 @@ func (dst *InetArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -151,14 +151,14 @@ func (dst *InetArray) DecodeText(src []byte) error { return nil } -func (dst *InetArray) DecodeBinary(src []byte) error { +func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -183,7 +183,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -193,7 +193,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { return nil } -func (src *InetArray) EncodeText(w io.Writer) (bool, error) { +func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -240,7 +240,7 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -269,11 +269,11 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, InetOid) +func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, InetOid) } -func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -293,7 +293,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -303,7 +303,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/int2.go b/int2.go index 62e1bc69..6996cd4f 100644 --- a/int2.go +++ b/int2.go @@ -98,7 +98,7 @@ func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int2) DecodeText(src []byte) error { +func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -113,7 +113,7 @@ func (dst *Int2) DecodeText(src []byte) error { return nil } -func (dst *Int2) DecodeBinary(src []byte) error { +func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Int2) DecodeBinary(src []byte) error { return nil } -func (src Int2) EncodeText(w io.Writer) (bool, error) { +func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -140,7 +140,7 @@ func (src Int2) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int2) EncodeBinary(w io.Writer) (bool, error) { +func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/int2_array.go b/int2_array.go index 3d06c018..2bf1c237 100644 --- a/int2_array.go +++ b/int2_array.go @@ -114,7 +114,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int2Array) DecodeText(src []byte) error { +func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int2Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int2Array) DecodeText(src []byte) error { return nil } -func (dst *Int2Array) DecodeBinary(src []byte) error { +func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { return nil } -func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int2Oid) +func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int2Oid) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/int4.go b/int4.go index 8eaf5094..62ee366f 100644 --- a/int4.go +++ b/int4.go @@ -89,7 +89,7 @@ func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int4) DecodeText(src []byte) error { +func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *Int4) DecodeText(src []byte) error { return nil } -func (dst *Int4) DecodeBinary(src []byte) error { +func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -119,7 +119,7 @@ func (dst *Int4) DecodeBinary(src []byte) error { return nil } -func (src Int4) EncodeText(w io.Writer) (bool, error) { +func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -131,7 +131,7 @@ func (src Int4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int4) EncodeBinary(w io.Writer) (bool, error) { +func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/int4_array.go b/int4_array.go index 5cd91c04..dda88eaf 100644 --- a/int4_array.go +++ b/int4_array.go @@ -114,7 +114,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int4Array) DecodeText(src []byte) error { +func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int4Array) DecodeText(src []byte) error { return nil } -func (dst *Int4Array) DecodeBinary(src []byte) error { +func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { return nil } -func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int4Oid) +func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int4Oid) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/int8.go b/int8.go index 2416500d..7ed54f8e 100644 --- a/int8.go +++ b/int8.go @@ -80,7 +80,7 @@ func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int8) DecodeText(src []byte) error { +func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -95,7 +95,7 @@ func (dst *Int8) DecodeText(src []byte) error { return nil } -func (dst *Int8) DecodeBinary(src []byte) error { +func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Int8) DecodeBinary(src []byte) error { return nil } -func (src Int8) EncodeText(w io.Writer) (bool, error) { +func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -123,7 +123,7 @@ func (src Int8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int8) EncodeBinary(w io.Writer) (bool, error) { +func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/int8_array.go b/int8_array.go index 5efc0f45..468c126b 100644 --- a/int8_array.go +++ b/int8_array.go @@ -114,7 +114,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int8Array) DecodeText(src []byte) error { +func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int8Array) DecodeText(src []byte) error { return nil } -func (dst *Int8Array) DecodeBinary(src []byte) error { +func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { return nil } -func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int8Oid) +func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int8Oid) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/json.go b/json.go index ecdb3dab..bfffae14 100644 --- a/json.go +++ b/json.go @@ -84,7 +84,7 @@ func (src *Json) AssignTo(dst interface{}) error { return nil } -func (dst *Json) DecodeText(src []byte) error { +func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Json{Status: Null} return nil @@ -97,11 +97,11 @@ func (dst *Json) DecodeText(src []byte) error { return nil } -func (dst *Json) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Json) EncodeText(w io.Writer) (bool, error) { +func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -113,6 +113,6 @@ func (src Json) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Json) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/jsonb.go b/jsonb.go index 13062e8e..e44f3c41 100644 --- a/jsonb.go +++ b/jsonb.go @@ -19,11 +19,11 @@ func (src *Jsonb) AssignTo(dst interface{}) error { return (*Json)(src).AssignTo(dst) } -func (dst *Jsonb) DecodeText(src []byte) error { - return (*Json)(dst).DecodeText(src) +func (dst *Jsonb) DecodeText(ci *ConnInfo, src []byte) error { + return (*Json)(dst).DecodeText(ci, src) } -func (dst *Jsonb) DecodeBinary(src []byte) error { +func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Jsonb{Status: Null} return nil @@ -46,11 +46,11 @@ func (dst *Jsonb) DecodeBinary(src []byte) error { } -func (src Jsonb) EncodeText(w io.Writer) (bool, error) { - return (Json)(src).EncodeText(w) +func (src Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Json)(src).EncodeText(ci, w) } -func (src Jsonb) EncodeBinary(w io.Writer) (bool, error) { +func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/name.go b/name.go index 9eb12ece..9ebf63d3 100644 --- a/name.go +++ b/name.go @@ -31,18 +31,18 @@ func (src *Name) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *Name) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (dst *Name) DecodeBinary(src []byte) error { - return (*Text)(dst).DecodeBinary(src) +func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) } -func (src Name) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } -func (src Name) EncodeBinary(w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(w) +func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) } diff --git a/oid.go b/oid.go index eab1fbcb..3edd7f3c 100644 --- a/oid.go +++ b/oid.go @@ -18,7 +18,7 @@ import ( // allow for NULL Oids use OidValue. type Oid uint32 -func (dst *Oid) DecodeText(src []byte) error { +func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -32,7 +32,7 @@ func (dst *Oid) DecodeText(src []byte) error { return nil } -func (dst *Oid) DecodeBinary(src []byte) error { +func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -46,12 +46,12 @@ func (dst *Oid) DecodeBinary(src []byte) error { return nil } -func (src Oid) EncodeText(w io.Writer) (bool, error) { +func (src Oid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) return false, err } -func (src Oid) EncodeBinary(w io.Writer) (bool, error) { +func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } diff --git a/oid_value.go b/oid_value.go index a2b2dcbe..1bce6e11 100644 --- a/oid_value.go +++ b/oid_value.go @@ -28,18 +28,18 @@ func (src *OidValue) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OidValue) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *OidValue) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *OidValue) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src OidValue) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src OidValue) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/pgtype.go b/pgtype.go index 7b1470b7..674c0db7 100644 --- a/pgtype.go +++ b/pgtype.go @@ -3,6 +3,7 @@ package pgtype import ( "errors" "io" + "reflect" ) // PostgreSQL oids for common types @@ -83,14 +84,14 @@ type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder MUST not retain a reference to // src. It MUST make a copy if it needs to retain the raw bytes. - DecodeBinary(src []byte) error + DecodeBinary(ci *ConnInfo, src []byte) error } type TextDecoder interface { // DecodeText decodes src into TextDecoder. If src is nil then the original // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST // make a copy if it needs to retain the raw bytes. - DecodeText(src []byte) error + DecodeText(ci *ConnInfo, src []byte) error } // BinaryEncoder is implemented by types that can encode themselves into the @@ -100,7 +101,7 @@ type BinaryEncoder interface { // SQL value NULL then write nothing and return (true, nil). The caller of // EncodeBinary is responsible for writing the correct NULL value or the // length of the data written. - EncodeBinary(w io.Writer) (null bool, err error) + EncodeBinary(ci *ConnInfo, w io.Writer) (null bool, err error) } // TextEncoder is implemented by types that can encode themselves into the @@ -110,7 +111,127 @@ type TextEncoder interface { // value NULL then write nothing and return (true, nil). The caller of // EncodeText is responsible for writing the correct NULL value or the length // of the data written. - EncodeText(w io.Writer) (null bool, err error) + EncodeText(ci *ConnInfo, w io.Writer) (null bool, err error) } var errUndefined = errors.New("cannot encode status undefined") + +type DataType struct { + Value Value + Name string + Oid Oid +} + +type ConnInfo struct { + oidToDataType map[Oid]*DataType + nameToDataType map[string]*DataType + reflectTypeToDataType map[reflect.Type]*DataType +} + +func NewConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, 256), + nameToDataType: make(map[string]*DataType, 256), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + } +} + +func (ci *ConnInfo) InitializeDataTypes(nameOids map[string]Oid) { + for name, oid := range nameOids { + var value Value + if t, ok := nameValues[name]; ok { + value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) + } else { + value = &GenericText{} + } + ci.RegisterDataType(DataType{Value: value, Name: name, Oid: oid}) + } +} + +func (ci *ConnInfo) RegisterDataType(t DataType) { + ci.oidToDataType[t.Oid] = &t + ci.nameToDataType[t.Name] = &t + ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t +} + +func (ci *ConnInfo) DataTypeForOid(oid Oid) (*DataType, bool) { + dt, ok := ci.oidToDataType[oid] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { + dt, ok := ci.nameToDataType[name] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { + dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] + return dt, ok +} + +// DeepCopy makes a deep copy of the ConnInfo. +func (ci *ConnInfo) DeepCopy() *ConnInfo { + ci2 := &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, len(ci.oidToDataType)), + nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), + reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), + } + + for _, dt := range ci.oidToDataType { + ci2.RegisterDataType(DataType{ + Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Name: dt.Name, + Oid: dt.Oid, + }) + } + + return ci2 +} + +var nameValues map[string]Value + +func init() { + nameValues = map[string]Value{ + "_aclitem": &AclitemArray{}, + "_bool": &BoolArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CidrArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_varchar": &VarcharArray{}, + "aclitem": &Aclitem{}, + "bool": &Bool{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &Cid{}, + "cidr": &Cidr{}, + "date": &Date{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int8": &Int8{}, + "json": &Json{}, + "jsonb": &Jsonb{}, + "name": &Name{}, + "oid": &OidValue{}, + "record": &Record{}, + "text": &Text{}, + "tid": &Tid{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "unknown": &Unknown{}, + "varchar": &Varchar{}, + "xid": &Xid{}, + } +} diff --git a/pgtype_test.go b/pgtype_test.go index f9b6f56d..391fed57 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -60,16 +60,16 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(w io.Writer) (bool, error) { - return f.e.EncodeText(w) +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeText(ci, w) } type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(w io.Writer) (bool, error) { - return f.e.EncodeBinary(w) +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeBinary(ci, w) } func forceEncoder(e interface{}, formatCode int16) interface{} { @@ -114,7 +114,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%v does not implement %v", fc.name) + t.Logf("%#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer diff --git a/pguint32.go b/pguint32.go index 05c79c0e..3f9e7bf7 100644 --- a/pguint32.go +++ b/pguint32.go @@ -63,7 +63,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { return nil } -func (dst *pguint32) DecodeText(src []byte) error { +func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -78,7 +78,7 @@ func (dst *pguint32) DecodeText(src []byte) error { return nil } -func (dst *pguint32) DecodeBinary(src []byte) error { +func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *pguint32) DecodeBinary(src []byte) error { return nil } -func (src pguint32) EncodeText(w io.Writer) (bool, error) { +func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src pguint32) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src pguint32) EncodeBinary(w io.Writer) (bool, error) { +func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/qchar.go b/qchar.go index d46e716d..4b32ee4a 100644 --- a/qchar.go +++ b/qchar.go @@ -115,7 +115,7 @@ func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *QChar) DecodeBinary(src []byte) error { +func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = QChar{Status: Null} return nil @@ -129,7 +129,7 @@ func (dst *QChar) DecodeBinary(src []byte) error { return nil } -func (src QChar) EncodeBinary(w io.Writer) (bool, error) { +func (src QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/record.go b/record.go new file mode 100644 index 00000000..1bfd05b9 --- /dev/null +++ b/record.go @@ -0,0 +1,123 @@ +package pgtype + +import ( + "encoding/binary" + "fmt" +) + +// Record is the generic PostgreSQL record type such as is created with the +// "row" function. Record only implements BinaryEncoder and Value. The text +// format output format from PostgreSQL does not include type information and is +// therefore impossible to decode. No encoders are implemented because +// PostgreSQL does not support input of generic records. +type Record struct { + Fields []Value + Status Status +} + +func (dst *Record) Set(src interface{}) error { + switch value := src.(type) { + case []Value: + *dst = Record{Fields: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Record", src) + } + + return nil +} + +func (dst *Record) Get() interface{} { + switch dst.Status { + case Present: + return dst.Fields + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Record) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]Value: + switch src.Status { + case Present: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + case *[]interface{}: + switch src.Status { + case Present: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + fields := make([]Value, fieldCount) + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 8 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldOid := Oid(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var binaryDecoder BinaryDecoder + if dt, ok := ci.DataTypeForOid(fieldOid); ok { + if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { + return fmt.Errorf("unknown oid while decoding record: %v", fieldOid) + } + } + + var fieldBytes []byte + if fieldLen >= 0 { + if len(src[rp:]) < fieldLen { + return fmt.Errorf("Record incomplete %v", src) + } + fieldBytes = src[rp : rp+fieldLen] + rp += fieldLen + } + + if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + fields[i] = binaryDecoder.(Value) + } + + *dst = Record{Fields: fields, Status: Present} + + return nil +} diff --git a/record_test.go b/record_test.go new file mode 100644 index 00000000..bc6e5893 --- /dev/null +++ b/record_test.go @@ -0,0 +1,150 @@ +package pgtype_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestRecordTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + tests := []struct { + sql string + expected pgtype.Record + }{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.sql) + if err != nil { + t.Fatal(err) + } + ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode + + var result pgtype.Record + if err := conn.QueryRow(psName).Scan(&result); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("%d: expected %v, got %v", i, tt.expected, result) + } + } +} + +func TestRecordAssignTo(t *testing.T) { + var valueSlice []pgtype.Value + var interfaceSlice []interface{} + + simpleTests := []struct { + src pgtype.Record + dst interface{} + expected interface{} + }{ + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &valueSlice, + expected: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + }, + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &interfaceSlice, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &valueSlice, + expected: (([]pgtype.Value)(nil)), + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &interfaceSlice, + expected: (([]interface{})(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/text.go b/text.go index 3dd082c9..f1a76b6e 100644 --- a/text.go +++ b/text.go @@ -78,7 +78,7 @@ func (src *Text) AssignTo(dst interface{}) error { return nil } -func (dst *Text) DecodeText(src []byte) error { +func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Text{Status: Null} return nil @@ -88,11 +88,11 @@ func (dst *Text) DecodeText(src []byte) error { return nil } -func (dst *Text) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Text) EncodeText(w io.Writer) (bool, error) { +func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -104,6 +104,6 @@ func (src Text) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Text) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/text_array.go b/text_array.go index 1e6677a9..6e89708f 100644 --- a/text_array.go +++ b/text_array.go @@ -83,7 +83,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return nil } -func (dst *TextArray) DecodeText(src []byte) error { +func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *TextArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *TextArray) DecodeText(src []byte) error { return nil } -func (dst *TextArray) DecodeBinary(src []byte) error { +func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { return nil } -func (src *TextArray) EncodeText(w io.Writer) (bool, error) { +func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TextOid) +func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TextOid) } -func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/tid.go b/tid.go index 20d962df..b91711d3 100644 --- a/tid.go +++ b/tid.go @@ -46,7 +46,7 @@ func (src *Tid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } -func (dst *Tid) DecodeText(src []byte) error { +func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -75,7 +75,7 @@ func (dst *Tid) DecodeText(src []byte) error { return nil } -func (dst *Tid) DecodeBinary(src []byte) error { +func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Tid) DecodeBinary(src []byte) error { return nil } -func (src Tid) EncodeText(w io.Writer) (bool, error) { +func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src Tid) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Tid) EncodeBinary(w io.Writer) (bool, error) { +func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/timestamp.go b/timestamp.go index 3bb8f080..9a9e74ea 100644 --- a/timestamp.go +++ b/timestamp.go @@ -85,7 +85,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeText(src []byte) error { +func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Timestamp) DecodeText(src []byte) error { // DecodeBinary decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeBinary(src []byte) error { +func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -139,7 +139,7 @@ func (dst *Timestamp) DecodeBinary(src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(w io.Writer) (bool, error) { +func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -167,7 +167,7 @@ func (src Timestamp) EncodeText(w io.Writer) (bool, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/timestamp_array.go b/timestamp_array.go index c955dc42..064ad483 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -84,7 +84,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestampArray) DecodeText(src []byte) error { +func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestampArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestampArray) DecodeText(src []byte) error { return nil } -func (dst *TimestampArray) DecodeBinary(src []byte) error { +func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestampOid) +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestampOid) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/timestamptz.go b/timestamptz.go index 5b9f5038..7f57f4b7 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -84,7 +84,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return nil } -func (dst *Timestamptz) DecodeText(src []byte) error { +func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Timestamptz) DecodeText(src []byte) error { return nil } -func (dst *Timestamptz) DecodeBinary(src []byte) error { +func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -143,7 +143,7 @@ func (dst *Timestamptz) DecodeBinary(src []byte) error { return nil } -func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Timestamptz) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/timestamptz_array.go b/timestamptz_array.go index cd63e02e..4af1460b 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -84,7 +84,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestamptzArray) DecodeText(src []byte) error { +func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { return nil } -func (dst *TimestamptzArray) DecodeBinary(src []byte) error { +func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestamptzOid) +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestamptzOid) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/typed_array.go.erb b/typed_array.go.erb index a56097c0..2a46a658 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -82,7 +82,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,14 +118,14 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -150,7 +150,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { elemSrc = src[rp:rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -160,7 +160,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { return nil } -func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { +func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -207,7 +207,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -236,11 +236,11 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, <%= element_oid %>) +func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -260,7 +260,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -270,7 +270,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 41c1313f..5fde32aa 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -8,6 +8,8 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_oid=CidrOid text_null=NULL typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_oid=VarcharOid text_null='"NULL"' typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/unknown.go b/unknown.go new file mode 100644 index 00000000..b951ad99 --- /dev/null +++ b/unknown.go @@ -0,0 +1,32 @@ +package pgtype + +// Unknown represents the PostgreSQL unknown type. It is either a string literal +// or NULL. It is used when PostgreSQL does not know the type of a value. In +// general, this will only be used in pgx when selecting a null value without +// type information. e.g. SELECT NULL; +type Unknown struct { + String string + Status Status +} + +func (dst *Unknown) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Unknown) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Unknown is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Unknown) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} diff --git a/varchar.go b/varchar.go new file mode 100644 index 00000000..adda6c49 --- /dev/null +++ b/varchar.go @@ -0,0 +1,40 @@ +package pgtype + +import ( + "io" +) + +type Varchar Text + +// Set converts from src to dst. Note that as Varchar is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *Varchar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Varchar) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Varchar is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Varchar) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) +} + +func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) +} diff --git a/varchar_array.go b/varchar_array.go index 693b9a61..21e9ccff 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -1,35 +1,296 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + + "github.com/jackc/pgx/pgio" ) -type VarcharArray TextArray +type VarcharArray struct { + Elements []Varchar + Dimensions []ArrayDimension + Status Status +} func (dst *VarcharArray) Set(src interface{}) error { - return (*TextArray)(dst).Set(src) + switch value := src.(type) { + + case []string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Varchar", value) + } + + return nil } func (dst *VarcharArray) Get() interface{} { - return (*TextArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *VarcharArray) AssignTo(dst interface{}) error { - return (*TextArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *VarcharArray) DecodeText(src []byte) error { - return (*TextArray)(dst).DecodeText(src) +func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Varchar + + if len(uta.Elements) > 0 { + elements = make([]Varchar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Varchar + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *VarcharArray) DecodeBinary(src []byte) error { - return (*TextArray)(dst).DecodeBinary(src) +func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Varchar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { - return (*TextArray)(src).EncodeText(w) +func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `"NULL"`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { - return (*TextArray)(src).encodeBinary(w, VarcharOid) +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, VarcharOid) +} + +func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/varchar_array_test.go b/varchar_array_test.go new file mode 100644 index 00000000..4a8b09b8 --- /dev/null +++ b/varchar_array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestVarcharArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "varchar[]", []interface{}{ + &pgtype.VarcharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{Status: pgtype.Null}, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar ", Status: pgtype.Present}, + pgtype.Varchar{String: "NuLL", Status: pgtype.Present}, + pgtype.Varchar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.Varchar{String: "", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + pgtype.Varchar{String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar", Status: pgtype.Present}, + pgtype.Varchar{String: "baz", Status: pgtype.Present}, + pgtype.Varchar{String: "quz", Status: pgtype.Present}, + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestVarcharArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.VarcharArray + }{ + { + source: []string{"foo"}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.VarcharArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.VarcharArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestVarcharArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.VarcharArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.VarcharArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.VarcharArray + dst interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/xid.go b/xid.go index a53120de..c76548a4 100644 --- a/xid.go +++ b/xid.go @@ -37,18 +37,18 @@ func (src *Xid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Xid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Xid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Xid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Xid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Xid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) }