From e29574d4470eaab7297046133b11f77dc9dda707 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jul 2014 15:03:52 -0500 Subject: [PATCH] Add support for integer, float and text arrays Restructure internals a bit so pgx/stdlib can turn off binary encoding and receive text back for array types. --- README.md | 13 +- conn.go | 19 +- helper_test.go | 7 +- query.go | 85 +++++---- query_test.go | 284 ++++++++++++++++++++++++++++++ stdlib/sql.go | 48 ++++- stdlib/sql_test.go | 22 +++ values.go | 430 +++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 864 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 638e710b..3f79879c 100644 --- a/README.md +++ b/README.md @@ -70,9 +70,6 @@ if err != nil { } ``` -Prepared statements will use the binary transmission when possible. This can -substantially increase performance. - ### Explicit Connection Pool Connection pool usage is explicit and configurable. In pgx, a connection can @@ -151,9 +148,17 @@ point type. pgx includes Null* types in a similar fashion to database/sql that implement the necessary interfaces to be encoded and scanned. +### Array Mapping + +pgx maps between int16, int32, int64, float32, float64, and string Go slices +and the equivalent PostgreSQL array type. Go slices of native types do not +support nulls, so if a PostgreSQL array that contains a slice is read into a +native Go slice an error will occur. + ### Logging -pgx connections optionally accept a logger from the [log15 package](http://gopkg.in/inconshreveable/log15.v2). +pgx connections optionally accept a logger from the [log15 +package](http://gopkg.in/inconshreveable/log15.v2). ## Testing diff --git a/conn.go b/conn.go index 03a2221d..10c17345 100644 --- a/conn.go +++ b/conn.go @@ -294,10 +294,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { case rowDescription: ps.FieldDescriptions = c.rxRowDescription(r) for i := range ps.FieldDescriptions { - switch ps.FieldDescriptions[i].DataType { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, DateOid, TimestampTzOid: - ps.FieldDescriptions[i].FormatCode = BinaryFormatCode - } + ps.FieldDescriptions[i].FormatCode, _ = DefaultOidFormats[ps.FieldDescriptions[i].DataType] } case noData: case readyForQuery: @@ -474,7 +471,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -518,6 +515,18 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTimestampTz(wbuf, arguments[i]) case TimestampOid: err = encodeTimestamp(wbuf, arguments[i]) + case Int2ArrayOid: + err = encodeInt2Array(wbuf, arguments[i]) + case Int4ArrayOid: + err = encodeInt4Array(wbuf, arguments[i]) + case Int8ArrayOid: + err = encodeInt8Array(wbuf, arguments[i]) + case Float4ArrayOid: + err = encodeFloat4Array(wbuf, arguments[i]) + case Float8ArrayOid: + err = encodeFloat8Array(wbuf, arguments[i]) + case TextArrayOid: + err = encodeTextArray(wbuf, arguments[i]) default: return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement Encoder", arg)) } diff --git a/helper_test.go b/helper_test.go index 039b5811..570e3b8e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -20,10 +20,13 @@ func closeConn(t testing.TB, conn *pgx.Conn) { } } -func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) { - if _, err := conn.Prepare(name, sql); err != nil { +func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) *pgx.PreparedStatement { + ps, err := conn.Prepare(name, sql) + if err != nil { t.Fatalf("Could not prepare %v: %v", name, err) } + + return ps } func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) { diff --git a/query.go b/query.go index 96e6265a..033f0a2e 100644 --- a/query.go +++ b/query.go @@ -214,6 +214,18 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *d = decodeFloat4(vr) case *float64: *d = decodeFloat8(vr) + case *[]int16: + *d = decodeInt2Array(vr) + case *[]int32: + *d = decodeInt4Array(vr) + case *[]int64: + *d = decodeInt8Array(vr) + case *[]float32: + *d = decodeFloat4Array(vr) + case *[]float64: + *d = decodeFloat8Array(vr) + case *[]string: + *d = decodeTextArray(vr) case *time.Time: switch vr.Type().DataType { case DateOid: @@ -263,39 +275,50 @@ func (rows *Rows) Values() ([]interface{}, error) { continue } - switch vr.Type().DataType { - case BoolOid: - values = append(values, decodeBool(vr)) - case ByteaOid: - values = append(values, decodeBytea(vr)) - case Int8Oid: - values = append(values, decodeInt8(vr)) - case Int2Oid: - values = append(values, decodeInt2(vr)) - case Int4Oid: - values = append(values, decodeInt4(vr)) - case VarcharOid, TextOid: - values = append(values, decodeText(vr)) - case Float4Oid: - values = append(values, decodeFloat4(vr)) - case Float8Oid: - values = append(values, decodeFloat8(vr)) - case DateOid: - values = append(values, decodeDate(vr)) - case TimestampTzOid: - values = append(values, decodeTimestampTz(vr)) - case TimestampOid: - values = append(values, decodeTimestamp(vr)) - default: - // if it is not an intrinsic type then return the text - switch vr.Type().FormatCode { - case TextFormatCode: - values = append(values, vr.ReadString(vr.Len())) - case BinaryFormatCode: - rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) + switch vr.Type().FormatCode { + // All intrinsic types (except string) are encoded with binary + // encoding so anything else should be treated as a string + case TextFormatCode: + values = append(values, vr.ReadString(vr.Len())) + case BinaryFormatCode: + switch vr.Type().DataType { + case BoolOid: + values = append(values, decodeBool(vr)) + case ByteaOid: + values = append(values, decodeBytea(vr)) + case Int8Oid: + values = append(values, decodeInt8(vr)) + case Int2Oid: + values = append(values, decodeInt2(vr)) + case Int4Oid: + values = append(values, decodeInt4(vr)) + case Float4Oid: + values = append(values, decodeFloat4(vr)) + case Float8Oid: + values = append(values, decodeFloat8(vr)) + case Int2ArrayOid: + values = append(values, decodeInt2Array(vr)) + case Int4ArrayOid: + values = append(values, decodeInt4Array(vr)) + case Int8ArrayOid: + values = append(values, decodeInt8Array(vr)) + case Float4ArrayOid: + values = append(values, decodeFloat4Array(vr)) + case Float8ArrayOid: + values = append(values, decodeFloat8Array(vr)) + case TextArrayOid: + values = append(values, decodeTextArray(vr)) + case DateOid: + values = append(values, decodeDate(vr)) + case TimestampTzOid: + values = append(values, decodeTimestampTz(vr)) + case TimestampOid: + values = append(values, decodeTimestamp(vr)) default: - rows.Fatal(errors.New("Unknown format code")) + rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) } + default: + rows.Fatal(errors.New("Unknown format code")) } if vr.Err() != nil { diff --git a/query_test.go b/query_test.go index e5a8fab4..44955ccd 100644 --- a/query_test.go +++ b/query_test.go @@ -376,6 +376,8 @@ func TestQueryRowCoreTypes(t *testing.T) { if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql) } + + ensureConnValid(t, conn) } } @@ -486,3 +488,285 @@ func TestQueryRowNoResults(t *testing.T) { ensureConnValid(t, conn) } } + +func TestQueryRowCoreInt16Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []int16 + + tests := []struct { + sql string + expected []int16 + }{ + {"select $1::int2[]", []int16{1, 2, 3, 4, 5}}, + {"select $1::int2[]", []int16{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreInt32Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []int32 + + tests := []struct { + sql string + expected []int32 + }{ + {"select $1::int4[]", []int32{1, 2, 3, 4, 5}}, + {"select $1::int4[]", []int32{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreInt64Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []int64 + + tests := []struct { + sql string + expected []int64 + }{ + {"select $1::int8[]", []int64{1, 2, 3, 4, 5}}, + {"select $1::int8[]", []int64{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreFloat32Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []float32 + + tests := []struct { + sql string + expected []float32 + }{ + {"select $1::float4[]", []float32{1.5, 2.0, 3.5}}, + {"select $1::float4[]", []float32{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreFloat64Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []float64 + + tests := []struct { + sql string + expected []float64 + }{ + {"select $1::float8[]", []float64{1.5, 2.0, 3.5}}, + {"select $1::float8[]", []float64{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreStringSlice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []string + + tests := []struct { + sql string + expected []string + }{ + {"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}}, + {"select $1::text[]", []string{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} diff --git a/stdlib/sql.go b/stdlib/sql.go index a8ad5cf1..dcb2f9cd 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -54,9 +54,24 @@ import ( var openFromConnPoolCount int +// oids that map to intrinsic database/sql types. These will be allowed to be +// binary, anything else will be forced to text format +var databaseSqlOids map[pgx.Oid]bool + func init() { d := &Driver{} sql.Register("pgx", d) + + databaseSqlOids = make(map[pgx.Oid]bool) + databaseSqlOids[pgx.BoolOid] = true + databaseSqlOids[pgx.ByteaOid] = true + databaseSqlOids[pgx.Int2Oid] = true + databaseSqlOids[pgx.Int4Oid] = true + databaseSqlOids[pgx.Int8Oid] = true + databaseSqlOids[pgx.Float4Oid] = true + databaseSqlOids[pgx.Float8Oid] = true + databaseSqlOids[pgx.DateOid] = true + databaseSqlOids[pgx.TimestampTzOid] = true } type Driver struct { @@ -136,6 +151,8 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { return nil, err } + restrictBinaryToDatabaseSqlTypes(ps) + return &Stmt{ps: ps, conn: c}, nil } @@ -176,9 +193,24 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { return nil, driver.ErrBadConn } + ps, err := c.conn.Prepare("", query) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + + return c.queryPrepared("", argsV) +} + +func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + args := valueToInterface(argsV) - rows, err := c.conn.Query(query, args...) + rows, err := c.conn.Query(name, args...) if err != nil { return nil, err } @@ -186,6 +218,18 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { return &Rows{rows: rows}, nil } +// Anything that isn't a database/sql compatible type needs to be forced to +// text format so that pgx.Rows.Values doesn't decode it into a native type +// (e.g. []int32) +func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { + for i, _ := range ps.FieldDescriptions { + intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType] + if !intrinsic { + ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode + } + } +} + type Stmt struct { ps *pgx.PreparedStatement conn *Conn @@ -204,7 +248,7 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { } func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { - return s.conn.Query(s.ps.Name, argsV) + return s.conn.queryPrepared(s.ps.Name, argsV) } // TODO - rename to avoid alloc diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 56dd8859..17d9a9a9 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -336,6 +336,28 @@ func TestConnQueryFailure(t *testing.T) { ensureConnValid(t, db) } +// Test type that pgx would handle natively in binary, but since it is not a +// database/sql native type should be passed through as a string +func TestConnQueryRowPgxBinary(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + sql := "select $1::int4[]" + expected := "{1,2,3}" + var actual string + + err := db.QueryRow(sql, expected).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if actual != expected { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) + } + + ensureConnValid(t, db) +} + func TestConnQueryRowUnknownType(t *testing.T) { db := openDB(t) defer closeDB(t, db) diff --git a/values.go b/values.go index c7096bc9..57fa83e8 100644 --- a/values.go +++ b/values.go @@ -19,6 +19,12 @@ const ( TextOid = 25 Float4Oid = 700 Float8Oid = 701 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 VarcharOid = 1043 DateOid = 1082 TimestampOid = 1114 @@ -31,6 +37,27 @@ const ( BinaryFormatCode = 1 ) +var DefaultOidFormats map[Oid]int16 + +func init() { + DefaultOidFormats = make(map[Oid]int16) + DefaultOidFormats[BoolOid] = BinaryFormatCode + DefaultOidFormats[ByteaOid] = BinaryFormatCode + DefaultOidFormats[Int2Oid] = BinaryFormatCode + DefaultOidFormats[Int4Oid] = BinaryFormatCode + DefaultOidFormats[Int8Oid] = BinaryFormatCode + DefaultOidFormats[Float4Oid] = BinaryFormatCode + DefaultOidFormats[Float8Oid] = BinaryFormatCode + DefaultOidFormats[DateOid] = BinaryFormatCode + DefaultOidFormats[TimestampTzOid] = BinaryFormatCode + DefaultOidFormats[Int2ArrayOid] = BinaryFormatCode + DefaultOidFormats[Int4ArrayOid] = BinaryFormatCode + DefaultOidFormats[Int8ArrayOid] = BinaryFormatCode + DefaultOidFormats[Float4ArrayOid] = BinaryFormatCode + DefaultOidFormats[Float8ArrayOid] = BinaryFormatCode + DefaultOidFormats[TextArrayOid] = BinaryFormatCode +} + type SerializationError string func (e SerializationError) Error() string { @@ -945,3 +972,406 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error { return nil } + +func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { + numDims := vr.ReadInt32() + if numDims == 0 { + return 0, nil + } + if numDims != 1 { + return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims)) + } + + vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care + vr.ReadInt32() // element oid + + length = vr.ReadInt32() + + idxFirstElem := vr.ReadInt32() + if idxFirstElem != 1 { + return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem)) + } + + return length, nil +} + +func decodeInt2Array(vr *ValueReader) []int16 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int2ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2ArrayOid, vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]int16, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 2: + a[i] = vr.ReadInt16() + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeInt2Array(w *WriteBuf, value interface{}) error { + slice, ok := value.([]int16) + if !ok { + return fmt.Errorf("Expected []int16, received %T", value) + } + + size := 20 + len(slice)*6 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(Int2Oid) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(2) + w.WriteInt16(v) + } + + return nil +} + +func decodeInt4Array(vr *ValueReader) []int32 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int4ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4ArrayOid, vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]int32, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 4: + a[i] = vr.ReadInt32() + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeInt4Array(w *WriteBuf, value interface{}) error { + slice, ok := value.([]int32) + if !ok { + return fmt.Errorf("Expected []int32, received %T", value) + } + + size := 20 + len(slice)*8 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(Int4Oid) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(4) + w.WriteInt32(v) + } + + return nil +} + +func decodeInt8Array(vr *ValueReader) []int64 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int8ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8ArrayOid, vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]int64, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + a[i] = vr.ReadInt64() + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeInt8Array(w *WriteBuf, value interface{}) error { + slice, ok := value.([]int64) + if !ok { + return fmt.Errorf("Expected []int64, received %T", value) + } + + size := 20 + len(slice)*12 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(Int8Oid) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(8) + w.WriteInt64(v) + } + + return nil +} + +func decodeFloat4Array(vr *ValueReader) []float32 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Float4ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Float4ArrayOid, vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]float32, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 4: + n := vr.ReadInt32() + p := unsafe.Pointer(&n) + a[i] = *(*float32)(p) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeFloat4Array(w *WriteBuf, value interface{}) error { + slice, ok := value.([]float32) + if !ok { + return fmt.Errorf("Expected []float32, received %T", value) + } + + size := 20 + len(slice)*8 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(Float4Oid) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(4) + + p := unsafe.Pointer(&v) + w.WriteInt32(*(*int32)(p)) + } + + return nil +} + +func decodeFloat8Array(vr *ValueReader) []float64 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Float8ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Float8ArrayOid, vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]float64, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + n := vr.ReadInt64() + p := unsafe.Pointer(&n) + a[i] = *(*float64)(p) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeFloat8Array(w *WriteBuf, value interface{}) error { + slice, ok := value.([]float64) + if !ok { + return fmt.Errorf("Expected []float64, received %T", value) + } + + size := 20 + len(slice)*12 + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(Float8Oid) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(8) + + p := unsafe.Pointer(&v) + w.WriteInt64(*(*int64)(p)) + } + + return nil +} + +func decodeTextArray(vr *ValueReader) []string { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != TextArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TextArrayOid, vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]string, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + if elSize == -1 { + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + } + + a[i] = vr.ReadString(elSize) + } + + return a +} + +func encodeTextArray(w *WriteBuf, value interface{}) error { + slice, ok := value.([]string) + if !ok { + return fmt.Errorf("Expected []string, received %T", value) + } + + var totalStringSize int + for _, v := range slice { + totalStringSize += len(v) + } + + size := 20 + len(slice)*4 + totalStringSize + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(TextOid) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(int32(len(v))) + w.WriteBytes([]byte(v)) + } + + return nil +}