2
0

Conn.LoadType supports domain types

If the underlying type is registered then use the same Codec.

fixes https://github.com/jackc/pgx/issues/1373
This commit is contained in:
Jack Christensen
2022-11-12 08:10:46 -06:00
parent b265fedd75
commit 5b6fb75669
2 changed files with 15 additions and 10 deletions
+9 -1
View File
@@ -1147,8 +1147,9 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
} }
var typtype string var typtype string
var typbasetype uint32
err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1173,6 +1174,13 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
} }
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
case "d": // domain
dt, ok := c.TypeMap().TypeForOID(typbasetype)
if !ok {
return nil, errors.New("domain base type OID not registered")
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
case "e": // enum case "e": // enum
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
default: default:
+6 -9
View File
@@ -837,24 +837,21 @@ func TestDomainType(t *testing.T) {
// uint64 but a result OID of the underlying numeric. // uint64 but a result OID of the underlying numeric.
var s string var s string
err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s) err := conn.QueryRow(ctx, "select $1::uint64", "24").Scan(&s)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "24", s) require.Equal(t, "24", s)
// Register type // Register type
var uint64OID uint32 uint64Type, err := conn.LoadType(ctx, "uint64")
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) require.NoError(t, err)
if err != nil { conn.TypeMap().RegisterType(uint64Type)
t.Fatalf("did not find uint64 OID, %v", err)
}
conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}})
var n uint64 var n uint64
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) err = conn.QueryRow(ctx, "select $1::uint64", uint64(24)).Scan(&n)
require.NoError(t, err) require.NoError(t, err)
// String is still an acceptable argument after registration // String is still an acceptable argument after registration
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) err = conn.QueryRow(ctx, "select $1::uint64", "7").Scan(&n)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }