From 11223497b3e7b4531bcd7cb827ad71f36ed4efcb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Jan 2022 20:42:12 -0600 Subject: [PATCH] Restore record support --- pgtype/pgtype.go | 4 +- pgtype/record_codec.go | 116 ++++++++++++++++++++++++++++++++++++ pgtype/record_codec_test.go | 72 ++++++++++++++++++++++ 3 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 pgtype/record_codec.go create mode 100644 pgtype/record_codec_test.go diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 0c8f4763..ab317f6e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -86,6 +86,7 @@ const ( VarbitArrayOID = 1563 NumericOID = 1700 RecordOID = 2249 + RecordArrayOID = 2287 UUIDOID = 2950 UUIDArrayOID = 2951 JSONBOID = 3802 @@ -211,7 +212,6 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) - // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) @@ -245,6 +245,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}}) ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}}) ci.RegisterDataType(DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + ci.RegisterDataType(DataType{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) ci.RegisterDataType(DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) @@ -285,6 +286,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}}) ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}}) ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PolygonOID]}}) + ci.RegisterDataType(DataType{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[RecordOID]}}) ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TextOID]}}) ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TIDOID]}}) ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}}) diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go new file mode 100644 index 00000000..31001b1f --- /dev/null +++ b/pgtype/record_codec.go @@ -0,0 +1,116 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// ArrayGetter is a type that can be converted into a PostgreSQL array. + +// RecordCodec is a codec for the generic PostgreSQL record type such as is created with the "row" function. Record can +// only decode the binary format. The text format output format from PostgreSQL does not include type information and +// is therefore impossible to decode. Encoding is impossible because PostgreSQL does not support input of generic +// records. +type RecordCodec struct{} + +func (RecordCodec) FormatSupported(format int16) bool { + return format == BinaryFormatCode +} + +func (RecordCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (RecordCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + return nil +} + +func (RecordCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + if format == BinaryFormatCode { + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanBinaryRecordToCompositeIndexScanner{ci: ci} + } + } + + return nil +} + +type scanPlanBinaryRecordToCompositeIndexScanner struct { + ci *ConnInfo +} + +func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target interface{}) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeBinaryScanner(plan.ci, src) + for i := 0; scanner.Next(); i++ { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.ci.PlanScan(scanner.OID(), BinaryFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), fieldTarget) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +func (RecordCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + return nil, fmt.Errorf("not implemented") +} + +func (RecordCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + scanner := NewCompositeBinaryScanner(ci, src) + values := make([]interface{}, scanner.FieldCount()) + for i := 0; scanner.Next(); i++ { + var v interface{} + fieldPlan := ci.PlanScan(scanner.OID(), BinaryFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[i] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } + +} diff --git a/pgtype/record_codec_test.go b/pgtype/record_codec_test.go new file mode 100644 index 00000000..14018e9e --- /dev/null +++ b/pgtype/record_codec_test.go @@ -0,0 +1,72 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestRecordCodec(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var a string + var b int32 + err := conn.QueryRow(context.Background(), `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b}) + require.NoError(t, err) + + require.Equal(t, "foo", a) + require.Equal(t, int32(42), b) +} + +func TestRecordCodecDecodeValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for _, tt := range []struct { + sql string + expected interface{} + }{ + { + sql: `select row()`, + expected: []interface{}{}, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: []interface{}{"foo", int32(42)}, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: []interface{}{float32(100), float32(1.09)}, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: []interface{}{"foo", []interface{}{int32(1), int32(2), nil, int32(4)}, int32(42)}, + }, + { + sql: `select row(null)`, + expected: []interface{}{nil}, + }, + { + sql: `select null::record`, + expected: nil, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(context.Background(), tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } +}