diff --git a/conn.go b/conn.go index 2c38fd2c..0b06a4aa 100644 --- a/conn.go +++ b/conn.go @@ -549,7 +549,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -593,6 +593,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTimestampTz(wbuf, arguments[i]) case TimestampOid: err = encodeTimestamp(wbuf, arguments[i]) + case BoolArrayOid: + err = encodeBoolArray(wbuf, arguments[i]) case Int2ArrayOid: err = encodeInt2Array(wbuf, arguments[i]) case Int4ArrayOid: @@ -607,6 +609,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTextArray(wbuf, arguments[i], TextOid) case VarcharArrayOid: err = encodeTextArray(wbuf, arguments[i], VarcharOid) + case TimestampArrayOid: + err = encodeTimestampArray(wbuf, arguments[i], VarcharOid) case OidOid: err = encodeOid(wbuf, arguments[i]) default: diff --git a/conn_test.go b/conn_test.go index c6264f77..244b46c6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -515,3 +515,35 @@ func TestCommandTag(t *testing.T) { } } } + +func TestInsertBoolArray(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} + +func TestInsertTimestampArray(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} diff --git a/query.go b/query.go index 835829cc..890f9db1 100644 --- a/query.go +++ b/query.go @@ -243,6 +243,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *d = decodeFloat4(vr) case *float64: *d = decodeFloat8(vr) + case *[]bool: + *d = decodeBoolArray(vr) case *[]int16: *d = decodeInt2Array(vr) case *[]int32: @@ -255,6 +257,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *d = decodeFloat8Array(vr) case *[]string: *d = decodeTextArray(vr) + case *[]time.Time: + *d = decodeTimestampArray(vr) case *time.Time: switch vr.Type().DataType { case DateOid: @@ -324,6 +328,8 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeFloat4(vr)) case Float8Oid: values = append(values, decodeFloat8(vr)) + case BoolArrayOid: + values = append(values, decodeBoolArray(vr)) case Int2ArrayOid: values = append(values, decodeInt2Array(vr)) case Int4ArrayOid: diff --git a/values.go b/values.go index 00091513..79ea6c96 100644 --- a/values.go +++ b/values.go @@ -12,26 +12,28 @@ import ( // PostgreSQL oids for common types const ( - BoolOid = 16 - ByteaOid = 17 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - Float4Oid = 700 - Float8Oid = 701 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampTzOid = 1184 + BoolOid = 16 + ByteaOid = 17 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + OidOid = 26 + Float4Oid = 700 + Float8Oid = 701 + BoolArrayOid = 1000 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampArrayOid = 1115 + TimestampTzOid = 1184 ) // PostgreSQL format codes @@ -51,11 +53,13 @@ func init() { DefaultTypeFormats = make(map[string]int16) DefaultTypeFormats["_float4"] = BinaryFormatCode DefaultTypeFormats["_float8"] = BinaryFormatCode + DefaultTypeFormats["_bool"] = BinaryFormatCode DefaultTypeFormats["_int2"] = BinaryFormatCode DefaultTypeFormats["_int4"] = BinaryFormatCode DefaultTypeFormats["_int8"] = BinaryFormatCode DefaultTypeFormats["_text"] = BinaryFormatCode DefaultTypeFormats["_varchar"] = BinaryFormatCode + DefaultTypeFormats["_timestamp"] = BinaryFormatCode DefaultTypeFormats["bool"] = BinaryFormatCode DefaultTypeFormats["bytea"] = BinaryFormatCode DefaultTypeFormats["date"] = BinaryFormatCode @@ -1193,6 +1197,66 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { return length, nil } +func decodeBoolArray(vr *ValueReader) []bool { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != BoolArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]bool, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 1: + if vr.ReadByte() == 1 { + a[i] = true + } + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeBoolArray(w *WriteBuf, value interface{}) error { + slice, ok := value.([]bool) + if !ok { + return fmt.Errorf("Expected []bool, received %T", value) + } + + encodeArrayHeader(w, BoolOid, len(slice), 5) + for _, v := range slice { + w.WriteInt32(1) + var b byte + if v { + b = 1 + } + w.WriteByte(b) + } + + return nil +} + func decodeInt2Array(vr *ValueReader) []int16 { if vr.Len() == -1 { return nil @@ -1238,15 +1302,7 @@ func encodeInt2Array(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []int16, received %T", value) } - size := 20 + len(slice)*6 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Int2Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Int2Oid, len(slice), 6) for _, v := range slice { w.WriteInt32(2) w.WriteInt16(v) @@ -1300,15 +1356,7 @@ func encodeInt4Array(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []int32, received %T", value) } - size := 20 + len(slice)*8 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Int4Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Int4Oid, len(slice), 8) for _, v := range slice { w.WriteInt32(4) w.WriteInt32(v) @@ -1362,15 +1410,7 @@ func encodeInt8Array(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []int64, received %T", value) } - size := 20 + len(slice)*12 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Int8Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Int8Oid, len(slice), 12) for _, v := range slice { w.WriteInt32(8) w.WriteInt64(v) @@ -1424,19 +1464,9 @@ func encodeFloat4Array(w *WriteBuf, value interface{}) error { if !ok { return fmt.Errorf("Expected []float32, received %T", value) } - - size := 20 + len(slice)*8 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Float4Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Float4Oid, len(slice), 8) for _, v := range slice { w.WriteInt32(4) - w.WriteInt32(int32(math.Float32bits(v))) } @@ -1489,18 +1519,9 @@ func encodeFloat8Array(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected []float64, received %T", value) } - size := 20 + len(slice)*12 - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(Float8Oid) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - + encodeArrayHeader(w, Float8Oid, len(slice), 12) for _, v := range slice { w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(v))) } @@ -1569,3 +1590,70 @@ func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error { return nil } + +func decodeTimestampArray(vr *ValueReader) []time.Time { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != TimestampArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]time.Time, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + microsecSinceY2K := vr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { + slice, ok := value.([]time.Time) + if !ok { + return fmt.Errorf("Expected []time.Time, received %T", value) + } + + encodeArrayHeader(w, TimestampOid, len(slice), 12) + for _, t := range slice { + w.WriteInt32(8) + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + w.WriteInt64(microsecSinceY2K) + } + + return nil +} + +func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) { + w.WriteInt32(int32(20 + length*sizePerItem)) + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(oid)) // type of elements + w.WriteInt32(int32(length)) // number of elements + w.WriteInt32(1) // index of first element +} diff --git a/values_test.go b/values_test.go index 9aba9fd4..1e91bd41 100644 --- a/values_test.go +++ b/values_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "fmt" "github.com/jackc/pgx" + "reflect" "strings" "testing" "time" @@ -159,6 +160,65 @@ func TestNullXMismatch(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) + tests := []struct { + sql string + query interface{} + scan interface{} + assert func(*testing.T, interface{}, interface{}) + }{ + { + "select $1::bool[]", []bool{true, false, true}, &[]bool{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]bool))) == false { + t.Errorf("failed to encode bool[]") + } + }, + }, + { + "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]int32))) == false { + t.Errorf("failed to encode int[]") + } + }, + }, + { + "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]string))) == false { + t.Errorf("failed to encode text[]") + } + }, + }, + { + "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + t.Errorf("failed to encode time.Time[]") + } + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("ps%d", i) + mustPrepare(t, conn, psName, tt.sql) + + err := conn.QueryRow(psName, tt.query).Scan(tt.scan) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + tt.assert(t, tt.query, tt.scan) + ensureConnValid(t, conn) + } +} + +func TestArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + type allTypes struct { s pgx.NullString i16 pgx.NullInt16