2
0

Add timestamptz[] support

This commit is contained in:
Jack Christensen
2014-12-23 18:15:12 -06:00
parent 191c37dfa6
commit d77e599ce6
4 changed files with 40 additions and 28 deletions
+4 -2
View File
@@ -549,7 +549,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(TextFormatCode) wbuf.WriteInt16(TextFormatCode)
default: default:
switch oid { switch oid {
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid:
wbuf.WriteInt16(BinaryFormatCode) wbuf.WriteInt16(BinaryFormatCode)
default: default:
wbuf.WriteInt16(TextFormatCode) wbuf.WriteInt16(TextFormatCode)
@@ -610,7 +610,9 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
case VarcharArrayOid: case VarcharArrayOid:
err = encodeTextArray(wbuf, arguments[i], VarcharOid) err = encodeTextArray(wbuf, arguments[i], VarcharOid)
case TimestampArrayOid: case TimestampArrayOid:
err = encodeTimestampArray(wbuf, arguments[i], VarcharOid) err = encodeTimestampArray(wbuf, arguments[i], TimestampOid)
case TimestampTzArrayOid:
err = encodeTimestampArray(wbuf, arguments[i], TimestampTzOid)
case OidOid: case OidOid:
err = encodeOid(wbuf, arguments[i]) err = encodeOid(wbuf, arguments[i])
default: default:
+1 -1
View File
@@ -342,7 +342,7 @@ func (rows *Rows) Values() ([]interface{}, error) {
values = append(values, decodeFloat8Array(vr)) values = append(values, decodeFloat8Array(vr))
case TextArrayOid, VarcharArrayOid: case TextArrayOid, VarcharArrayOid:
values = append(values, decodeTextArray(vr)) values = append(values, decodeTextArray(vr))
case TimestampArrayOid: case TimestampArrayOid, TimestampTzArrayOid:
values = append(values, decodeTimestampArray(vr)) values = append(values, decodeTimestampArray(vr))
case DateOid: case DateOid:
values = append(values, decodeDate(vr)) values = append(values, decodeDate(vr))
+26 -24
View File
@@ -12,28 +12,29 @@ import (
// PostgreSQL oids for common types // PostgreSQL oids for common types
const ( const (
BoolOid = 16 BoolOid = 16
ByteaOid = 17 ByteaOid = 17
Int8Oid = 20 Int8Oid = 20
Int2Oid = 21 Int2Oid = 21
Int4Oid = 23 Int4Oid = 23
TextOid = 25 TextOid = 25
OidOid = 26 OidOid = 26
Float4Oid = 700 Float4Oid = 700
Float8Oid = 701 Float8Oid = 701
BoolArrayOid = 1000 BoolArrayOid = 1000
Int2ArrayOid = 1005 Int2ArrayOid = 1005
Int4ArrayOid = 1007 Int4ArrayOid = 1007
TextArrayOid = 1009 TextArrayOid = 1009
VarcharArrayOid = 1015 VarcharArrayOid = 1015
Int8ArrayOid = 1016 Int8ArrayOid = 1016
Float4ArrayOid = 1021 Float4ArrayOid = 1021
Float8ArrayOid = 1022 Float8ArrayOid = 1022
VarcharOid = 1043 VarcharOid = 1043
DateOid = 1082 DateOid = 1082
TimestampOid = 1114 TimestampOid = 1114
TimestampArrayOid = 1115 TimestampArrayOid = 1115
TimestampTzOid = 1184 TimestampTzOid = 1184
TimestampTzArrayOid = 1185
) )
// PostgreSQL format codes // PostgreSQL format codes
@@ -60,6 +61,7 @@ func init() {
DefaultTypeFormats["_text"] = BinaryFormatCode DefaultTypeFormats["_text"] = BinaryFormatCode
DefaultTypeFormats["_varchar"] = BinaryFormatCode DefaultTypeFormats["_varchar"] = BinaryFormatCode
DefaultTypeFormats["_timestamp"] = BinaryFormatCode DefaultTypeFormats["_timestamp"] = BinaryFormatCode
DefaultTypeFormats["_timestamptz"] = BinaryFormatCode
DefaultTypeFormats["bool"] = BinaryFormatCode DefaultTypeFormats["bool"] = BinaryFormatCode
DefaultTypeFormats["bytea"] = BinaryFormatCode DefaultTypeFormats["bytea"] = BinaryFormatCode
DefaultTypeFormats["date"] = BinaryFormatCode DefaultTypeFormats["date"] = BinaryFormatCode
@@ -1596,7 +1598,7 @@ func decodeTimestampArray(vr *ValueReader) []time.Time {
return nil return nil
} }
if vr.Type().DataType != TimestampArrayOid { if vr.Type().DataType != TimestampArrayOid && vr.Type().DataType != TimestampTzArrayOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType)))
return nil return nil
} }
@@ -1638,7 +1640,7 @@ func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error {
return fmt.Errorf("Expected []time.Time, received %T", value) return fmt.Errorf("Expected []time.Time, received %T", value)
} }
encodeArrayHeader(w, TimestampOid, len(slice), 12) encodeArrayHeader(w, int(elOid), len(slice), 12)
for _, t := range slice { for _, t := range slice {
w.WriteInt32(8) w.WriteInt32(8)
microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000
+9 -1
View File
@@ -194,7 +194,15 @@ func TestArrayDecoding(t *testing.T) {
"select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
func(t *testing.T, query, scan interface{}) { func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
t.Errorf("failed to encode time.Time[]") t.Errorf("failed to encode time.Time[] to timestamp[]")
}
},
},
{
"select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
func(t *testing.T, query, scan interface{}) {
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
t.Errorf("failed to encode time.Time[] to timestamptz[]")
} }
}, },
}, },