Add compatibility with database/sql custom types
Support database/sql.Scanner Support database/sql/driver.Valuer
This commit is contained in:
+113
@@ -2,10 +2,13 @@ package pgx_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"github.com/jackc/pgx"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func TestConnQueryScan(t *testing.T) {
|
||||
@@ -904,3 +907,113 @@ func TestReadingNullByteArrays(t *testing.T) {
|
||||
t.Errorf("Expected to read 2 rows, read: ", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Use github.com/shopspring/decimal as real-world database/sql custom type
|
||||
// to test against.
|
||||
func TestConnQueryDatabaseSQLScanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var num decimal.Decimal
|
||||
|
||||
err := conn.QueryRow("select '1234.567'::decimal").Scan(&num)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
||||
expected, err := decimal.NewFromString("1234.567")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !num.Equals(expected) {
|
||||
t.Errorf("Expected num to be %v, but it was %v", expected, num)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Use github.com/shopspring/decimal as real-world database/sql custom type
|
||||
// to test against.
|
||||
func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
expected, err := decimal.NewFromString("1234.567")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var num decimal.Decimal
|
||||
|
||||
err = conn.QueryRow("select $1::decimal", expected).Scan(&num)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
||||
if !num.Equals(expected) {
|
||||
t.Errorf("Expected num to be %v, but it was %v", expected, num)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryDatabaseSQLNullX(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type row struct {
|
||||
boolValid sql.NullBool
|
||||
boolNull sql.NullBool
|
||||
int64Valid sql.NullInt64
|
||||
int64Null sql.NullInt64
|
||||
float64Valid sql.NullFloat64
|
||||
float64Null sql.NullFloat64
|
||||
stringValid sql.NullString
|
||||
stringNull sql.NullString
|
||||
}
|
||||
|
||||
expected := row{
|
||||
boolValid: sql.NullBool{Bool: true, Valid: true},
|
||||
int64Valid: sql.NullInt64{Int64: 123, Valid: true},
|
||||
float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true},
|
||||
stringValid: sql.NullString{String: "pgx", Valid: true},
|
||||
}
|
||||
|
||||
var actual row
|
||||
|
||||
err := conn.QueryRow(
|
||||
"select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text",
|
||||
expected.boolValid,
|
||||
expected.boolNull,
|
||||
expected.int64Valid,
|
||||
expected.int64Null,
|
||||
expected.float64Valid,
|
||||
expected.float64Null,
|
||||
expected.stringValid,
|
||||
expected.stringNull,
|
||||
).Scan(
|
||||
&actual.boolValid,
|
||||
&actual.boolNull,
|
||||
&actual.int64Valid,
|
||||
&actual.int64Null,
|
||||
&actual.float64Valid,
|
||||
&actual.float64Null,
|
||||
&actual.stringValid,
|
||||
&actual.stringNull,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
||||
if expected != actual {
|
||||
t.Errorf("Expected %v, but got %v", expected, actual)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user