ab885b375b
It was a mistake to use it in other contexts. This made interop difficult between pacakges that depended on pgtype such as pgx and packages that did not like pgconn and pgproto3. In particular this was awkward for prepared statements. Because pgx depends on pgtype and the tests for pgtype depend on pgx this change will require a couple back and forth commits to get the go.mod dependecies correct.
287 lines
7.8 KiB
Go
287 lines
7.8 KiB
Go
package testutil
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/jackc/pgtype"
|
|
"github.com/jackc/pgx/v4"
|
|
_ "github.com/jackc/pgx/v4/stdlib"
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
|
var sqlDriverName string
|
|
switch driverName {
|
|
case "github.com/lib/pq":
|
|
sqlDriverName = "postgres"
|
|
case "github.com/jackc/pgx/stdlib":
|
|
sqlDriverName = "pgx"
|
|
default:
|
|
t.Fatalf("Unknown driver %v", driverName)
|
|
}
|
|
|
|
db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
func MustConnectPgx(t testing.TB) *pgx.Conn {
|
|
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func MustClose(t testing.TB, conn interface {
|
|
Close() error
|
|
}) {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func MustCloseContext(t testing.TB, conn interface {
|
|
Close(context.Context) error
|
|
}) {
|
|
err := conn.Close(context.Background())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
type forceTextEncoder struct {
|
|
e pgtype.TextEncoder
|
|
}
|
|
|
|
func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
|
return f.e.EncodeText(ci, buf)
|
|
}
|
|
|
|
type forceBinaryEncoder struct {
|
|
e pgtype.BinaryEncoder
|
|
}
|
|
|
|
func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
|
return f.e.EncodeBinary(ci, buf)
|
|
}
|
|
|
|
func ForceEncoder(e interface{}, formatCode int16) interface{} {
|
|
switch formatCode {
|
|
case pgx.TextFormatCode:
|
|
if e, ok := e.(pgtype.TextEncoder); ok {
|
|
return forceTextEncoder{e: e}
|
|
}
|
|
case pgx.BinaryFormatCode:
|
|
if e, ok := e.(pgtype.BinaryEncoder); ok {
|
|
return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
|
|
TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
|
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
|
TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
|
}
|
|
}
|
|
|
|
func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustCloseContext(t, conn)
|
|
|
|
_, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for i, v := range values {
|
|
for _, paramFormat := range formats {
|
|
for _, resultFormat := range formats {
|
|
vEncoder := ForceEncoder(v, paramFormat.formatCode)
|
|
if vEncoder == nil {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name)
|
|
continue
|
|
}
|
|
switch resultFormat.formatCode {
|
|
case pgx.TextFormatCode:
|
|
if _, ok := v.(pgtype.TextEncoder); !ok {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
|
|
continue
|
|
}
|
|
case pgx.BinaryFormatCode:
|
|
if _, ok := v.(pgtype.BinaryEncoder); !ok {
|
|
t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
|
|
err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectDatabaseSQL(t, driverName)
|
|
defer MustClose(t, conn)
|
|
|
|
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
for i, v := range values {
|
|
// Derefence value if it is a pointer
|
|
derefV := v
|
|
refVal := reflect.ValueOf(v)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err := ps.QueryRow(v).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", driverName, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
|
|
type NormalizeTest struct {
|
|
SQL string
|
|
Value interface{}
|
|
}
|
|
|
|
func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) {
|
|
TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool {
|
|
return reflect.DeepEqual(a, b)
|
|
})
|
|
}
|
|
|
|
func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc)
|
|
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
|
TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc)
|
|
}
|
|
}
|
|
|
|
func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectPgx(t)
|
|
defer MustCloseContext(t, conn)
|
|
|
|
formats := []struct {
|
|
name string
|
|
formatCode int16
|
|
}{
|
|
{name: "TextFormat", formatCode: pgx.TextFormatCode},
|
|
{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
for _, fc := range formats {
|
|
psName := fmt.Sprintf("test%d", i)
|
|
_, err := conn.Prepare(context.Background(), psName, tt.SQL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
queryResultFormats := pgx.QueryResultFormats{fc.formatCode}
|
|
if ForceEncoder(tt.Value, fc.formatCode) == nil {
|
|
t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
|
|
continue
|
|
}
|
|
// Derefence value if it is a pointer
|
|
derefV := tt.Value
|
|
refVal := reflect.ValueOf(tt.Value)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err = conn.QueryRow(context.Background(), psName, queryResultFormats).Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", fc.name, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
|
conn := MustConnectDatabaseSQL(t, driverName)
|
|
defer MustClose(t, conn)
|
|
|
|
for i, tt := range tests {
|
|
ps, err := conn.Prepare(tt.SQL)
|
|
if err != nil {
|
|
t.Errorf("%d. %v", i, err)
|
|
continue
|
|
}
|
|
|
|
// Derefence value if it is a pointer
|
|
derefV := tt.Value
|
|
refVal := reflect.ValueOf(tt.Value)
|
|
if refVal.Kind() == reflect.Ptr {
|
|
derefV = refVal.Elem().Interface()
|
|
}
|
|
|
|
result := reflect.New(reflect.TypeOf(derefV))
|
|
err = ps.QueryRow().Scan(result.Interface())
|
|
if err != nil {
|
|
t.Errorf("%v %d: %v", driverName, i, err)
|
|
}
|
|
|
|
if !eqFunc(result.Elem().Interface(), derefV) {
|
|
t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
|
|
}
|
|
}
|
|
}
|