Add support for integer, float and text arrays
Restructure internals a bit so pgx/stdlib can turn off binary encoding and receive text back for array types.
This commit is contained in:
@@ -70,9 +70,6 @@ if err != nil {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Prepared statements will use the binary transmission when possible. This can
|
|
||||||
substantially increase performance.
|
|
||||||
|
|
||||||
### Explicit Connection Pool
|
### Explicit Connection Pool
|
||||||
|
|
||||||
Connection pool usage is explicit and configurable. In pgx, a connection can
|
Connection pool usage is explicit and configurable. In pgx, a connection can
|
||||||
@@ -151,9 +148,17 @@ point type.
|
|||||||
pgx includes Null* types in a similar fashion to database/sql that implement the
|
pgx includes Null* types in a similar fashion to database/sql that implement the
|
||||||
necessary interfaces to be encoded and scanned.
|
necessary interfaces to be encoded and scanned.
|
||||||
|
|
||||||
|
### Array Mapping
|
||||||
|
|
||||||
|
pgx maps between int16, int32, int64, float32, float64, and string Go slices
|
||||||
|
and the equivalent PostgreSQL array type. Go slices of native types do not
|
||||||
|
support nulls, so if a PostgreSQL array that contains a slice is read into a
|
||||||
|
native Go slice an error will occur.
|
||||||
|
|
||||||
### Logging
|
### Logging
|
||||||
|
|
||||||
pgx connections optionally accept a logger from the [log15 package](http://gopkg.in/inconshreveable/log15.v2).
|
pgx connections optionally accept a logger from the [log15
|
||||||
|
package](http://gopkg.in/inconshreveable/log15.v2).
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
|
|||||||
@@ -294,10 +294,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||||||
case rowDescription:
|
case rowDescription:
|
||||||
ps.FieldDescriptions = c.rxRowDescription(r)
|
ps.FieldDescriptions = c.rxRowDescription(r)
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.FieldDescriptions {
|
||||||
switch ps.FieldDescriptions[i].DataType {
|
ps.FieldDescriptions[i].FormatCode, _ = DefaultOidFormats[ps.FieldDescriptions[i].DataType]
|
||||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, DateOid, TimestampTzOid:
|
|
||||||
ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case noData:
|
case noData:
|
||||||
case readyForQuery:
|
case readyForQuery:
|
||||||
@@ -474,7 +471,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
default:
|
default:
|
||||||
switch oid {
|
switch oid {
|
||||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid:
|
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid:
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
default:
|
default:
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
@@ -518,6 +515,18 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
err = encodeTimestampTz(wbuf, arguments[i])
|
err = encodeTimestampTz(wbuf, arguments[i])
|
||||||
case TimestampOid:
|
case TimestampOid:
|
||||||
err = encodeTimestamp(wbuf, arguments[i])
|
err = encodeTimestamp(wbuf, arguments[i])
|
||||||
|
case Int2ArrayOid:
|
||||||
|
err = encodeInt2Array(wbuf, arguments[i])
|
||||||
|
case Int4ArrayOid:
|
||||||
|
err = encodeInt4Array(wbuf, arguments[i])
|
||||||
|
case Int8ArrayOid:
|
||||||
|
err = encodeInt8Array(wbuf, arguments[i])
|
||||||
|
case Float4ArrayOid:
|
||||||
|
err = encodeFloat4Array(wbuf, arguments[i])
|
||||||
|
case Float8ArrayOid:
|
||||||
|
err = encodeFloat8Array(wbuf, arguments[i])
|
||||||
|
case TextArrayOid:
|
||||||
|
err = encodeTextArray(wbuf, arguments[i])
|
||||||
default:
|
default:
|
||||||
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement Encoder", arg))
|
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement Encoder", arg))
|
||||||
}
|
}
|
||||||
|
|||||||
+5
-2
@@ -20,10 +20,13 @@ func closeConn(t testing.TB, conn *pgx.Conn) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) {
|
func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) *pgx.PreparedStatement {
|
||||||
if _, err := conn.Prepare(name, sql); err != nil {
|
ps, err := conn.Prepare(name, sql)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Could not prepare %v: %v", name, err)
|
t.Fatalf("Could not prepare %v: %v", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ps
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) {
|
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) {
|
||||||
|
|||||||
@@ -214,6 +214,18 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||||||
*d = decodeFloat4(vr)
|
*d = decodeFloat4(vr)
|
||||||
case *float64:
|
case *float64:
|
||||||
*d = decodeFloat8(vr)
|
*d = decodeFloat8(vr)
|
||||||
|
case *[]int16:
|
||||||
|
*d = decodeInt2Array(vr)
|
||||||
|
case *[]int32:
|
||||||
|
*d = decodeInt4Array(vr)
|
||||||
|
case *[]int64:
|
||||||
|
*d = decodeInt8Array(vr)
|
||||||
|
case *[]float32:
|
||||||
|
*d = decodeFloat4Array(vr)
|
||||||
|
case *[]float64:
|
||||||
|
*d = decodeFloat8Array(vr)
|
||||||
|
case *[]string:
|
||||||
|
*d = decodeTextArray(vr)
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
switch vr.Type().DataType {
|
switch vr.Type().DataType {
|
||||||
case DateOid:
|
case DateOid:
|
||||||
@@ -263,39 +275,50 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
switch vr.Type().DataType {
|
switch vr.Type().FormatCode {
|
||||||
case BoolOid:
|
// All intrinsic types (except string) are encoded with binary
|
||||||
values = append(values, decodeBool(vr))
|
// encoding so anything else should be treated as a string
|
||||||
case ByteaOid:
|
case TextFormatCode:
|
||||||
values = append(values, decodeBytea(vr))
|
values = append(values, vr.ReadString(vr.Len()))
|
||||||
case Int8Oid:
|
case BinaryFormatCode:
|
||||||
values = append(values, decodeInt8(vr))
|
switch vr.Type().DataType {
|
||||||
case Int2Oid:
|
case BoolOid:
|
||||||
values = append(values, decodeInt2(vr))
|
values = append(values, decodeBool(vr))
|
||||||
case Int4Oid:
|
case ByteaOid:
|
||||||
values = append(values, decodeInt4(vr))
|
values = append(values, decodeBytea(vr))
|
||||||
case VarcharOid, TextOid:
|
case Int8Oid:
|
||||||
values = append(values, decodeText(vr))
|
values = append(values, decodeInt8(vr))
|
||||||
case Float4Oid:
|
case Int2Oid:
|
||||||
values = append(values, decodeFloat4(vr))
|
values = append(values, decodeInt2(vr))
|
||||||
case Float8Oid:
|
case Int4Oid:
|
||||||
values = append(values, decodeFloat8(vr))
|
values = append(values, decodeInt4(vr))
|
||||||
case DateOid:
|
case Float4Oid:
|
||||||
values = append(values, decodeDate(vr))
|
values = append(values, decodeFloat4(vr))
|
||||||
case TimestampTzOid:
|
case Float8Oid:
|
||||||
values = append(values, decodeTimestampTz(vr))
|
values = append(values, decodeFloat8(vr))
|
||||||
case TimestampOid:
|
case Int2ArrayOid:
|
||||||
values = append(values, decodeTimestamp(vr))
|
values = append(values, decodeInt2Array(vr))
|
||||||
default:
|
case Int4ArrayOid:
|
||||||
// if it is not an intrinsic type then return the text
|
values = append(values, decodeInt4Array(vr))
|
||||||
switch vr.Type().FormatCode {
|
case Int8ArrayOid:
|
||||||
case TextFormatCode:
|
values = append(values, decodeInt8Array(vr))
|
||||||
values = append(values, vr.ReadString(vr.Len()))
|
case Float4ArrayOid:
|
||||||
case BinaryFormatCode:
|
values = append(values, decodeFloat4Array(vr))
|
||||||
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
|
case Float8ArrayOid:
|
||||||
|
values = append(values, decodeFloat8Array(vr))
|
||||||
|
case TextArrayOid:
|
||||||
|
values = append(values, decodeTextArray(vr))
|
||||||
|
case DateOid:
|
||||||
|
values = append(values, decodeDate(vr))
|
||||||
|
case TimestampTzOid:
|
||||||
|
values = append(values, decodeTimestampTz(vr))
|
||||||
|
case TimestampOid:
|
||||||
|
values = append(values, decodeTimestamp(vr))
|
||||||
default:
|
default:
|
||||||
rows.Fatal(errors.New("Unknown format code"))
|
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
rows.Fatal(errors.New("Unknown format code"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if vr.Err() != nil {
|
if vr.Err() != nil {
|
||||||
|
|||||||
+284
@@ -376,6 +376,8 @@ func TestQueryRowCoreTypes(t *testing.T) {
|
|||||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||||
t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql)
|
t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -486,3 +488,285 @@ func TestQueryRowNoResults(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryRowCoreInt16Slice(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var actual []int16
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
expected []int16
|
||||||
|
}{
|
||||||
|
{"select $1::int2[]", []int16{1, 2, 3, 4, 5}},
|
||||||
|
{"select $1::int2[]", []int16{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(actual) != len(tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < len(actual); j++ {
|
||||||
|
if actual[j] != tt.expected[j] {
|
||||||
|
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||||
|
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||||
|
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowCoreInt32Slice(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var actual []int32
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
expected []int32
|
||||||
|
}{
|
||||||
|
{"select $1::int4[]", []int32{1, 2, 3, 4, 5}},
|
||||||
|
{"select $1::int4[]", []int32{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(actual) != len(tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < len(actual); j++ {
|
||||||
|
if actual[j] != tt.expected[j] {
|
||||||
|
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||||
|
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||||
|
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowCoreInt64Slice(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var actual []int64
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
expected []int64
|
||||||
|
}{
|
||||||
|
{"select $1::int8[]", []int64{1, 2, 3, 4, 5}},
|
||||||
|
{"select $1::int8[]", []int64{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(actual) != len(tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < len(actual); j++ {
|
||||||
|
if actual[j] != tt.expected[j] {
|
||||||
|
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||||
|
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||||
|
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowCoreFloat32Slice(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var actual []float32
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
expected []float32
|
||||||
|
}{
|
||||||
|
{"select $1::float4[]", []float32{1.5, 2.0, 3.5}},
|
||||||
|
{"select $1::float4[]", []float32{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(actual) != len(tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < len(actual); j++ {
|
||||||
|
if actual[j] != tt.expected[j] {
|
||||||
|
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||||
|
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||||
|
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowCoreFloat64Slice(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var actual []float64
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
expected []float64
|
||||||
|
}{
|
||||||
|
{"select $1::float8[]", []float64{1.5, 2.0, 3.5}},
|
||||||
|
{"select $1::float8[]", []float64{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(actual) != len(tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < len(actual); j++ {
|
||||||
|
if actual[j] != tt.expected[j] {
|
||||||
|
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||||
|
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||||
|
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowCoreStringSlice(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var actual []string
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
|
||||||
|
{"select $1::text[]", []string{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(actual) != len(tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < len(actual); j++ {
|
||||||
|
if actual[j] != tt.expected[j] {
|
||||||
|
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||||
|
err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||||
|
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|||||||
+46
-2
@@ -54,9 +54,24 @@ import (
|
|||||||
|
|
||||||
var openFromConnPoolCount int
|
var openFromConnPoolCount int
|
||||||
|
|
||||||
|
// oids that map to intrinsic database/sql types. These will be allowed to be
|
||||||
|
// binary, anything else will be forced to text format
|
||||||
|
var databaseSqlOids map[pgx.Oid]bool
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
d := &Driver{}
|
d := &Driver{}
|
||||||
sql.Register("pgx", d)
|
sql.Register("pgx", d)
|
||||||
|
|
||||||
|
databaseSqlOids = make(map[pgx.Oid]bool)
|
||||||
|
databaseSqlOids[pgx.BoolOid] = true
|
||||||
|
databaseSqlOids[pgx.ByteaOid] = true
|
||||||
|
databaseSqlOids[pgx.Int2Oid] = true
|
||||||
|
databaseSqlOids[pgx.Int4Oid] = true
|
||||||
|
databaseSqlOids[pgx.Int8Oid] = true
|
||||||
|
databaseSqlOids[pgx.Float4Oid] = true
|
||||||
|
databaseSqlOids[pgx.Float8Oid] = true
|
||||||
|
databaseSqlOids[pgx.DateOid] = true
|
||||||
|
databaseSqlOids[pgx.TimestampTzOid] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
type Driver struct {
|
type Driver struct {
|
||||||
@@ -136,6 +151,8 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
restrictBinaryToDatabaseSqlTypes(ps)
|
||||||
|
|
||||||
return &Stmt{ps: ps, conn: c}, nil
|
return &Stmt{ps: ps, conn: c}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,9 +193,24 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
|||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ps, err := c.conn.Prepare("", query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
restrictBinaryToDatabaseSqlTypes(ps)
|
||||||
|
|
||||||
|
return c.queryPrepared("", argsV)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
||||||
|
if !c.conn.IsAlive() {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
|
||||||
args := valueToInterface(argsV)
|
args := valueToInterface(argsV)
|
||||||
|
|
||||||
rows, err := c.conn.Query(query, args...)
|
rows, err := c.conn.Query(name, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -186,6 +218,18 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
|||||||
return &Rows{rows: rows}, nil
|
return &Rows{rows: rows}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Anything that isn't a database/sql compatible type needs to be forced to
|
||||||
|
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
||||||
|
// (e.g. []int32)
|
||||||
|
func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
|
||||||
|
for i, _ := range ps.FieldDescriptions {
|
||||||
|
intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType]
|
||||||
|
if !intrinsic {
|
||||||
|
ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Stmt struct {
|
type Stmt struct {
|
||||||
ps *pgx.PreparedStatement
|
ps *pgx.PreparedStatement
|
||||||
conn *Conn
|
conn *Conn
|
||||||
@@ -204,7 +248,7 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
||||||
return s.conn.Query(s.ps.Name, argsV)
|
return s.conn.queryPrepared(s.ps.Name, argsV)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - rename to avoid alloc
|
// TODO - rename to avoid alloc
|
||||||
|
|||||||
@@ -336,6 +336,28 @@ func TestConnQueryFailure(t *testing.T) {
|
|||||||
ensureConnValid(t, db)
|
ensureConnValid(t, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test type that pgx would handle natively in binary, but since it is not a
|
||||||
|
// database/sql native type should be passed through as a string
|
||||||
|
func TestConnQueryRowPgxBinary(t *testing.T) {
|
||||||
|
db := openDB(t)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
sql := "select $1::int4[]"
|
||||||
|
expected := "{1,2,3}"
|
||||||
|
var actual string
|
||||||
|
|
||||||
|
err := db.QueryRow(sql, expected).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, db)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnQueryRowUnknownType(t *testing.T) {
|
func TestConnQueryRowUnknownType(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|||||||
@@ -19,6 +19,12 @@ const (
|
|||||||
TextOid = 25
|
TextOid = 25
|
||||||
Float4Oid = 700
|
Float4Oid = 700
|
||||||
Float8Oid = 701
|
Float8Oid = 701
|
||||||
|
Int2ArrayOid = 1005
|
||||||
|
Int4ArrayOid = 1007
|
||||||
|
TextArrayOid = 1009
|
||||||
|
Int8ArrayOid = 1016
|
||||||
|
Float4ArrayOid = 1021
|
||||||
|
Float8ArrayOid = 1022
|
||||||
VarcharOid = 1043
|
VarcharOid = 1043
|
||||||
DateOid = 1082
|
DateOid = 1082
|
||||||
TimestampOid = 1114
|
TimestampOid = 1114
|
||||||
@@ -31,6 +37,27 @@ const (
|
|||||||
BinaryFormatCode = 1
|
BinaryFormatCode = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var DefaultOidFormats map[Oid]int16
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
DefaultOidFormats = make(map[Oid]int16)
|
||||||
|
DefaultOidFormats[BoolOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[ByteaOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Int2Oid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Int4Oid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Int8Oid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Float4Oid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Float8Oid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[DateOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[TimestampTzOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Int2ArrayOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Int4ArrayOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Int8ArrayOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Float4ArrayOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[Float8ArrayOid] = BinaryFormatCode
|
||||||
|
DefaultOidFormats[TextArrayOid] = BinaryFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
type SerializationError string
|
type SerializationError string
|
||||||
|
|
||||||
func (e SerializationError) Error() string {
|
func (e SerializationError) Error() string {
|
||||||
@@ -945,3 +972,406 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
||||||
|
numDims := vr.ReadInt32()
|
||||||
|
if numDims == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if numDims != 1 {
|
||||||
|
return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims))
|
||||||
|
}
|
||||||
|
|
||||||
|
vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care
|
||||||
|
vr.ReadInt32() // element oid
|
||||||
|
|
||||||
|
length = vr.ReadInt32()
|
||||||
|
|
||||||
|
idxFirstElem := vr.ReadInt32()
|
||||||
|
if idxFirstElem != 1 {
|
||||||
|
return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem))
|
||||||
|
}
|
||||||
|
|
||||||
|
return length, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInt2Array(vr *ValueReader) []int16 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != Int2ArrayOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2ArrayOid, vr.Type().DataType)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
numElems, err := decode1dArrayHeader(vr)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]int16, int(numElems))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
elSize := vr.ReadInt32()
|
||||||
|
switch elSize {
|
||||||
|
case 2:
|
||||||
|
a[i] = vr.ReadInt16()
|
||||||
|
case -1:
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInt2Array(w *WriteBuf, value interface{}) error {
|
||||||
|
slice, ok := value.([]int16)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("Expected []int16, received %T", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := 20 + len(slice)*6
|
||||||
|
w.WriteInt32(int32(size))
|
||||||
|
|
||||||
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
w.WriteInt32(0) // no nulls
|
||||||
|
w.WriteInt32(Int2Oid) // type of elements
|
||||||
|
w.WriteInt32(int32(len(slice))) // number of elements
|
||||||
|
w.WriteInt32(1) // index of first element
|
||||||
|
|
||||||
|
for _, v := range slice {
|
||||||
|
w.WriteInt32(2)
|
||||||
|
w.WriteInt16(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInt4Array(vr *ValueReader) []int32 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != Int4ArrayOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4ArrayOid, vr.Type().DataType)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
numElems, err := decode1dArrayHeader(vr)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]int32, int(numElems))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
elSize := vr.ReadInt32()
|
||||||
|
switch elSize {
|
||||||
|
case 4:
|
||||||
|
a[i] = vr.ReadInt32()
|
||||||
|
case -1:
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInt4Array(w *WriteBuf, value interface{}) error {
|
||||||
|
slice, ok := value.([]int32)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("Expected []int32, received %T", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := 20 + len(slice)*8
|
||||||
|
w.WriteInt32(int32(size))
|
||||||
|
|
||||||
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
w.WriteInt32(0) // no nulls
|
||||||
|
w.WriteInt32(Int4Oid) // type of elements
|
||||||
|
w.WriteInt32(int32(len(slice))) // number of elements
|
||||||
|
w.WriteInt32(1) // index of first element
|
||||||
|
|
||||||
|
for _, v := range slice {
|
||||||
|
w.WriteInt32(4)
|
||||||
|
w.WriteInt32(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInt8Array(vr *ValueReader) []int64 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != Int8ArrayOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8ArrayOid, vr.Type().DataType)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
numElems, err := decode1dArrayHeader(vr)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]int64, int(numElems))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
elSize := vr.ReadInt32()
|
||||||
|
switch elSize {
|
||||||
|
case 8:
|
||||||
|
a[i] = vr.ReadInt64()
|
||||||
|
case -1:
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInt8Array(w *WriteBuf, value interface{}) error {
|
||||||
|
slice, ok := value.([]int64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("Expected []int64, received %T", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := 20 + len(slice)*12
|
||||||
|
w.WriteInt32(int32(size))
|
||||||
|
|
||||||
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
w.WriteInt32(0) // no nulls
|
||||||
|
w.WriteInt32(Int8Oid) // type of elements
|
||||||
|
w.WriteInt32(int32(len(slice))) // number of elements
|
||||||
|
w.WriteInt32(1) // index of first element
|
||||||
|
|
||||||
|
for _, v := range slice {
|
||||||
|
w.WriteInt32(8)
|
||||||
|
w.WriteInt64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeFloat4Array(vr *ValueReader) []float32 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != Float4ArrayOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Float4ArrayOid, vr.Type().DataType)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
numElems, err := decode1dArrayHeader(vr)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]float32, int(numElems))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
elSize := vr.ReadInt32()
|
||||||
|
switch elSize {
|
||||||
|
case 4:
|
||||||
|
n := vr.ReadInt32()
|
||||||
|
p := unsafe.Pointer(&n)
|
||||||
|
a[i] = *(*float32)(p)
|
||||||
|
case -1:
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeFloat4Array(w *WriteBuf, value interface{}) error {
|
||||||
|
slice, ok := value.([]float32)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("Expected []float32, received %T", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := 20 + len(slice)*8
|
||||||
|
w.WriteInt32(int32(size))
|
||||||
|
|
||||||
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
w.WriteInt32(0) // no nulls
|
||||||
|
w.WriteInt32(Float4Oid) // type of elements
|
||||||
|
w.WriteInt32(int32(len(slice))) // number of elements
|
||||||
|
w.WriteInt32(1) // index of first element
|
||||||
|
|
||||||
|
for _, v := range slice {
|
||||||
|
w.WriteInt32(4)
|
||||||
|
|
||||||
|
p := unsafe.Pointer(&v)
|
||||||
|
w.WriteInt32(*(*int32)(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeFloat8Array(vr *ValueReader) []float64 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != Float8ArrayOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Float8ArrayOid, vr.Type().DataType)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
numElems, err := decode1dArrayHeader(vr)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]float64, int(numElems))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
elSize := vr.ReadInt32()
|
||||||
|
switch elSize {
|
||||||
|
case 8:
|
||||||
|
n := vr.ReadInt64()
|
||||||
|
p := unsafe.Pointer(&n)
|
||||||
|
a[i] = *(*float64)(p)
|
||||||
|
case -1:
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeFloat8Array(w *WriteBuf, value interface{}) error {
|
||||||
|
slice, ok := value.([]float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("Expected []float64, received %T", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := 20 + len(slice)*12
|
||||||
|
w.WriteInt32(int32(size))
|
||||||
|
|
||||||
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
w.WriteInt32(0) // no nulls
|
||||||
|
w.WriteInt32(Float8Oid) // type of elements
|
||||||
|
w.WriteInt32(int32(len(slice))) // number of elements
|
||||||
|
w.WriteInt32(1) // index of first element
|
||||||
|
|
||||||
|
for _, v := range slice {
|
||||||
|
w.WriteInt32(8)
|
||||||
|
|
||||||
|
p := unsafe.Pointer(&v)
|
||||||
|
w.WriteInt64(*(*int64)(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeTextArray(vr *ValueReader) []string {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != TextArrayOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TextArrayOid, vr.Type().DataType)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
numElems, err := decode1dArrayHeader(vr)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]string, int(numElems))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
elSize := vr.ReadInt32()
|
||||||
|
if elSize == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a[i] = vr.ReadString(elSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeTextArray(w *WriteBuf, value interface{}) error {
|
||||||
|
slice, ok := value.([]string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("Expected []string, received %T", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalStringSize int
|
||||||
|
for _, v := range slice {
|
||||||
|
totalStringSize += len(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := 20 + len(slice)*4 + totalStringSize
|
||||||
|
w.WriteInt32(int32(size))
|
||||||
|
|
||||||
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
w.WriteInt32(0) // no nulls
|
||||||
|
w.WriteInt32(TextOid) // type of elements
|
||||||
|
w.WriteInt32(int32(len(slice))) // number of elements
|
||||||
|
w.WriteInt32(1) // index of first element
|
||||||
|
|
||||||
|
for _, v := range slice {
|
||||||
|
w.WriteInt32(int32(len(v)))
|
||||||
|
w.WriteBytes([]byte(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user