66625e6489
Also remove PrepareEx. It's primary usage was for context. Supplying parameter OIDs is unnecessary when you can type cast in the query SQL. If it does become necessary or desirable to add options back it can be added in a backwards compatible way by adding a varargs as last argument.
287 lines
7.8 KiB
Go
287 lines
7.8 KiB
Go
package testutil
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/jackc/pgx"
|
|
"github.com/jackc/pgx/pgtype"
|
|
_ "github.com/jackc/pgx/stdlib"
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
|
var sqlDriverName string
|
|
switch driverName {
|
|
case "github.com/lib/pq":
|
|
sqlDriverName = "postgres"
|
|
case "github.com/jackc/pgx/stdlib":
|
|
sqlDriverName = "pgx"
|
|
default:
|
|
t.Fatalf("Unknown driver %v", driverName)
|
|
}
|
|
|
|
db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
func MustConnectPgx(t testing.TB) *pgx.Conn {
|
|
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func MustClose(t testing.TB, conn interface {
|
|
Close() error
|
|
}) {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func MustCloseContext(t testing.TB, conn interface {
|
|
Close(context.Context) error
|
|
}) {
|
|
err := conn.Close(context.Background())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
type forceTextEncoder struct {
|
|
e pgtype.TextEncoder
|
|
}
|
|
|
|
func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
|
return f.e.EncodeText(ci, buf)
|
|
}
|
|
|
|
type forceBinaryEncoder struct {
|
|
e pgtype.BinaryEncoder
|
|
}
|
|
|
|
func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
|
return f.e.EncodeBinary(ci, buf)
|
|
}
|
|
|
|
func ForceEncoder(e interface{}, formatCode int16) interface{} {
|
|
switch formatCode {
|
|
case pgx.TextFormatCode:
|
|
if e, ok := e.(pgtype.TextEncoder); ok {
|
|
return forceTextEncoder{e: e}
|
|
}
|
|
case pgx.BinaryFormatCode:
|
|
if e, ok := e.(pgtype.BinaryEncoder); ok {
|
|
return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
|
|
TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
|
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
|
TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
|
}
|
|
}
|
|
|
|
func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustCloseContext(t, conn)
|
|
|
|
_, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for i, v := range values {
|
|
for _, paramFormat := range formats {
|
|
for _, resultFormat := range formats {
|
|
vEncoder := ForceEncoder(v, paramFormat.formatCode)
|
|
if vEncoder == nil {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name)
|
|
continue
|
|
}
|
|
switch resultFormat.formatCode {
|
|
case pgx.TextFormatCode:
|
|
if _, ok := v.(pgtype.TextEncoder); !ok {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
|
|
continue
|
|
}
|
|
case pgx.BinaryFormatCode:
|
|
if _, ok := v.(pgtype.BinaryEncoder); !ok {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
|
|
err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectDatabaseSQL(t, driverName)
|
|
defer MustClose(t, conn)
|
|
|
|
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
for i, v := range values {
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err := ps.QueryRow(v).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", driverName, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
|
|
type NormalizeTest struct {
|
|
SQL string
|
|
Value interface{}
|
|
}
|
|
|
|
func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) {
|
|
TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc)
|
|
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
|
TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc)
|
|
}
|
|
}
|
|
|
|
func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustCloseContext(t, conn)
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
for _, fc := range formats {
|
|
psName := fmt.Sprintf("test%d", i)
|
|
ps, err := conn.Prepare(context.Background(), psName, tt.SQL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
|
if ForceEncoder(tt.Value, fc.formatCode) == nil {
|
|
t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
|
|
continue
|
|
}
|
|
// Derefence value if it is a pointer
|
|
derefV := tt.Value
|
|
refVal := reflect.ValueOf(tt.Value)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err = conn.QueryRow(context.Background(), psName).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", fc.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectDatabaseSQL(t, driverName)
|
|
defer MustClose(t, conn)
|
|
|
|
for i, tt := range tests {
|
|
ps, err := conn.Prepare(tt.SQL)
|
|
if err != nil {
|
|
t.Errorf("%d. %v", i, err)
|
|
continue
|
|
}
|
|
|
|
// Derefence value if it is a pointer
|
|
derefV := tt.Value
|
|
refVal := reflect.ValueOf(tt.Value)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err = ps.QueryRow().Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", driverName, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|