Add database/sql support to pgtype
This commit is contained in:
+60
-1
@@ -1,6 +1,7 @@
|
||||
package pgtype_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -10,6 +11,8 @@ import (
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
_ "github.com/jackc/pgx/stdlib"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Test for renamed types
|
||||
@@ -24,6 +27,25 @@ type _float32Slice []float32
|
||||
type _float64Slice []float64
|
||||
type _byteSlice []byte
|
||||
|
||||
func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
||||
var sqlDriverName string
|
||||
switch driverName {
|
||||
case "github.com/lib/pq":
|
||||
sqlDriverName = "postgres"
|
||||
case "github.com/jackc/pgx/stdlib":
|
||||
sqlDriverName = "pgx"
|
||||
default:
|
||||
t.Fatalf("Unknown driver %v", driverName)
|
||||
}
|
||||
|
||||
db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func mustConnectPgx(t testing.TB) *pgx.Conn {
|
||||
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
@@ -93,6 +115,13 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface
|
||||
}
|
||||
|
||||
func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
@@ -114,7 +143,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int
|
||||
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
||||
vEncoder := forceEncoder(v, fc.formatCode)
|
||||
if vEncoder == nil {
|
||||
t.Logf("%#v does not implement %v", v, fc.name)
|
||||
t.Logf("Skipping: %#v does not implement %v", v, fc.name)
|
||||
continue
|
||||
}
|
||||
// Derefence value if it is a pointer
|
||||
@@ -136,3 +165,33 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectDatabaseSQL(t, driverName)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i, v := range values {
|
||||
// Derefence value if it is a pointer
|
||||
derefV := v
|
||||
refVal := reflect.ValueOf(v)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
derefV = refVal.Elem().Interface()
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err := ps.QueryRow(v).Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("%v %d: %v", driverName, i, err)
|
||||
}
|
||||
|
||||
if !eqFunc(result.Elem().Interface(), derefV) {
|
||||
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user