From d1b42d1c8eb6322eeba1befe076971c2f946e21e Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Sun, 21 Dec 2014 13:01:24 +0700 Subject: [PATCH 1/4] support inserting into bool[] --- conn.go | 4 +++- conn_test.go | 16 ++++++++++++++++ values.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 2c38fd2c..26d99169 100644 --- a/conn.go +++ b/conn.go @@ -549,7 +549,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -593,6 +593,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTimestampTz(wbuf, arguments[i]) case TimestampOid: err = encodeTimestamp(wbuf, arguments[i]) + case BoolArrayOid: + err = encodeBoolArray(wbuf, arguments[i]) case Int2ArrayOid: err = encodeInt2Array(wbuf, arguments[i]) case Int4ArrayOid: diff --git a/conn_test.go b/conn_test.go index c6264f77..9aed24e1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -515,3 +515,19 @@ func TestCommandTag(t *testing.T) { } } } + +func TestInsertBoolArray(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} diff --git a/values.go b/values.go index 00091513..33515b41 100644 --- a/values.go +++ b/values.go @@ -21,6 +21,7 @@ const ( OidOid = 26 Float4Oid = 700 Float8Oid = 701 + BoolArrayOid = 1000 Int2ArrayOid = 1005 Int4ArrayOid = 1007 TextArrayOid = 1009 @@ -1232,6 +1233,33 @@ func decodeInt2Array(vr *ValueReader) []int16 { return a } +func encodeBoolArray(w *WriteBuf, value interface{}) error { + slice, ok := value.([]bool) + if !ok { + return fmt.Errorf("Expected []bool, received %T", value) + } + + size := 20 + len(slice)*5 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(BoolOid) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(1) + var b byte + if v { + b = 1 + } + w.WriteByte(b) + } + + return nil +} + func encodeInt2Array(w *WriteBuf, value interface{}) error { slice, ok := value.([]int16) if !ok { From 67292290cf55283ff8aaf84b5094a6aa7a399d40 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Sun, 21 Dec 2014 13:35:39 +0700 Subject: [PATCH 2/4] support for inserting []time.Time into timestamp[] columns --- conn.go | 4 +++- conn_test.go | 16 +++++++++++++ values.go | 68 ++++++++++++++++++++++++++++++++++++---------------- 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/conn.go b/conn.go index 26d99169..0b06a4aa 100644 --- a/conn.go +++ b/conn.go @@ -549,7 +549,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -609,6 +609,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTextArray(wbuf, arguments[i], TextOid) case VarcharArrayOid: err = encodeTextArray(wbuf, arguments[i], VarcharOid) + case TimestampArrayOid: + err = encodeTimestampArray(wbuf, arguments[i], VarcharOid) case OidOid: err = encodeOid(wbuf, arguments[i]) default: diff --git a/conn_test.go b/conn_test.go index 9aed24e1..244b46c6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -531,3 +531,19 @@ func TestInsertBoolArray(t *testing.T) { t.Errorf("Unexpected results from Exec: %v", results) } } + +func TestInsertTimestampArray(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} diff --git a/values.go b/values.go index 33515b41..d199b859 100644 --- a/values.go +++ b/values.go @@ -12,27 +12,28 @@ import ( // PostgreSQL oids for common types const ( - BoolOid = 16 - ByteaOid = 17 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - Float4Oid = 700 - Float8Oid = 701 - BoolArrayOid = 1000 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampTzOid = 1184 + BoolOid = 16 + ByteaOid = 17 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + OidOid = 26 + Float4Oid = 700 + Float8Oid = 701 + BoolArrayOid = 1000 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampArrayOid = 1115 + TimestampTzOid = 1184 ) // PostgreSQL format codes @@ -1597,3 +1598,28 @@ func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error { return nil } + +func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { + slice, ok := value.([]time.Time) + if !ok { + return fmt.Errorf("Expected []time.Time, received %T", value) + } + + size := 20 + len(slice)*12 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(TimestampOid)) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, t := range slice { + w.WriteInt32(8) + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + w.WriteInt64(microsecSinceY2K) + } + + return nil +} From be663f648c30e4d24551ea0c79b4c551b85d1f42 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Sun, 21 Dec 2014 13:40:45 +0700 Subject: [PATCH 3/4] refactor common code for encoding array header --- values.go | 82 +++++++++++-------------------------------------------- 1 file changed, 16 insertions(+), 66 deletions(-) diff --git a/values.go b/values.go index d199b859..38fc9b17 100644 --- a/values.go +++ b/values.go @@ -1240,15 +1240,7 @@ func encodeBoolArray(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []bool, received %T", value) } - size := 20 + len(slice)*5 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(BoolOid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, BoolOid, len(slice), 5) for _, v := range slice { w.WriteInt32(1) var b byte @@ -1267,15 +1259,7 @@ func encodeInt2Array(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []int16, received %T", value) } - size := 20 + len(slice)*6 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Int2Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Int2Oid, len(slice), 6) for _, v := range slice { w.WriteInt32(2) w.WriteInt16(v) @@ -1329,15 +1313,7 @@ func encodeInt4Array(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []int32, received %T", value) } - size := 20 + len(slice)*8 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Int4Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Int4Oid, len(slice), 8) for _, v := range slice { w.WriteInt32(4) w.WriteInt32(v) @@ -1391,15 +1367,7 @@ func encodeInt8Array(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []int64, received %T", value) } - size := 20 + len(slice)*12 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Int8Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Int8Oid, len(slice), 12) for _, v := range slice { w.WriteInt32(8) w.WriteInt64(v) @@ -1453,19 +1421,9 @@ func encodeFloat4Array(w *WriteBuf, value interface{}) error { if !ok { return fmt.Errorf("Expected []float32, received %T", value) } - - size := 20 + len(slice)*8 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Float4Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Float4Oid, len(slice), 8) for _, v := range slice { w.WriteInt32(4) - w.WriteInt32(int32(math.Float32bits(v))) } @@ -1518,18 +1476,9 @@ func encodeFloat8Array(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []float64, received %T", value) } - size := 20 + len(slice)*12 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Float8Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Float8Oid, len(slice), 12) for _, v := range slice { w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(v))) } @@ -1605,15 +1554,7 @@ func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { return fmt.Errorf("Expected []time.Time, received %T", value) } - size := 20 + len(slice)*12 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(TimestampOid)) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, TimestampOid, len(slice), 12) for _, t := range slice { w.WriteInt32(8) microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 @@ -1623,3 +1564,12 @@ func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { return nil } + +func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) { + w.WriteInt32(int32(20 + length*sizePerItem)) + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(oid)) // type of elements + w.WriteInt32(int32(length)) // number of elements + w.WriteInt32(1) // index of first element +} From 109b55f9deda26e3fdff5783b7622814bb7f0442 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Sun, 21 Dec 2014 14:35:38 +0700 Subject: [PATCH 4/4] support decoding of []time.Time and []bool --- query.go | 6 +++ values.go | 122 +++++++++++++++++++++++++++++++++++++++++-------- values_test.go | 60 ++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 19 deletions(-) diff --git a/query.go b/query.go index 835829cc..890f9db1 100644 --- a/query.go +++ b/query.go @@ -243,6 +243,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *d = decodeFloat4(vr) case *float64: *d = decodeFloat8(vr) + case *[]bool: + *d = decodeBoolArray(vr) case *[]int16: *d = decodeInt2Array(vr) case *[]int32: @@ -255,6 +257,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *d = decodeFloat8Array(vr) case *[]string: *d = decodeTextArray(vr) + case *[]time.Time: + *d = decodeTimestampArray(vr) case *time.Time: switch vr.Type().DataType { case DateOid: @@ -324,6 +328,8 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeFloat4(vr)) case Float8Oid: values = append(values, decodeFloat8(vr)) + case BoolArrayOid: + values = append(values, decodeBoolArray(vr)) case Int2ArrayOid: values = append(values, decodeInt2Array(vr)) case Int4ArrayOid: diff --git a/values.go b/values.go index 38fc9b17..79ea6c96 100644 --- a/values.go +++ b/values.go @@ -53,11 +53,13 @@ func init() { DefaultTypeFormats = make(map[string]int16) DefaultTypeFormats["_float4"] = BinaryFormatCode DefaultTypeFormats["_float8"] = BinaryFormatCode + DefaultTypeFormats["_bool"] = BinaryFormatCode DefaultTypeFormats["_int2"] = BinaryFormatCode DefaultTypeFormats["_int4"] = BinaryFormatCode DefaultTypeFormats["_int8"] = BinaryFormatCode DefaultTypeFormats["_text"] = BinaryFormatCode DefaultTypeFormats["_varchar"] = BinaryFormatCode + DefaultTypeFormats["_timestamp"] = BinaryFormatCode DefaultTypeFormats["bool"] = BinaryFormatCode DefaultTypeFormats["bytea"] = BinaryFormatCode DefaultTypeFormats["date"] = BinaryFormatCode @@ -1195,6 +1197,66 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { return length, nil } +func decodeBoolArray(vr *ValueReader) []bool { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != BoolArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]bool, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 1: + if vr.ReadByte() == 1 { + a[i] = true + } + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeBoolArray(w *WriteBuf, value interface{}) error { + slice, ok := value.([]bool) + if !ok { + return fmt.Errorf("Expected []bool, received %T", value) + } + + encodeArrayHeader(w, BoolOid, len(slice), 5) + for _, v := range slice { + w.WriteInt32(1) + var b byte + if v { + b = 1 + } + w.WriteByte(b) + } + + return nil +} + func decodeInt2Array(vr *ValueReader) []int16 { if vr.Len() == -1 { return nil @@ -1234,25 +1296,6 @@ func decodeInt2Array(vr *ValueReader) []int16 { return a } -func encodeBoolArray(w *WriteBuf, value interface{}) error { - slice, ok := value.([]bool) - if !ok { - return fmt.Errorf("Expected []bool, received %T", value) - } - - encodeArrayHeader(w, BoolOid, len(slice), 5) - for _, v := range slice { - w.WriteInt32(1) - var b byte - if v { - b = 1 - } - w.WriteByte(b) - } - - return nil -} - func encodeInt2Array(w *WriteBuf, value interface{}) error { slice, ok := value.([]int16) if !ok { @@ -1548,6 +1591,47 @@ func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error { return nil } +func decodeTimestampArray(vr *ValueReader) []time.Time { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != TimestampArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]time.Time, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + microsecSinceY2K := vr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize))) + return nil + } + } + + return a +} + func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { slice, ok := value.([]time.Time) if !ok { diff --git a/values_test.go b/values_test.go index 9aba9fd4..1e91bd41 100644 --- a/values_test.go +++ b/values_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "fmt" "github.com/jackc/pgx" + "reflect" "strings" "testing" "time" @@ -159,6 +160,65 @@ func TestNullXMismatch(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) + tests := []struct { + sql string + query interface{} + scan interface{} + assert func(*testing.T, interface{}, interface{}) + }{ + { + "select $1::bool[]", []bool{true, false, true}, &[]bool{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]bool))) == false { + t.Errorf("failed to encode bool[]") + } + }, + }, + { + "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]int32))) == false { + t.Errorf("failed to encode int[]") + } + }, + }, + { + "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]string))) == false { + t.Errorf("failed to encode text[]") + } + }, + }, + { + "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + t.Errorf("failed to encode time.Time[]") + } + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("ps%d", i) + mustPrepare(t, conn, psName, tt.sql) + + err := conn.QueryRow(psName, tt.query).Scan(tt.scan) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + tt.assert(t, tt.query, tt.scan) + ensureConnValid(t, conn) + } +} + +func TestArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + type allTypes struct { s pgx.NullString i16 pgx.NullInt16