diff --git a/conn.go b/conn.go index 26d99169..0b06a4aa 100644 --- a/conn.go +++ b/conn.go @@ -549,7 +549,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -609,6 +609,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTextArray(wbuf, arguments[i], TextOid) case VarcharArrayOid: err = encodeTextArray(wbuf, arguments[i], VarcharOid) + case TimestampArrayOid: + err = encodeTimestampArray(wbuf, arguments[i], VarcharOid) case OidOid: err = encodeOid(wbuf, arguments[i]) default: diff --git a/conn_test.go b/conn_test.go index 9aed24e1..244b46c6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -531,3 +531,19 @@ func TestInsertBoolArray(t *testing.T) { t.Errorf("Unexpected results from Exec: %v", results) } } + +func TestInsertTimestampArray(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} diff --git a/values.go b/values.go index 33515b41..d199b859 100644 --- a/values.go +++ b/values.go @@ -12,27 +12,28 @@ import ( // PostgreSQL oids for common types const ( - BoolOid = 16 - ByteaOid = 17 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - Float4Oid = 700 - Float8Oid = 701 - BoolArrayOid = 1000 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampTzOid = 1184 + BoolOid = 16 + ByteaOid = 17 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + OidOid = 26 + Float4Oid = 700 + Float8Oid = 701 + BoolArrayOid = 1000 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampArrayOid = 1115 + TimestampTzOid = 1184 ) // PostgreSQL format codes @@ -1597,3 +1598,28 @@ func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error { return nil } + +func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { + slice, ok := value.([]time.Time) + if !ok { + return fmt.Errorf("Expected []time.Time, received %T", value) + } + + size := 20 + len(slice)*12 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(TimestampOid)) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, t := range slice { + w.WriteInt32(8) + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + w.WriteInt64(microsecSinceY2K) + } + + return nil +}