Add support for integer, float and text arrays
Restructure internals a bit so pgx/stdlib can turn off binary encoding and receive text back for array types.
This commit is contained in:
@@ -70,9 +70,6 @@ if err != nil {
|
||||
}
|
||||
```
|
||||
|
||||
Prepared statements will use the binary transmission when possible. This can
|
||||
substantially increase performance.
|
||||
|
||||
### Explicit Connection Pool
|
||||
|
||||
Connection pool usage is explicit and configurable. In pgx, a connection can
|
||||
@@ -151,9 +148,17 @@ point type.
|
||||
pgx includes Null* types in a similar fashion to database/sql that implement the
|
||||
necessary interfaces to be encoded and scanned.
|
||||
|
||||
### Array Mapping
|
||||
|
||||
pgx maps between int16, int32, int64, float32, float64, and string Go slices
|
||||
and the equivalent PostgreSQL array type. Go slices of native types do not
|
||||
support nulls, so if a PostgreSQL array that contains a slice is read into a
|
||||
native Go slice an error will occur.
|
||||
|
||||
### Logging
|
||||
|
||||
pgx connections optionally accept a logger from the [log15 package](http://gopkg.in/inconshreveable/log15.v2).
|
||||
pgx connections optionally accept a logger from the [log15
|
||||
package](http://gopkg.in/inconshreveable/log15.v2).
|
||||
|
||||
## Testing
|
||||
|
||||
|
||||
@@ -294,10 +294,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
||||
case rowDescription:
|
||||
ps.FieldDescriptions = c.rxRowDescription(r)
|
||||
for i := range ps.FieldDescriptions {
|
||||
switch ps.FieldDescriptions[i].DataType {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, DateOid, TimestampTzOid:
|
||||
ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
|
||||
}
|
||||
ps.FieldDescriptions[i].FormatCode, _ = DefaultOidFormats[ps.FieldDescriptions[i].DataType]
|
||||
}
|
||||
case noData:
|
||||
case readyForQuery:
|
||||
@@ -474,7 +471,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid:
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
default:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
@@ -518,6 +515,18 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||
err = encodeTimestampTz(wbuf, arguments[i])
|
||||
case TimestampOid:
|
||||
err = encodeTimestamp(wbuf, arguments[i])
|
||||
case Int2ArrayOid:
|
||||
err = encodeInt2Array(wbuf, arguments[i])
|
||||
case Int4ArrayOid:
|
||||
err = encodeInt4Array(wbuf, arguments[i])
|
||||
case Int8ArrayOid:
|
||||
err = encodeInt8Array(wbuf, arguments[i])
|
||||
case Float4ArrayOid:
|
||||
err = encodeFloat4Array(wbuf, arguments[i])
|
||||
case Float8ArrayOid:
|
||||
err = encodeFloat8Array(wbuf, arguments[i])
|
||||
case TextArrayOid:
|
||||
err = encodeTextArray(wbuf, arguments[i])
|
||||
default:
|
||||
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement Encoder", arg))
|
||||
}
|
||||
|
||||
+5
-2
@@ -20,10 +20,13 @@ func closeConn(t testing.TB, conn *pgx.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) {
|
||||
if _, err := conn.Prepare(name, sql); err != nil {
|
||||
func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) *pgx.PreparedStatement {
|
||||
ps, err := conn.Prepare(name, sql)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not prepare %v: %v", name, err)
|
||||
}
|
||||
|
||||
return ps
|
||||
}
|
||||
|
||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) {
|
||||
|
||||
@@ -214,6 +214,18 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
||||
*d = decodeFloat4(vr)
|
||||
case *float64:
|
||||
*d = decodeFloat8(vr)
|
||||
case *[]int16:
|
||||
*d = decodeInt2Array(vr)
|
||||
case *[]int32:
|
||||
*d = decodeInt4Array(vr)
|
||||
case *[]int64:
|
||||
*d = decodeInt8Array(vr)
|
||||
case *[]float32:
|
||||
*d = decodeFloat4Array(vr)
|
||||
case *[]float64:
|
||||
*d = decodeFloat8Array(vr)
|
||||
case *[]string:
|
||||
*d = decodeTextArray(vr)
|
||||
case *time.Time:
|
||||
switch vr.Type().DataType {
|
||||
case DateOid:
|
||||
@@ -263,39 +275,50 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch vr.Type().DataType {
|
||||
case BoolOid:
|
||||
values = append(values, decodeBool(vr))
|
||||
case ByteaOid:
|
||||
values = append(values, decodeBytea(vr))
|
||||
case Int8Oid:
|
||||
values = append(values, decodeInt8(vr))
|
||||
case Int2Oid:
|
||||
values = append(values, decodeInt2(vr))
|
||||
case Int4Oid:
|
||||
values = append(values, decodeInt4(vr))
|
||||
case VarcharOid, TextOid:
|
||||
values = append(values, decodeText(vr))
|
||||
case Float4Oid:
|
||||
values = append(values, decodeFloat4(vr))
|
||||
case Float8Oid:
|
||||
values = append(values, decodeFloat8(vr))
|
||||
case DateOid:
|
||||
values = append(values, decodeDate(vr))
|
||||
case TimestampTzOid:
|
||||
values = append(values, decodeTimestampTz(vr))
|
||||
case TimestampOid:
|
||||
values = append(values, decodeTimestamp(vr))
|
||||
default:
|
||||
// if it is not an intrinsic type then return the text
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
values = append(values, vr.ReadString(vr.Len()))
|
||||
case BinaryFormatCode:
|
||||
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
|
||||
switch vr.Type().FormatCode {
|
||||
// All intrinsic types (except string) are encoded with binary
|
||||
// encoding so anything else should be treated as a string
|
||||
case TextFormatCode:
|
||||
values = append(values, vr.ReadString(vr.Len()))
|
||||
case BinaryFormatCode:
|
||||
switch vr.Type().DataType {
|
||||
case BoolOid:
|
||||
values = append(values, decodeBool(vr))
|
||||
case ByteaOid:
|
||||
values = append(values, decodeBytea(vr))
|
||||
case Int8Oid:
|
||||
values = append(values, decodeInt8(vr))
|
||||
case Int2Oid:
|
||||
values = append(values, decodeInt2(vr))
|
||||
case Int4Oid:
|
||||
values = append(values, decodeInt4(vr))
|
||||
case Float4Oid:
|
||||
values = append(values, decodeFloat4(vr))
|
||||
case Float8Oid:
|
||||
values = append(values, decodeFloat8(vr))
|
||||
case Int2ArrayOid:
|
||||
values = append(values, decodeInt2Array(vr))
|
||||
case Int4ArrayOid:
|
||||
values = append(values, decodeInt4Array(vr))
|
||||
case Int8ArrayOid:
|
||||
values = append(values, decodeInt8Array(vr))
|
||||
case Float4ArrayOid:
|
||||
values = append(values, decodeFloat4Array(vr))
|
||||
case Float8ArrayOid:
|
||||
values = append(values, decodeFloat8Array(vr))
|
||||
case TextArrayOid:
|
||||
values = append(values, decodeTextArray(vr))
|
||||
case DateOid:
|
||||
values = append(values, decodeDate(vr))
|
||||
case TimestampTzOid:
|
||||
values = append(values, decodeTimestampTz(vr))
|
||||
case TimestampOid:
|
||||
values = append(values, decodeTimestamp(vr))
|
||||
default:
|
||||
rows.Fatal(errors.New("Unknown format code"))
|
||||
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
|
||||
}
|
||||
default:
|
||||
rows.Fatal(errors.New("Unknown format code"))
|
||||
}
|
||||
|
||||
if vr.Err() != nil {
|
||||
|
||||
+284
@@ -376,6 +376,8 @@ func TestQueryRowCoreTypes(t *testing.T) {
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -486,3 +488,285 @@ func TestQueryRowNoResults(t *testing.T) {
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRowCoreInt16Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []int16
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []int16
|
||||
}{
|
||||
{"select $1::int2[]", []int16{1, 2, 3, 4, 5}},
|
||||
{"select $1::int2[]", []int16{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreInt32Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []int32
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []int32
|
||||
}{
|
||||
{"select $1::int4[]", []int32{1, 2, 3, 4, 5}},
|
||||
{"select $1::int4[]", []int32{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreInt64Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []int64
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []int64
|
||||
}{
|
||||
{"select $1::int8[]", []int64{1, 2, 3, 4, 5}},
|
||||
{"select $1::int8[]", []int64{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreFloat32Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []float32
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []float32
|
||||
}{
|
||||
{"select $1::float4[]", []float32{1.5, 2.0, 3.5}},
|
||||
{"select $1::float4[]", []float32{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreFloat64Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []float64
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []float64
|
||||
}{
|
||||
{"select $1::float8[]", []float64{1.5, 2.0, 3.5}},
|
||||
{"select $1::float8[]", []float64{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreStringSlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []string
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []string
|
||||
}{
|
||||
{"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
|
||||
{"select $1::text[]", []string{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
+46
-2
@@ -54,9 +54,24 @@ import (
|
||||
|
||||
var openFromConnPoolCount int
|
||||
|
||||
// oids that map to intrinsic database/sql types. These will be allowed to be
|
||||
// binary, anything else will be forced to text format
|
||||
var databaseSqlOids map[pgx.Oid]bool
|
||||
|
||||
func init() {
|
||||
d := &Driver{}
|
||||
sql.Register("pgx", d)
|
||||
|
||||
databaseSqlOids = make(map[pgx.Oid]bool)
|
||||
databaseSqlOids[pgx.BoolOid] = true
|
||||
databaseSqlOids[pgx.ByteaOid] = true
|
||||
databaseSqlOids[pgx.Int2Oid] = true
|
||||
databaseSqlOids[pgx.Int4Oid] = true
|
||||
databaseSqlOids[pgx.Int8Oid] = true
|
||||
databaseSqlOids[pgx.Float4Oid] = true
|
||||
databaseSqlOids[pgx.Float8Oid] = true
|
||||
databaseSqlOids[pgx.DateOid] = true
|
||||
databaseSqlOids[pgx.TimestampTzOid] = true
|
||||
}
|
||||
|
||||
type Driver struct {
|
||||
@@ -136,6 +151,8 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
|
||||
return &Stmt{ps: ps, conn: c}, nil
|
||||
}
|
||||
|
||||
@@ -176,9 +193,24 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
ps, err := c.conn.Prepare("", query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
|
||||
return c.queryPrepared("", argsV)
|
||||
}
|
||||
|
||||
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
args := valueToInterface(argsV)
|
||||
|
||||
rows, err := c.conn.Query(query, args...)
|
||||
rows, err := c.conn.Query(name, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -186,6 +218,18 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
||||
return &Rows{rows: rows}, nil
|
||||
}
|
||||
|
||||
// Anything that isn't a database/sql compatible type needs to be forced to
|
||||
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
||||
// (e.g. []int32)
|
||||
func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
|
||||
for i, _ := range ps.FieldDescriptions {
|
||||
intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType]
|
||||
if !intrinsic {
|
||||
ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Stmt struct {
|
||||
ps *pgx.PreparedStatement
|
||||
conn *Conn
|
||||
@@ -204,7 +248,7 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
||||
}
|
||||
|
||||
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
||||
return s.conn.Query(s.ps.Name, argsV)
|
||||
return s.conn.queryPrepared(s.ps.Name, argsV)
|
||||
}
|
||||
|
||||
// TODO - rename to avoid alloc
|
||||
|
||||
@@ -336,6 +336,28 @@ func TestConnQueryFailure(t *testing.T) {
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
// Test type that pgx would handle natively in binary, but since it is not a
|
||||
// database/sql native type should be passed through as a string
|
||||
func TestConnQueryRowPgxBinary(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
|
||||
sql := "select $1::int4[]"
|
||||
expected := "{1,2,3}"
|
||||
var actual string
|
||||
|
||||
err := db.QueryRow(sql, expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
||||
}
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
|
||||
}
|
||||
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestConnQueryRowUnknownType(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
|
||||
@@ -19,6 +19,12 @@ const (
|
||||
TextOid = 25
|
||||
Float4Oid = 700
|
||||
Float8Oid = 701
|
||||
Int2ArrayOid = 1005
|
||||
Int4ArrayOid = 1007
|
||||
TextArrayOid = 1009
|
||||
Int8ArrayOid = 1016
|
||||
Float4ArrayOid = 1021
|
||||
Float8ArrayOid = 1022
|
||||
VarcharOid = 1043
|
||||
DateOid = 1082
|
||||
TimestampOid = 1114
|
||||
@@ -31,6 +37,27 @@ const (
|
||||
BinaryFormatCode = 1
|
||||
)
|
||||
|
||||
var DefaultOidFormats map[Oid]int16
|
||||
|
||||
func init() {
|
||||
DefaultOidFormats = make(map[Oid]int16)
|
||||
DefaultOidFormats[BoolOid] = BinaryFormatCode
|
||||
DefaultOidFormats[ByteaOid] = BinaryFormatCode
|
||||
DefaultOidFormats[Int2Oid] = BinaryFormatCode
|
||||
DefaultOidFormats[Int4Oid] = BinaryFormatCode
|
||||
DefaultOidFormats[Int8Oid] = BinaryFormatCode
|
||||
DefaultOidFormats[Float4Oid] = BinaryFormatCode
|
||||
DefaultOidFormats[Float8Oid] = BinaryFormatCode
|
||||
DefaultOidFormats[DateOid] = BinaryFormatCode
|
||||
DefaultOidFormats[TimestampTzOid] = BinaryFormatCode
|
||||
DefaultOidFormats[Int2ArrayOid] = BinaryFormatCode
|
||||
DefaultOidFormats[Int4ArrayOid] = BinaryFormatCode
|
||||
DefaultOidFormats[Int8ArrayOid] = BinaryFormatCode
|
||||
DefaultOidFormats[Float4ArrayOid] = BinaryFormatCode
|
||||
DefaultOidFormats[Float8ArrayOid] = BinaryFormatCode
|
||||
DefaultOidFormats[TextArrayOid] = BinaryFormatCode
|
||||
}
|
||||
|
||||
type SerializationError string
|
||||
|
||||
func (e SerializationError) Error() string {
|
||||
@@ -945,3 +972,406 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
||||
numDims := vr.ReadInt32()
|
||||
if numDims == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if numDims != 1 {
|
||||
return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims))
|
||||
}
|
||||
|
||||
vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care
|
||||
vr.ReadInt32() // element oid
|
||||
|
||||
length = vr.ReadInt32()
|
||||
|
||||
idxFirstElem := vr.ReadInt32()
|
||||
if idxFirstElem != 1 {
|
||||
return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem))
|
||||
}
|
||||
|
||||
return length, nil
|
||||
}
|
||||
|
||||
func decodeInt2Array(vr *ValueReader) []int16 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int2ArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2ArrayOid, vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]int16, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case 2:
|
||||
a[i] = vr.ReadInt16()
|
||||
case -1:
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeInt2Array(w *WriteBuf, value interface{}) error {
|
||||
slice, ok := value.([]int16)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []int16, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*6
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Int2Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(2)
|
||||
w.WriteInt16(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeInt4Array(vr *ValueReader) []int32 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int4ArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4ArrayOid, vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]int32, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case 4:
|
||||
a[i] = vr.ReadInt32()
|
||||
case -1:
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeInt4Array(w *WriteBuf, value interface{}) error {
|
||||
slice, ok := value.([]int32)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []int32, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*8
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Int4Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(4)
|
||||
w.WriteInt32(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeInt8Array(vr *ValueReader) []int64 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int8ArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8ArrayOid, vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]int64, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case 8:
|
||||
a[i] = vr.ReadInt64()
|
||||
case -1:
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeInt8Array(w *WriteBuf, value interface{}) error {
|
||||
slice, ok := value.([]int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []int64, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*12
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Int8Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(8)
|
||||
w.WriteInt64(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeFloat4Array(vr *ValueReader) []float32 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Float4ArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Float4ArrayOid, vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]float32, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case 4:
|
||||
n := vr.ReadInt32()
|
||||
p := unsafe.Pointer(&n)
|
||||
a[i] = *(*float32)(p)
|
||||
case -1:
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeFloat4Array(w *WriteBuf, value interface{}) error {
|
||||
slice, ok := value.([]float32)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []float32, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*8
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Float4Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(4)
|
||||
|
||||
p := unsafe.Pointer(&v)
|
||||
w.WriteInt32(*(*int32)(p))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeFloat8Array(vr *ValueReader) []float64 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Float8ArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Float8ArrayOid, vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]float64, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case 8:
|
||||
n := vr.ReadInt64()
|
||||
p := unsafe.Pointer(&n)
|
||||
a[i] = *(*float64)(p)
|
||||
case -1:
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeFloat8Array(w *WriteBuf, value interface{}) error {
|
||||
slice, ok := value.([]float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []float64, received %T", value)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*12
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(Float8Oid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(8)
|
||||
|
||||
p := unsafe.Pointer(&v)
|
||||
w.WriteInt64(*(*int64)(p))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeTextArray(vr *ValueReader) []string {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != TextArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TextArrayOid, vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([]string, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
if elSize == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
}
|
||||
|
||||
a[i] = vr.ReadString(elSize)
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeTextArray(w *WriteBuf, value interface{}) error {
|
||||
slice, ok := value.([]string)
|
||||
if !ok {
|
||||
return fmt.Errorf("Expected []string, received %T", value)
|
||||
}
|
||||
|
||||
var totalStringSize int
|
||||
for _, v := range slice {
|
||||
totalStringSize += len(v)
|
||||
}
|
||||
|
||||
size := 20 + len(slice)*4 + totalStringSize
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(TextOid) // type of elements
|
||||
w.WriteInt32(int32(len(slice))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
for _, v := range slice {
|
||||
w.WriteInt32(int32(len(v)))
|
||||
w.WriteBytes([]byte(v))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user