2
0
Files
pgx/pgtype/pgtype_test.go
T
Jack Christensen 456a242f5c Unregistered OIDs are handled the same as unknown OIDs
This improves handling of unregistered types. In general, they should
"just work". But there are performance benefits gained and some edge
cases avoided by registering types. Updated documentation to mention
this.

https://github.com/jackc/pgx/issues/1445
2022-12-23 13:14:56 -06:00

472 lines
12 KiB
Go

package pgtype_test
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"os"
"regexp"
"strconv"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var defaultConnTestRunner pgxtest.ConnTestRunner
func init() {
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
return config
}
}
// Test for renamed types
type _string string
type _bool bool
type _int8 int8
type _int16 int16
type _int16Slice []int16
type _int32Slice []int32
type _int64Slice []int64
type _float32Slice []float32
type _float64Slice []float64
type _byteSlice []byte
// unregisteredOID represents a actual type that is not registered. Cannot use 0 because that represents that the type
// is not known (e.g. when using the simple protocol).
const unregisteredOID = uint32(1)
func mustParseCIDR(t testing.TB, s string) *net.IPNet {
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
t.Fatal(err)
}
return ipnet
}
func mustParseInet(t testing.TB, s string) *net.IPNet {
ip, ipnet, err := net.ParseCIDR(s)
if err == nil {
if ipv4 := ip.To4(); ipv4 != nil {
ipnet.IP = ipv4
} else {
ipnet.IP = ip
}
return ipnet
}
// May be bare IP address.
//
ip = net.ParseIP(s)
if ip == nil {
t.Fatal(errors.New("unable to parse inet address"))
}
ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
if ipv4 := ip.To4(); ipv4 != nil {
ipnet.IP = ipv4
ipnet.Mask = net.CIDRMask(32, 32)
}
return ipnet
}
func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr {
addr, err := net.ParseMAC(s)
if err != nil {
t.Fatal(err)
}
return addr
}
func skipCockroachDB(t testing.TB, msg string) {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())
if conn.PgConn().ParameterStatus("crdb_version") != "" {
t.Skip(msg)
}
}
func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return
}
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
require.NoError(t, err)
if serverVersion < minVersion {
t.Skipf("Test requires PostgreSQL v%d+", minVersion)
}
}
// sqlScannerFunc lets an arbitrary function be used as a sql.Scanner.
type sqlScannerFunc func(src any) error
func (f sqlScannerFunc) Scan(src any) error {
return f(src)
}
// driverValuerFunc lets an arbitrary function be used as a driver.Valuer.
type driverValuerFunc func() (driver.Value, error)
func (f driverValuerFunc) Value() (driver.Value, error) {
return f()
}
func TestMapScanNilIsNoOp(t *testing.T) {
m := pgtype.NewMap()
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil)
assert.NoError(t, err)
}
func TestMapScanTextFormatInterfacePtr(t *testing.T) {
m := pgtype.NewMap()
var got any
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got)
require.NoError(t, err)
assert.Equal(t, "foo", got)
}
func TestMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) {
m := pgtype.NewMap()
var got []byte
err := m.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got)
require.NoError(t, err)
assert.Equal(t, []byte("{}"), got)
}
func TestMapScanBinaryFormatInterfacePtr(t *testing.T) {
m := pgtype.NewMap()
var got any
err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got)
require.NoError(t, err)
assert.Equal(t, "foo", got)
}
func TestMapScanUnknownOIDToStringsAndBytes(t *testing.T) {
unknownOID := uint32(999999)
srcBuf := []byte("foo")
m := pgtype.NewMap()
var s string
err := m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s)
assert.NoError(t, err)
assert.Equal(t, "foo", s)
var rs _string
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs)
assert.NoError(t, err)
assert.Equal(t, "foo", string(rs))
var b []byte
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b)
assert.NoError(t, err)
assert.Equal(t, []byte("foo"), b)
var rb _byteSlice
err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb)
assert.NoError(t, err)
assert.Equal(t, []byte("foo"), []byte(rb))
}
func TestMapScanPointerToNilStructDoesNotCrash(t *testing.T) {
m := pgtype.NewMap()
type myStruct struct{}
var p *myStruct
err := m.Scan(0, pgx.TextFormatCode, []byte("(foo,bar)"), &p)
require.NotNil(t, err)
}
func TestMapScanUnknownOIDTextFormat(t *testing.T) {
m := pgtype.NewMap()
var n int32
err := m.Scan(0, pgx.TextFormatCode, []byte("123"), &n)
assert.NoError(t, err)
assert.EqualValues(t, 123, n)
}
func TestMapScanUnknownOIDIntoSQLScanner(t *testing.T) {
m := pgtype.NewMap()
var s sql.NullString
err := m.Scan(0, pgx.TextFormatCode, []byte(nil), &s)
assert.NoError(t, err)
assert.Equal(t, "", s.String)
assert.False(t, s.Valid)
}
type pgCustomInt int64
func (ci *pgCustomInt) Scan(src interface{}) error {
*ci = pgCustomInt(src.(int64))
return nil
}
func TestScanPlanBinaryInt32ScanScanner(t *testing.T) {
m := pgtype.NewMap()
src := []byte{0, 42}
var v pgCustomInt
plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.EqualValues(t, 42, v)
ptr := new(pgCustomInt)
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = new(pgCustomInt)
err = plan.Scan(nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
ptr = nil
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = nil
plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
}
// Test for https://github.com/jackc/pgtype/issues/164
func TestScanPlanInterface(t *testing.T) {
m := pgtype.NewMap()
src := []byte{0, 42}
var v interface{}
plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v)
err := plan.Scan(src, v)
assert.Error(t, err)
}
// https://github.com/jackc/pgx/issues/1263
func TestMapScanPtrToPtrToSlice(t *testing.T) {
m := pgtype.NewMap()
src := []byte("{foo,bar}")
var v *[]string
plan := m.PlanScan(pgtype.TextArrayOID, pgtype.TextFormatCode, &v)
err := plan.Scan(src, &v)
require.NoError(t, err)
require.Equal(t, []string{"foo", "bar"}, *v)
}
type databaseValuerString string
func (s databaseValuerString) Value() (driver.Value, error) {
return fmt.Sprintf("%d", len(s)), nil
}
// https://github.com/jackc/pgx/issues/1319
func TestMapEncodeTextFormatDatabaseValuerThatIsRenamedSimpleType(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerString("foo")
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
require.NoError(t, err)
require.Equal(t, "3", string(buf))
}
type databaseValuerFmtStringer string
func (s databaseValuerFmtStringer) Value() (driver.Value, error) {
return nil, nil
}
func (s databaseValuerFmtStringer) String() string {
return "foobar"
}
// https://github.com/jackc/pgx/issues/1311
func TestMapEncodeTextFormatDatabaseValuerThatIsFmtStringer(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerFmtStringer("")
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
require.NoError(t, err)
require.Nil(t, buf)
}
type databaseValuerStringFormat struct {
n int32
}
func (v databaseValuerStringFormat) Value() (driver.Value, error) {
return fmt.Sprint(v.n), nil
}
func TestMapEncodeBinaryFormatDatabaseValuerThatReturnsString(t *testing.T) {
m := pgtype.NewMap()
src := databaseValuerStringFormat{n: 42}
buf, err := m.Encode(pgtype.Int4OID, pgtype.BinaryFormatCode, src, nil)
require.NoError(t, err)
require.Equal(t, []byte{0, 0, 0, 42}, buf)
}
// https://github.com/jackc/pgx/issues/1445
func TestMapEncodeDatabaseValuerThatReturnsStringIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, driverValuerFunc(func() (driver.Value, error) { return "foo", nil }), nil)
require.NoError(t, err)
require.Equal(t, []byte("foo"), buf)
}
// https://github.com/jackc/pgx/issues/1445
func TestMapEncodeDatabaseValuerThatReturnsByteSliceIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, driverValuerFunc(func() (driver.Value, error) { return []byte{0, 1, 2, 3}, nil }), nil)
require.NoError(t, err)
require.Equal(t, []byte(`\x00010203`), buf)
}
func TestMapEncodeStringIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, "foo", nil)
require.NoError(t, err)
require.Equal(t, []byte("foo"), buf)
}
func TestMapEncodeByteSliceIntoUnregisteredTypeTextFormat(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, []byte{0, 1, 2, 3}, nil)
require.NoError(t, err)
require.Equal(t, []byte(`\x00010203`), buf)
}
// https://github.com/jackc/pgx/issues/1326
func TestMapScanPointerToRenamedType(t *testing.T) {
srcBuf := []byte("foo")
m := pgtype.NewMap()
var rs *_string
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, srcBuf, &rs)
assert.NoError(t, err)
require.NotNil(t, rs)
assert.Equal(t, "foo", string(*rs))
}
// https://github.com/jackc/pgx/issues/1326
func TestMapScanNullToWrongType(t *testing.T) {
m := pgtype.NewMap()
var n *int32
err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, nil, &n)
assert.NoError(t, err)
assert.Nil(t, n)
var pn pgtype.Int4
err = m.Scan(pgtype.TextOID, pgx.TextFormatCode, nil, &pn)
assert.NoError(t, err)
assert.False(t, pn.Valid)
}
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v pgtype.Int4
for i := 0; i < b.N; i++ {
v = pgtype.Int4{}
err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
if err != nil {
b.Fatal(err)
}
if v != (pgtype.Int4{Int32: 42, Valid: true}) {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkMapScanInt4IntoGoInt32(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v int32
for i := 0; i < b.N; i++ {
v = 0
err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v)
if err != nil {
b.Fatal(err)
}
if v != 42 {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v pgtype.Int4
plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
for i := 0; i < b.N; i++ {
v = pgtype.Int4{}
err := plan.Scan(src, &v)
if err != nil {
b.Fatal(err)
}
if v != (pgtype.Int4{Int32: 42, Valid: true}) {
b.Fatal("scan failed due to bad value")
}
}
}
func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
var v int32
plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v)
for i := 0; i < b.N; i++ {
v = 0
err := plan.Scan(src, &v)
if err != nil {
b.Fatal(err)
}
if v != 42 {
b.Fatal("scan failed due to bad value")
}
}
}
func isExpectedEq(a any) func(any) bool {
return func(v any) bool {
return a == v
}
}