From 109b55f9deda26e3fdff5783b7622814bb7f0442 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Sun, 21 Dec 2014 14:35:38 +0700 Subject: [PATCH] support decoding of []time.Time and []bool --- query.go | 6 +++ values.go | 122 +++++++++++++++++++++++++++++++++++++++++-------- values_test.go | 60 ++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 19 deletions(-) diff --git a/query.go b/query.go index 835829cc..890f9db1 100644 --- a/query.go +++ b/query.go @@ -243,6 +243,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *d = decodeFloat4(vr) case *float64: *d = decodeFloat8(vr) + case *[]bool: + *d = decodeBoolArray(vr) case *[]int16: *d = decodeInt2Array(vr) case *[]int32: @@ -255,6 +257,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *d = decodeFloat8Array(vr) case *[]string: *d = decodeTextArray(vr) + case *[]time.Time: + *d = decodeTimestampArray(vr) case *time.Time: switch vr.Type().DataType { case DateOid: @@ -324,6 +328,8 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeFloat4(vr)) case Float8Oid: values = append(values, decodeFloat8(vr)) + case BoolArrayOid: + values = append(values, decodeBoolArray(vr)) case Int2ArrayOid: values = append(values, decodeInt2Array(vr)) case Int4ArrayOid: diff --git a/values.go b/values.go index 38fc9b17..79ea6c96 100644 --- a/values.go +++ b/values.go @@ -53,11 +53,13 @@ func init() { DefaultTypeFormats = make(map[string]int16) DefaultTypeFormats["_float4"] = BinaryFormatCode DefaultTypeFormats["_float8"] = BinaryFormatCode + DefaultTypeFormats["_bool"] = BinaryFormatCode DefaultTypeFormats["_int2"] = BinaryFormatCode DefaultTypeFormats["_int4"] = BinaryFormatCode DefaultTypeFormats["_int8"] = BinaryFormatCode DefaultTypeFormats["_text"] = BinaryFormatCode DefaultTypeFormats["_varchar"] = BinaryFormatCode + DefaultTypeFormats["_timestamp"] = BinaryFormatCode DefaultTypeFormats["bool"] = BinaryFormatCode DefaultTypeFormats["bytea"] = BinaryFormatCode DefaultTypeFormats["date"] = BinaryFormatCode @@ -1195,6 +1197,66 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { return length, nil } +func decodeBoolArray(vr *ValueReader) []bool { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != BoolArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]bool, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 1: + if vr.ReadByte() == 1 { + a[i] = true + } + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeBoolArray(w *WriteBuf, value interface{}) error { + slice, ok := value.([]bool) + if !ok { + return fmt.Errorf("Expected []bool, received %T", value) + } + + encodeArrayHeader(w, BoolOid, len(slice), 5) + for _, v := range slice { + w.WriteInt32(1) + var b byte + if v { + b = 1 + } + w.WriteByte(b) + } + + return nil +} + func decodeInt2Array(vr *ValueReader) []int16 { if vr.Len() == -1 { return nil @@ -1234,25 +1296,6 @@ func decodeInt2Array(vr *ValueReader) []int16 { return a } -func encodeBoolArray(w *WriteBuf, value interface{}) error { - slice, ok := value.([]bool) - if !ok { - return fmt.Errorf("Expected []bool, received %T", value) - } - - encodeArrayHeader(w, BoolOid, len(slice), 5) - for _, v := range slice { - w.WriteInt32(1) - var b byte - if v { - b = 1 - } - w.WriteByte(b) - } - - return nil -} - func encodeInt2Array(w *WriteBuf, value interface{}) error { slice, ok := value.([]int16) if !ok { @@ -1548,6 +1591,47 @@ func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error { return nil } +func decodeTimestampArray(vr *ValueReader) []time.Time { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != TimestampArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]time.Time, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + microsecSinceY2K := vr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize))) + return nil + } + } + + return a +} + func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { slice, ok := value.([]time.Time) if !ok { diff --git a/values_test.go b/values_test.go index 9aba9fd4..1e91bd41 100644 --- a/values_test.go +++ b/values_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "fmt" "github.com/jackc/pgx" + "reflect" "strings" "testing" "time" @@ -159,6 +160,65 @@ func TestNullXMismatch(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) + tests := []struct { + sql string + query interface{} + scan interface{} + assert func(*testing.T, interface{}, interface{}) + }{ + { + "select $1::bool[]", []bool{true, false, true}, &[]bool{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]bool))) == false { + t.Errorf("failed to encode bool[]") + } + }, + }, + { + "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]int32))) == false { + t.Errorf("failed to encode int[]") + } + }, + }, + { + "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]string))) == false { + t.Errorf("failed to encode text[]") + } + }, + }, + { + "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + t.Errorf("failed to encode time.Time[]") + } + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("ps%d", i) + mustPrepare(t, conn, psName, tt.sql) + + err := conn.QueryRow(psName, tt.query).Scan(tt.scan) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + tt.assert(t, tt.query, tt.scan) + ensureConnValid(t, conn) + } +} + +func TestArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + type allTypes struct { s pgx.NullString i16 pgx.NullInt16