Add compatibility with database/sql custom types
Support database/sql.Scanner Support database/sql/driver.Valuer
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
# Tip
|
# Tip
|
||||||
|
|
||||||
|
* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
|
||||||
* Go float64 can no longer be encoded to a PostgreSQL float4
|
* Go float64 can no longer be encoded to a PostgreSQL float4
|
||||||
* Add ConnPool.Reset method
|
* Add ConnPool.Reset method
|
||||||
* []byte skips encoding/decoding
|
* []byte skips encoding/decoding
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ Pgx supports many additional features beyond what is available through database/
|
|||||||
* Maps inet and cidr PostgreSQL types to net.IPNet
|
* Maps inet and cidr PostgreSQL types to net.IPNet
|
||||||
* Large object support
|
* Large object support
|
||||||
* Null mapping to Null* struct or pointer to pointer.
|
* Null mapping to Null* struct or pointer to pointer.
|
||||||
|
* Supports database/sql.Scanner and database/sql/driver/Valuer interfaces for custom types
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -851,15 +852,20 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
|
|
||||||
wbuf.WriteInt16(int16(len(arguments)))
|
wbuf.WriteInt16(int16(len(arguments)))
|
||||||
for i, oid := range ps.ParameterOids {
|
for i, oid := range ps.ParameterOids {
|
||||||
|
encode:
|
||||||
if arguments[i] == nil {
|
if arguments[i] == nil {
|
||||||
wbuf.WriteInt32(-1)
|
wbuf.WriteInt32(-1)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
encode:
|
|
||||||
switch arg := arguments[i].(type) {
|
switch arg := arguments[i].(type) {
|
||||||
case Encoder:
|
case Encoder:
|
||||||
err = arg.Encode(wbuf, oid)
|
err = arg.Encode(wbuf, oid)
|
||||||
|
case driver.Valuer:
|
||||||
|
arguments[i], err = arg.Value()
|
||||||
|
if err == nil {
|
||||||
|
goto encode
|
||||||
|
}
|
||||||
case string:
|
case string:
|
||||||
err = encodeText(wbuf, arguments[i])
|
err = encodeText(wbuf, arguments[i])
|
||||||
case []byte:
|
case []byte:
|
||||||
|
|||||||
@@ -181,6 +181,9 @@ Conn.PgTypes.
|
|||||||
See example_custom_type_test.go for an example of a custom type for the
|
See example_custom_type_test.go for an example of a custom type for the
|
||||||
PostgreSQL point type.
|
PostgreSQL point type.
|
||||||
|
|
||||||
|
pgx also includes support for custom types implementing the database/sql.Scanner
|
||||||
|
and database/sql/driver.Valuer interfaces.
|
||||||
|
|
||||||
Raw Bytes Mapping
|
Raw Bytes Mapping
|
||||||
|
|
||||||
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
|
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package pgx
|
package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -255,6 +256,40 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
rows.Fatal(scanArgError{col: i, err: err})
|
rows.Fatal(scanArgError{col: i, err: err})
|
||||||
}
|
}
|
||||||
|
} else if s, ok := d.(sql.Scanner); ok {
|
||||||
|
var val interface{}
|
||||||
|
if 0 <= vr.Len() {
|
||||||
|
switch vr.Type().DataType {
|
||||||
|
case BoolOid:
|
||||||
|
val = decodeBool(vr)
|
||||||
|
case Int8Oid:
|
||||||
|
val = int64(decodeInt8(vr))
|
||||||
|
case Int2Oid:
|
||||||
|
val = int64(decodeInt2(vr))
|
||||||
|
case Int4Oid:
|
||||||
|
val = int64(decodeInt4(vr))
|
||||||
|
case TextOid, VarcharOid:
|
||||||
|
val = decodeText(vr)
|
||||||
|
case OidOid:
|
||||||
|
val = int64(decodeOid(vr))
|
||||||
|
case Float4Oid:
|
||||||
|
val = float64(decodeFloat4(vr))
|
||||||
|
case Float8Oid:
|
||||||
|
val = decodeFloat8(vr)
|
||||||
|
case DateOid:
|
||||||
|
val = decodeDate(vr)
|
||||||
|
case TimestampOid:
|
||||||
|
val = decodeTimestamp(vr)
|
||||||
|
case TimestampTzOid:
|
||||||
|
val = decodeTimestampTz(vr)
|
||||||
|
default:
|
||||||
|
val = vr.ReadBytes(vr.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = s.Scan(val)
|
||||||
|
if err != nil {
|
||||||
|
rows.Fatal(scanArgError{col: i, err: err})
|
||||||
|
}
|
||||||
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
||||||
decodeJson(vr, &d)
|
decodeJson(vr, &d)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+113
@@ -2,10 +2,13 @@ package pgx_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"database/sql"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConnQueryScan(t *testing.T) {
|
func TestConnQueryScan(t *testing.T) {
|
||||||
@@ -904,3 +907,113 @@ func TestReadingNullByteArrays(t *testing.T) {
|
|||||||
t.Errorf("Expected to read 2 rows, read: ", count)
|
t.Errorf("Expected to read 2 rows, read: ", count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use github.com/shopspring/decimal as real-world database/sql custom type
|
||||||
|
// to test against.
|
||||||
|
func TestConnQueryDatabaseSQLScanner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var num decimal.Decimal
|
||||||
|
|
||||||
|
err := conn.QueryRow("select '1234.567'::decimal").Scan(&num)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected, err := decimal.NewFromString("1234.567")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !num.Equals(expected) {
|
||||||
|
t.Errorf("Expected num to be %v, but it was %v", expected, num)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use github.com/shopspring/decimal as real-world database/sql custom type
|
||||||
|
// to test against.
|
||||||
|
func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
expected, err := decimal.NewFromString("1234.567")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var num decimal.Decimal
|
||||||
|
|
||||||
|
err = conn.QueryRow("select $1::decimal", expected).Scan(&num)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !num.Equals(expected) {
|
||||||
|
t.Errorf("Expected num to be %v, but it was %v", expected, num)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnQueryDatabaseSQLNullX(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
type row struct {
|
||||||
|
boolValid sql.NullBool
|
||||||
|
boolNull sql.NullBool
|
||||||
|
int64Valid sql.NullInt64
|
||||||
|
int64Null sql.NullInt64
|
||||||
|
float64Valid sql.NullFloat64
|
||||||
|
float64Null sql.NullFloat64
|
||||||
|
stringValid sql.NullString
|
||||||
|
stringNull sql.NullString
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := row{
|
||||||
|
boolValid: sql.NullBool{Bool: true, Valid: true},
|
||||||
|
int64Valid: sql.NullInt64{Int64: 123, Valid: true},
|
||||||
|
float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true},
|
||||||
|
stringValid: sql.NullString{String: "pgx", Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual row
|
||||||
|
|
||||||
|
err := conn.QueryRow(
|
||||||
|
"select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text",
|
||||||
|
expected.boolValid,
|
||||||
|
expected.boolNull,
|
||||||
|
expected.int64Valid,
|
||||||
|
expected.int64Null,
|
||||||
|
expected.float64Valid,
|
||||||
|
expected.float64Null,
|
||||||
|
expected.stringValid,
|
||||||
|
expected.stringNull,
|
||||||
|
).Scan(
|
||||||
|
&actual.boolValid,
|
||||||
|
&actual.boolNull,
|
||||||
|
&actual.int64Valid,
|
||||||
|
&actual.int64Null,
|
||||||
|
&actual.float64Valid,
|
||||||
|
&actual.float64Null,
|
||||||
|
&actual.stringValid,
|
||||||
|
&actual.stringNull,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected != actual {
|
||||||
|
t.Errorf("Expected %v, but got %v", expected, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user