From 218663463828a6358d2a3004c00180d9d986a511 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 11:55:24 -0500 Subject: [PATCH] Add CompositeFields encoders --- composite_fields.go | 112 +++++++++++++++++++++++++++++ composite_fields_test.go | 147 +++++++++++++++++++++++++++++++++++++++ composite_type.go | 14 ++++ 3 files changed, 273 insertions(+) diff --git a/composite_fields.go b/composite_fields.go index 64a17b55..751adce8 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -1,11 +1,17 @@ package pgtype import ( + "encoding/binary" + + "github.com/jackc/pgio" errors "golang.org/x/xerrors" ) // CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a // nullable value use a *CompositeFields. It will be set to nil in case of null. +// +// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not +// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType. type CompositeFields []interface{} func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { @@ -74,3 +80,109 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { return nil } + +// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using +// CompositeFields to encode directly. +func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + buf = append(buf, '(') + + fieldBuf := make([]byte, 0, 32) + + for _, f := range cf { + if f != nil { + fieldBuf = fieldBuf[0:0] + if textEncoder, ok := f.(TextEncoder); ok { + var err error + fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf) + if err != nil { + return nil, err + } + if fieldBuf != nil { + buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + } else { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, errors.Errorf("Unknown data type for %#v", f) + } + + err := dt.Value.Set(f) + if err != nil { + return nil, err + } + + if textEncoder, ok := dt.Value.(TextEncoder); ok { + var err error + fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf) + if err != nil { + return nil, err + } + if fieldBuf != nil { + buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + } else { + return nil, errors.Errorf("Cannot encode text format for %v", f) + } + } + } + buf = append(buf, ',') + } + + buf[len(buf)-1] = ')' + return buf, nil +} + +// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is +// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary +// composite format requires the OID of each field to be specified the only types that will work are those known to +// ConnInfo. +// +// In particular: +// +// * Nil cannot be used because there is no way to determine what type it. +// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail. +// * No dereferencing will be done. e.g. *Text must be used instead of Text. +func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + buf = pgio.AppendUint32(buf, uint32(len(cf))) + + for _, f := range cf { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, errors.Errorf("Unknown OID for %#v", f) + } + + buf = pgio.AppendUint32(buf, dt.OID) + lengthPos := len(buf) + buf = pgio.AppendInt32(buf, -1) + + if binaryEncoder, ok := f.(BinaryEncoder); ok { + fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf))) + buf = fieldBuf + } + } else { + err := dt.Value.Set(f) + if err != nil { + return nil, err + } + if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { + fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf))) + buf = fieldBuf + } + } else { + return nil, errors.Errorf("Cannot encode binary format for %v", f) + } + } + } + + return buf, nil +} diff --git a/composite_fields_test.go b/composite_fields_test.go index d53e48ec..dc4d4c29 100644 --- a/composite_fields_test.go +++ b/composite_fields_test.go @@ -8,6 +8,7 @@ import ( "github.com/jackc/pgtype/testutil" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCompositeFieldsDecode(t *testing.T) { @@ -123,4 +124,150 @@ func TestCompositeFieldsDecode(t *testing.T) { assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) } } + + // Skip nil fields + { + var a int32 + var c float64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, nil, &c}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, 1, a, "Format: %v", format) + assert.EqualValuesf(t, 2.1, c, "Format: %v", format) + } + } +} + +func TestCompositeFieldsEncode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists cf_encode; + +create type cf_encode as ( + a text, + b int4, + c text, + d float8, + e text +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type cf_encode") + + // Use simple protocol to force text or binary encoding + simpleProtocols := []bool{true, false} + + // Assorted values + { + var a string + var b int32 + var c string + var d float64 + var e string + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol) + } + } + } + + // untyped nil + { + var a pgtype.Text + var b int32 + var c string + var d pgtype.Float8 + var e pgtype.Text + + simpleProtocol := true + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) + } + + // untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema + // of the composite type. + simpleProtocol = false + err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol) + } + + // nulls, string "null", and empty string fields + { + var a pgtype.Text + var b int32 + var c string + var d pgtype.Float8 + var e pgtype.Text + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{&pgtype.Text{Status: pgtype.Null}, int32(1), "null", &pgtype.Float8{Status: pgtype.Null}, &pgtype.Text{Status: pgtype.Null}}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) + } + } + } + + // quotes and special characters + { + var a string + var b int32 + var c string + var d float64 + var e string + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow( + context.Background(), + `select $1::cf_encode`, + pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol) + } + } + } } diff --git a/composite_type.go b/composite_type.go index 03d88aea..b4b1ab28 100644 --- a/composite_type.go +++ b/composite_type.go @@ -2,6 +2,7 @@ package pgtype import ( "encoding/binary" + "strings" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -366,3 +367,16 @@ func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { func RecordAddNull(buf []byte, oid uint32) []byte { return pgio.AppendInt32(buf, int32(-1)) } + +var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteCompositeField(src string) string { + return `"` + quoteCompositeReplacer.Replace(src) + `"` +} + +func QuoteCompositeFieldIfNeeded(src string) string { + if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { + return quoteCompositeField(src) + } + return src +}