From 02372f1c3c30d54d271ac0e32d16a866f22ac620 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 15:12:09 -0600 Subject: [PATCH] Add DecodeValue to composites --- pgtype/composite.go | 69 ++++++++++++++++++++++++++++++++++------ pgtype/composite_test.go | 39 +++++++++++++++++++++++ pgtype/record_codec.go | 11 ++++++- 3 files changed, 109 insertions(+), 10 deletions(-) diff --git a/pgtype/composite.go b/pgtype/composite.go index d21ab665..2ccc7b1d 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -209,11 +209,16 @@ func (c *CompositeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format return nil, nil } - // var n int64 - // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - // return n, err - - return nil, fmt.Errorf("not implemented") + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } } func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { @@ -221,11 +226,57 @@ func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src return nil, nil } - // var n int16 - // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - // return n, err + switch format { + case TextFormatCode: + scanner := NewCompositeTextScanner(ci, src) + values := make(map[string]interface{}, len(c.Fields)) + for i := 0; scanner.Next() && i < len(c.Fields); i++ { + var v interface{} + fieldPlan := ci.PlanScan(c.Fields[i].DataType.OID, TextFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].DataType.OID, v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[c.Fields[i].Name] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + case BinaryFormatCode: + scanner := NewCompositeBinaryScanner(ci, src) + values := make(map[string]interface{}, len(c.Fields)) + for i := 0; scanner.Next() && i < len(c.Fields); i++ { + var v interface{} + fieldPlan := ci.PlanScan(scanner.OID(), BinaryFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[c.Fields[i].Name] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } - return nil, fmt.Errorf("not implemented") } type CompositeBinaryScanner struct { diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index 9a0eff2a..66db4281 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -162,3 +162,42 @@ create type point3d as ( require.Equalf(t, input, output, "%v", format.name) } } + +func TestCompositeCodecDecodeValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type point3d") + + dt, err := conn.LoadDataType(context.Background(), "point3d") + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(*dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + rows, err := conn.Query(context.Background(), "select '(1,2,3)'::point3d", pgx.QueryResultFormats{format.code}) + require.NoErrorf(t, err, "%v", format.name) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoErrorf(t, err, "%v", format.name) + require.Lenf(t, values, 1, "%v", format.name) + require.Equalf(t, map[string]interface{}{"x": 1.0, "y": 2.0, "z": 3.0}, values[0], "%v", format.name) + require.False(t, rows.Next()) + require.NoErrorf(t, rows.Err(), "%v", format.name) + } +} diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go index 31001b1f..92c197b2 100644 --- a/pgtype/record_codec.go +++ b/pgtype/record_codec.go @@ -75,7 +75,16 @@ func (RecordCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 return nil, nil } - return nil, fmt.Errorf("not implemented") + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } } func (RecordCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {