Conn.LoadType supports domain types
If the underlying type is registered then use the same Codec. fixes https://github.com/jackc/pgx/issues/1373
This commit is contained in:
@@ -1147,8 +1147,9 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
var typtype string
|
var typtype string
|
||||||
|
var typbasetype uint32
|
||||||
|
|
||||||
err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
|
err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1173,6 +1174,13 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
|
||||||
|
case "d": // domain
|
||||||
|
dt, ok := c.TypeMap().TypeForOID(typbasetype)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("domain base type OID not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
|
||||||
case "e": // enum
|
case "e": // enum
|
||||||
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
|
||||||
default:
|
default:
|
||||||
|
|||||||
+6
-9
@@ -837,24 +837,21 @@ func TestDomainType(t *testing.T) {
|
|||||||
// uint64 but a result OID of the underlying numeric.
|
// uint64 but a result OID of the underlying numeric.
|
||||||
|
|
||||||
var s string
|
var s string
|
||||||
err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s)
|
err := conn.QueryRow(ctx, "select $1::uint64", "24").Scan(&s)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "24", s)
|
require.Equal(t, "24", s)
|
||||||
|
|
||||||
// Register type
|
// Register type
|
||||||
var uint64OID uint32
|
uint64Type, err := conn.LoadType(ctx, "uint64")
|
||||||
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID)
|
require.NoError(t, err)
|
||||||
if err != nil {
|
conn.TypeMap().RegisterType(uint64Type)
|
||||||
t.Fatalf("did not find uint64 OID, %v", err)
|
|
||||||
}
|
|
||||||
conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}})
|
|
||||||
|
|
||||||
var n uint64
|
var n uint64
|
||||||
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
|
err = conn.QueryRow(ctx, "select $1::uint64", uint64(24)).Scan(&n)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// String is still an acceptable argument after registration
|
// String is still an acceptable argument after registration
|
||||||
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n)
|
err = conn.QueryRow(ctx, "select $1::uint64", "7").Scan(&n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user