From 7323d3f5a793f528f86ef7e24b4ceec8b1742c6a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Apr 2016 19:07:29 -0500 Subject: [PATCH] Encode/decode [][]byte to/from bytea[] fixes #139 --- CHANGELOG.md | 1 + conn.go | 2 +- values.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++ values_test.go | 19 ++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f70e20cc..8cd4bb5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Encode and decode between all Go and PostgreSQL integer types with bounds checking * Decode inet/cidr to net.IP +* Encode/decode [][]byte to/from bytea[] ## Performance diff --git a/conn.go b/conn.go index a8c6cbc5..389d0ccc 100644 --- a/conn.go +++ b/conn.go @@ -857,7 +857,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) diff --git a/values.go b/values.go index 5499d8c7..f46ce1df 100644 --- a/values.go +++ b/values.go @@ -32,6 +32,7 @@ const ( Int2ArrayOid = 1005 Int4ArrayOid = 1007 TextArrayOid = 1009 + ByteaArrayOid = 1001 VarcharArrayOid = 1015 Int8ArrayOid = 1016 Float4ArrayOid = 1021 @@ -67,6 +68,7 @@ var DefaultTypeFormats map[string]int16 func init() { DefaultTypeFormats = map[string]int16{ "_bool": BinaryFormatCode, + "_bytea": BinaryFormatCode, "_cidr": BinaryFormatCode, "_float4": BinaryFormatCode, "_float8": BinaryFormatCode, @@ -604,6 +606,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeString(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) + case [][]byte: + return encodeByteSliceSlice(wbuf, oid, arg) } if v := reflect.ValueOf(arg); v.Kind() == reflect.Ptr { @@ -801,6 +805,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeTextArray(vr) case *[]time.Time: *v = decodeTimestampArray(vr) + case *[][]byte: + *v = decodeByteaArray(vr) case *time.Time: switch vr.Type().DataType { case DateOid: @@ -1683,6 +1689,67 @@ func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error { return nil } +func decodeByteaArray(vr *ValueReader) [][]byte { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != ByteaArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into [][]byte", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([][]byte, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + a[i] = vr.ReadBytes(elSize) + } + } + + return a +} + +func encodeByteSliceSlice(w *WriteBuf, oid Oid, value [][]byte) error { + if oid != ByteaArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[][]byte", oid) + } + + size := 20 // array header size + for _, el := range value { + size += 4 + len(el) + } + + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(ByteaOid)) // type of elements + w.WriteInt32(int32(len(value))) // number of elements + w.WriteInt32(1) // index of first element + + for _, el := range value { + encodeByteSlice(w, ByteaOid, el) + } + + return nil +} + func decodeInt2Array(vr *ValueReader) []int16 { if vr.Len() == -1 { return nil diff --git a/values_test.go b/values_test.go index 91e2173f..14a8aa17 100644 --- a/values_test.go +++ b/values_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "bytes" "net" "reflect" "strings" @@ -706,12 +707,30 @@ func TestArrayDecoding(t *testing.T) { } }, }, + { + "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, + func(t *testing.T, query, scan interface{}) { + queryBytesSliceSlice := query.([][]byte) + scanBytesSliceSlice := *(scan.(*[][]byte)) + if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { + t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice)) + } + for i := range queryBytesSliceSlice { + qb := queryBytesSliceSlice[i] + sb := scanBytesSliceSlice[i] + if bytes.Compare(qb, sb) != 0 { + t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) + } + } + }, + }, } for i, tt := range tests { err := conn.QueryRow(tt.sql, tt.query).Scan(tt.scan) if err != nil { t.Errorf(`%d. error reading array: %v`, i, err) + continue } tt.assert(t, tt.query, tt.scan) ensureConnValid(t, conn)