Finish extraction of pgtype test helpers
This commit is contained in:
+17
-17
@@ -15,7 +15,7 @@ import (
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
||||
func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
||||
var sqlDriverName string
|
||||
switch driverName {
|
||||
case "github.com/lib/pq":
|
||||
@@ -34,7 +34,7 @@ func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
func mustConnectPgx(t testing.TB) *pgx.Conn {
|
||||
func MustConnectPgx(t testing.TB) *pgx.Conn {
|
||||
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -48,7 +48,7 @@ func mustConnectPgx(t testing.TB) *pgx.Conn {
|
||||
return conn
|
||||
}
|
||||
|
||||
func mustClose(t testing.TB, conn interface {
|
||||
func MustClose(t testing.TB, conn interface {
|
||||
Close() error
|
||||
}) {
|
||||
err := conn.Close()
|
||||
@@ -73,7 +73,7 @@ func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool
|
||||
return f.e.EncodeBinary(ci, w)
|
||||
}
|
||||
|
||||
func forceEncoder(e interface{}, formatCode int16) interface{} {
|
||||
func ForceEncoder(e interface{}, formatCode int16) interface{} {
|
||||
switch formatCode {
|
||||
case pgx.TextFormatCode:
|
||||
if e, ok := e.(pgtype.TextEncoder); ok {
|
||||
@@ -102,8 +102,8 @@ func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int
|
||||
}
|
||||
|
||||
func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
conn := MustConnectPgx(t)
|
||||
defer MustClose(t, conn)
|
||||
|
||||
ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName))
|
||||
if err != nil {
|
||||
@@ -121,7 +121,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []
|
||||
for i, v := range values {
|
||||
for _, fc := range formats {
|
||||
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
||||
vEncoder := forceEncoder(v, fc.formatCode)
|
||||
vEncoder := ForceEncoder(v, fc.formatCode)
|
||||
if vEncoder == nil {
|
||||
t.Logf("Skipping: %#v does not implement %v", v, fc.name)
|
||||
continue
|
||||
@@ -134,7 +134,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface())
|
||||
err := conn.QueryRow("test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("%v %d: %v", fc.name, i, err)
|
||||
}
|
||||
@@ -147,8 +147,8 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []
|
||||
}
|
||||
|
||||
func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
conn := MustConnectPgx(t)
|
||||
defer MustClose(t, conn)
|
||||
|
||||
for i, v := range values {
|
||||
// Derefence value if it is a pointer
|
||||
@@ -176,8 +176,8 @@ func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName str
|
||||
}
|
||||
|
||||
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectDatabaseSQL(t, driverName)
|
||||
defer mustClose(t, conn)
|
||||
conn := MustConnectDatabaseSQL(t, driverName)
|
||||
defer MustClose(t, conn)
|
||||
|
||||
ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
|
||||
if err != nil {
|
||||
@@ -223,8 +223,8 @@ func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc f
|
||||
}
|
||||
|
||||
func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
conn := MustConnectPgx(t)
|
||||
defer MustClose(t, conn)
|
||||
|
||||
formats := []struct {
|
||||
name string
|
||||
@@ -243,7 +243,7 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun
|
||||
}
|
||||
|
||||
ps.FieldDescriptions[0].FormatCode = fc.formatCode
|
||||
if forceEncoder(tt.Value, fc.formatCode) == nil {
|
||||
if ForceEncoder(tt.Value, fc.formatCode) == nil {
|
||||
t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
|
||||
continue
|
||||
}
|
||||
@@ -268,8 +268,8 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun
|
||||
}
|
||||
|
||||
func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectDatabaseSQL(t, driverName)
|
||||
defer mustClose(t, conn)
|
||||
conn := MustConnectDatabaseSQL(t, driverName)
|
||||
defer MustClose(t, conn)
|
||||
|
||||
for i, tt := range tests {
|
||||
ps, err := conn.Prepare(tt.SQL)
|
||||
|
||||
Reference in New Issue
Block a user