+2
-4
@@ -1,15 +1,13 @@
|
|||||||
# Unreleased
|
# Unreleased
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
* Add PrepareEx
|
|
||||||
|
|
||||||
## Fixes
|
## Fixes
|
||||||
|
|
||||||
* Fix *ConnPool.Deallocate() not deleting prepared statement from map
|
* Fix *ConnPool.Deallocate() not deleting prepared statement from map
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
* Add PrepareEx
|
||||||
|
* Add basic record to []interface{} decoding
|
||||||
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
|
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
|
||||||
* Decode inet/cidr to net.IP
|
* Decode inet/cidr to net.IP
|
||||||
* Encode/decode [][]byte to/from bytea[]
|
* Encode/decode [][]byte to/from bytea[]
|
||||||
|
|||||||
@@ -332,7 +332,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) loadPgTypes() error {
|
func (c *Conn) loadPgTypes() error {
|
||||||
rows, err := c.Query("select t.oid, t.typname from pg_type t left join pg_type base_type on t.typelem=base_type.oid where t.typtype='b' and (base_type.oid is null or base_type.typtype='b');")
|
rows, err := c.Query(`select t.oid, t.typname
|
||||||
|
from pg_type t
|
||||||
|
left join pg_type base_type on t.typelem=base_type.oid
|
||||||
|
where (
|
||||||
|
t.typtype='b'
|
||||||
|
and (base_type.oid is null or base_type.typtype='b')
|
||||||
|
)
|
||||||
|
or t.typname in('record');`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -910,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
default:
|
default:
|
||||||
switch oid {
|
switch oid {
|
||||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid:
|
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid:
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
default:
|
default:
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ const (
|
|||||||
CidrArrayOid = 651
|
CidrArrayOid = 651
|
||||||
Float4Oid = 700
|
Float4Oid = 700
|
||||||
Float8Oid = 701
|
Float8Oid = 701
|
||||||
|
UnknownOid = 705
|
||||||
InetOid = 869
|
InetOid = 869
|
||||||
BoolArrayOid = 1000
|
BoolArrayOid = 1000
|
||||||
Int2ArrayOid = 1005
|
Int2ArrayOid = 1005
|
||||||
@@ -44,6 +45,7 @@ const (
|
|||||||
TimestampArrayOid = 1115
|
TimestampArrayOid = 1115
|
||||||
TimestampTzOid = 1184
|
TimestampTzOid = 1184
|
||||||
TimestampTzArrayOid = 1185
|
TimestampTzArrayOid = 1185
|
||||||
|
RecordOid = 2249
|
||||||
UuidOid = 2950
|
UuidOid = 2950
|
||||||
JsonbOid = 3802
|
JsonbOid = 3802
|
||||||
)
|
)
|
||||||
@@ -91,8 +93,11 @@ func init() {
|
|||||||
"int4": BinaryFormatCode,
|
"int4": BinaryFormatCode,
|
||||||
"int8": BinaryFormatCode,
|
"int8": BinaryFormatCode,
|
||||||
"oid": BinaryFormatCode,
|
"oid": BinaryFormatCode,
|
||||||
|
"record": BinaryFormatCode,
|
||||||
|
"text": BinaryFormatCode,
|
||||||
"timestamp": BinaryFormatCode,
|
"timestamp": BinaryFormatCode,
|
||||||
"timestamptz": BinaryFormatCode,
|
"timestamptz": BinaryFormatCode,
|
||||||
|
"varchar": BinaryFormatCode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -807,6 +812,8 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||||||
*v = decodeTimestampArray(vr)
|
*v = decodeTimestampArray(vr)
|
||||||
case *[][]byte:
|
case *[][]byte:
|
||||||
*v = decodeByteaArray(vr)
|
*v = decodeByteaArray(vr)
|
||||||
|
case *[]interface{}:
|
||||||
|
*v = decodeRecord(vr)
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
switch vr.Type().DataType {
|
switch vr.Type().DataType {
|
||||||
case DateOid:
|
case DateOid:
|
||||||
@@ -1613,6 +1620,77 @@ func encodeIP(w *WriteBuf, oid Oid, value net.IP) error {
|
|||||||
return encodeIPNet(w, oid, ipnet)
|
return encodeIPNet(w, oid, ipnet)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeRecord(vr *ValueReader) []interface{} {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != RecordOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
valueCount := vr.ReadInt32()
|
||||||
|
record := make([]interface{}, 0, int(valueCount))
|
||||||
|
|
||||||
|
for i := int32(0); i < valueCount; i++ {
|
||||||
|
fd := FieldDescription{FormatCode: BinaryFormatCode}
|
||||||
|
fieldVR := ValueReader{mr: vr.mr, fd: &fd}
|
||||||
|
fd.DataType = vr.ReadOid()
|
||||||
|
fieldVR.valueBytesRemaining = vr.ReadInt32()
|
||||||
|
vr.valueBytesRemaining -= fieldVR.valueBytesRemaining
|
||||||
|
|
||||||
|
switch fd.DataType {
|
||||||
|
case BoolOid:
|
||||||
|
record = append(record, decodeBool(&fieldVR))
|
||||||
|
case ByteaOid:
|
||||||
|
record = append(record, decodeBytea(&fieldVR))
|
||||||
|
case Int8Oid:
|
||||||
|
record = append(record, decodeInt8(&fieldVR))
|
||||||
|
case Int2Oid:
|
||||||
|
record = append(record, decodeInt2(&fieldVR))
|
||||||
|
case Int4Oid:
|
||||||
|
record = append(record, decodeInt4(&fieldVR))
|
||||||
|
case OidOid:
|
||||||
|
record = append(record, decodeOid(&fieldVR))
|
||||||
|
case Float4Oid:
|
||||||
|
record = append(record, decodeFloat4(&fieldVR))
|
||||||
|
case Float8Oid:
|
||||||
|
record = append(record, decodeFloat8(&fieldVR))
|
||||||
|
case DateOid:
|
||||||
|
record = append(record, decodeDate(&fieldVR))
|
||||||
|
case TimestampTzOid:
|
||||||
|
record = append(record, decodeTimestampTz(&fieldVR))
|
||||||
|
case TimestampOid:
|
||||||
|
record = append(record, decodeTimestamp(&fieldVR))
|
||||||
|
case InetOid, CidrOid:
|
||||||
|
record = append(record, decodeInet(&fieldVR))
|
||||||
|
case TextOid, VarcharOid, UnknownOid:
|
||||||
|
record = append(record, decodeText(&fieldVR))
|
||||||
|
default:
|
||||||
|
vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume any remaining data
|
||||||
|
if fieldVR.Len() > 0 {
|
||||||
|
fieldVR.ReadBytes(fieldVR.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
if fieldVR.Err() != nil {
|
||||||
|
vr.Fatal(fieldVR.Err())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return record
|
||||||
|
}
|
||||||
|
|
||||||
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
||||||
numDims := vr.ReadInt32()
|
numDims := vr.ReadInt32()
|
||||||
if numDims > 1 {
|
if numDims > 1 {
|
||||||
|
|||||||
@@ -959,3 +959,40 @@ func TestPointerPointerNonZero(t *testing.T) {
|
|||||||
t.Errorf("Expected dest to be nil, got %#v", dest)
|
t.Errorf("Expected dest to be nil, got %#v", dest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRowDecode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
expected []interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"select row(1, 'cat', '2015-01-01 08:12:42'::timestamptz)",
|
||||||
|
[]interface{}{
|
||||||
|
int32(1),
|
||||||
|
"cat",
|
||||||
|
time.Date(2015, 1, 1, 8, 12, 42, 0, time.Local),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
var actual []interface{}
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual, tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user