2
0
Files
pgx/pgtype/pgtype_test.go
T
Jack Christensen 222e3b37bc Prefer driver.Value over wrap plans when encoding
This is tricky due to driver.Valuer returning any. For example, we can
plan for fmt.Stringer because it always returns a string.

Because of this driver.Valuer was always handled as the last option. But
with pgx v5 now having the ability to find underlying types like a
string and supporting fmt.Stringer it meant that driver.Valuer was
often not getting called because something else was found first.

This change tries driver.Valuer immediately after the initial PlanScan
for the Codec. So a type that directly implements a pgx interface should
be used, but driver.Valuer will be prefered before all the attempts to
handle renamed types, pointer deferencing, etc.

fixes https://github.com/jackc/pgx/issues/1319
fixes https://github.com/jackc/pgx/issues/1311
2022-10-01 12:20:23 -05:00

397 lines
9.2 KiB
Go

package pgtype_test
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"os"
"regexp"
"strconv"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var defaultConnTestRunner pgxtest.ConnTestRunner
func init() {
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
return config
}
}
// Test for renamed types
type _string string
type _bool bool
type _int8 int8
type _int16 int16
type _int16Slice []int16
type _int32Slice []int32
type _int64Slice []int64
type _float32Slice []float32
type _float64Slice []float64
type _byteSlice []byte
func mustParseCIDR(t testing.TB, s string) *net.IPNet {
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
t.Fatal(err)
}
return ipnet
}
func mustParseInet(t testing.TB, s string) *net.IPNet {
ip, ipnet, err := net.ParseCIDR(s)
if err == nil {
if ipv4 := ip.To4(); ipv4 != nil {
ipnet.IP = ipv4
} else {
ipnet.IP = ip
}
return ipnet
}
// May be bare IP address.
//
ip = net.ParseIP(s)
if ip == nil {
t.Fatal(errors.New("unable to parse inet address"))
}
ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
if ipv4 := ip.To4(); ipv4 != nil {
ipnet.IP = ipv4
ipnet.Mask = net.CIDRMask(32, 32)
}
return ipnet
}
func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr {
addr, err := net.ParseMAC(s)
if err != nil {
t.Fatal(err)
}
return addr
}
func skipCockroachDB(t testing.TB, msg string) {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())
if conn.PgConn().ParameterStatus("crdb_version") != "" {
t.Skip(msg)
}
}
func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return
}
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
require.NoError(t, err)
if serverVersion < minVersion {
t.Skipf("Test requires PostgreSQL v%d+", minVersion)
}
}
func TestMapScanNilIsNoOp(t *testing.T) {
m := pgtype.NewMap()
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil)
assert.NoError(t, err)
}
func TestMapScanTextFormatInterfacePtr(t *testing.T) {
m := pgtype.NewMap()
var got any
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got)
require.NoError(t, err)
assert.Equal(t, "foo", got)
}
func TestMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) {
m := pgtype.NewMap()
var got []byte
err := m.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got)
require.NoError(t, err)
assert.Equal(t, []byte("{}"), got)
}
func TestMapScanBinaryFormatInterfacePtr(t *testing.T) {
m := pgtype.NewMap()
var got any
err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got)
require.NoError(t, err)
assert.Equal(t, "foo", got)
}
func TestMapScanUnknownOIDToStringsAndBytes(t *testing.T) {
unknownOID := uint32(999999)
srcBuf := []byte("foo")
m := pgtype.NewMap()
var s string
err := m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s)
assert.NoError(t, err)
assert.Equal(t, "foo", s)
var rs _string
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs)
assert.NoError(t, err)
assert.Equal(t, "foo", string(rs))
var b []byte
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b)
assert.NoError(t, err)
assert.Equal(t, []byte("foo"), b)
var rb _byteSlice
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb)
assert.NoError(t, err)
assert.Equal(t, []byte("foo"), []byte(rb))
}
func TestMapScanPointerToNilStructDoesNotCrash(t *testing.T) {
m := pgtype.NewMap()
type myStruct struct{}
var p *myStruct
err := m.Scan(0, pgx.TextFormatCode, []byte("(foo,bar)"), &p)
require.NotNil(t, err)
}
func TestMapScanUnknownOIDTextFormat(t *testing.T) {
m := pgtype.NewMap()
var n int32
err := m.Scan(0, pgx.TextFormatCode, []byte("123"), &n)
assert.NoError(t, err)
assert.EqualValues(t, 123, n)
}
func TestMapScanUnknownOIDIntoSQLScanner(t *testing.T) {
m := pgtype.NewMap()
var s sql.NullString
err := m.Scan(0, pgx.TextFormatCode, []byte(nil), &s)
assert.NoError(t, err)
assert.Equal(t, "", s.String)
assert.False(t, s.Valid)
}
type pgCustomInt int64
func (ci *pgCustomInt) Scan(src interface{}) error {
*ci = pgCustomInt(src.(int64))
return nil
}
func TestScanPlanBinaryInt32ScanScanner(t *testing.T) {
m := pgtype.NewMap()
src := []byte{0, 42}
var v pgCustomInt
plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.EqualValues(t, 42, v)
ptr := new(pgCustomInt)
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = new(pgCustomInt)
err = plan.Scan(nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
ptr = nil
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = nil
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
}
// Test for https://github.com/jackc/pgtype/issues/164
func TestScanPlanInterface(t *testing.T) {
m := pgtype.NewMap()
src := []byte{0, 42}
var v interface{}
plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v)
err := plan.Scan(src, v)
assert.Error(t, err)
}
// https://github.com/jackc/pgx/issues/1263
func TestMapScanPtrToPtrToSlice(t *testing.T) {
m := pgtype.NewMap()
src := []byte("{foo,bar}")
var v *[]string
plan := m.PlanScan(pgtype.TextArrayOID, pgtype.TextFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.Equal(t, []string{"foo", "bar"}, *v)
}
type databaseValuerString string
func (s databaseValuerString) Value() (driver.Value, error) {
return fmt.Sprintf("%d", len(s)), nil
}
// https://github.com/jackc/pgx/issues/1319
func TestMapEncodeTextFormatDatabaseValuerThatIsRenamedSimpleType(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerString("foo")
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
require.NoError(t, err)
require.Equal(t, "3", string(buf))
}
type databaseValuerFmtStringer string
func (s databaseValuerFmtStringer) Value() (driver.Value, error) {
return nil, nil
}
func (s databaseValuerFmtStringer) String() string {
return "foobar"
}
// https://github.com/jackc/pgx/issues/1311
func TestMapEncodeTextFormatDatabaseValuerThatIsFmtStringer(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerFmtStringer("")
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
require.NoError(t, err)
require.Nil(t, buf)
}
type databaseValuerStringFormat struct {
n int32
}
func (v databaseValuerStringFormat) Value() (driver.Value, error) {
return fmt.Sprint(v.n), nil
}
func TestMapEncodeBinaryFormatDatabaseValuerThatReturnsString(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerStringFormat{n: 42}
buf, err := m.Encode(pgtype.Int4OID, pgtype.BinaryFormatCode, src, nil)
require.NoError(t, err)
require.Equal(t, []byte{0, 0, 0, 42}, buf)
}
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v pgtype.Int4
for i := 0; i < b.N; i++ {
v = pgtype.Int4{}
err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
if err != nil {
b.Fatal(err)
}
if v != (pgtype.Int4{Int32: 42, Valid: true}) {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkMapScanInt4IntoGoInt32(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v int32
for i := 0; i < b.N; i++ {
v = 0
err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
if err != nil {
b.Fatal(err)
}
if v != 42 {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v pgtype.Int4
plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
for i := 0; i < b.N; i++ {
v = pgtype.Int4{}
err := plan.Scan(src, &v)
if err != nil {
b.Fatal(err)
}
if v != (pgtype.Int4{Int32: 42, Valid: true}) {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v int32
plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
for i := 0; i < b.N; i++ {
v = 0
err := plan.Scan(src, &v)
if err != nil {
b.Fatal(err)
}
if v != 42 {
b.Fatal("scan failed due to bad value")
}
}
}
func isExpectedEq(a any) func(any) bool {
return func(v any) bool {
return a == v
}
}