diff --git a/pgtype/composite_bench_test.go b/pgtype/composite_bench_test.go deleted file mode 100644 index ef57709b..00000000 --- a/pgtype/composite_bench_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgio" - "github.com/jackc/pgx/v5/pgtype" - "github.com/stretchr/testify/require" -) - -type MyCompositeRaw struct { - A int32 - B *string -} - -func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - buf = pgio.AppendUint32(buf, 2) - - buf = pgio.AppendUint32(buf, pgtype.Int4OID) - buf = pgio.AppendInt32(buf, 4) - buf = pgio.AppendInt32(buf, src.A) - - buf = pgio.AppendUint32(buf, pgtype.TextOID) - if src.B != nil { - buf = pgio.AppendInt32(buf, int32(len(*src.B))) - buf = append(buf, (*src.B)...) - } else { - buf = pgio.AppendInt32(buf, -1) - } - - return buf, nil -} - -func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - a := pgtype.Int4{} - b := pgtype.Text{} - - scanner := pgtype.NewCompositeBinaryScanner(ci, src) - scanner.ScanDecoder(&a) - scanner.ScanDecoder(&b) - - if scanner.Err() != nil { - return scanner.Err() - } - - dst.A = a.Int - if b.Valid { - dst.B = &b.String - } else { - dst.B = nil - } - - return nil -} - -var x []byte - -func BenchmarkBinaryEncodingManual(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyCompositeRaw{4, ptrS("ABCDEFG")} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - buf, _ = v.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingHelper(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyType{4, ptrS("ABCDEFG")} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - buf, _ = v.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingComposite(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - f1 := 2 - f2 := ptrS("bar") - c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, ci) - require.NoError(b, err) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - c.Set([]interface{}{f1, f2}) - buf, _ = c.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -func BenchmarkBinaryEncodingJSON(b *testing.B) { - buf := make([]byte, 0, 128) - ci := pgtype.NewConnInfo() - v := MyCompositeRaw{4, ptrS("ABCDEFG")} - j := pgtype.JSON{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - j.Set(v) - buf, _ = j.EncodeBinary(ci, buf[:0]) - } - x = buf -} - -var dstRaw MyCompositeRaw - -func BenchmarkBinaryDecodingManual(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - dst := MyCompositeRaw{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := dst.DecodeBinary(ci, buf) - E(err) - } - dstRaw = dst -} - -var dstMyType MyType - -func BenchmarkBinaryDecodingHelpers(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - dst := MyType{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := dst.DecodeBinary(ci, buf) - E(err) - } - dstMyType = dst -} - -var gf1 int -var gf2 *string - -func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { - ci := pgtype.NewConnInfo() - buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - var f1 int - var f2 *string - - c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, ci) - require.NoError(b, err) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := c.DecodeBinary(ci, buf) - if err != nil { - b.Fatal(err) - } - err = c.AssignTo([]interface{}{&f1, &f2}) - if err != nil { - b.Fatal(err) - } - } - gf1 = f1 - gf2 = f2 -} - -func BenchmarkBinaryDecodingJSON(b *testing.B) { - ci := pgtype.NewConnInfo() - j := pgtype.JSON{} - j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")}) - buf, _ := j.EncodeBinary(ci, nil) - - j = pgtype.JSON{} - dst := MyCompositeRaw{} - - b.ResetTimer() - for n := 0; n < b.N; n++ { - err := j.DecodeBinary(ci, buf) - E(err) - err = j.AssignTo(&dst) - E(err) - } - dstRaw = dst -} diff --git a/pgtype/composite_fields.go b/pgtype/composite_fields.go deleted file mode 100644 index e7ca89c7..00000000 --- a/pgtype/composite_fields.go +++ /dev/null @@ -1,107 +0,0 @@ -package pgtype - -import "fmt" - -// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a -// nullable value use a *CompositeFields. It will be set to nil in case of null. -// -// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not -// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType. -type CompositeFields []interface{} - -func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { - if len(cf) == 0 { - return fmt.Errorf("cannot decode into empty CompositeFields") - } - - if src == nil { - return fmt.Errorf("cannot decode unexpected null into CompositeFields") - } - - scanner := NewCompositeBinaryScanner(ci, src) - - for _, f := range cf { - scanner.ScanValue(f) - } - - if scanner.Err() != nil { - return scanner.Err() - } - - return nil -} - -func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { - if len(cf) == 0 { - return fmt.Errorf("cannot decode into empty CompositeFields") - } - - if src == nil { - return fmt.Errorf("cannot decode unexpected null into CompositeFields") - } - - scanner := NewCompositeTextScanner(ci, src) - - for _, f := range cf { - scanner.ScanValue(f) - } - - if scanner.Err() != nil { - return scanner.Err() - } - - return nil -} - -// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using -// CompositeFields to encode directly. -func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - b := NewCompositeTextBuilder(ci, buf) - - for _, f := range cf { - if paramEncoder, ok := f.(ParamEncoder); ok { - b.AppendEncoder(paramEncoder) - } else { - b.AppendValue(f) - } - } - - return b.Finish() -} - -// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is -// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary -// composite format requires the OID of each field to be specified the only types that will work are those known to -// ConnInfo. -// -// In particular: -// -// * Nil cannot be used because there is no way to determine what type it. -// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail. -// * No dereferencing will be done. e.g. *Text must be used instead of Text. -func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - b := NewCompositeBinaryBuilder(ci, buf) - - for _, f := range cf { - dt, ok := ci.DataTypeForValue(f) - if !ok { - return nil, fmt.Errorf("Unknown OID for %#v", f) - } - - if paramEncoder, ok := f.(ParamEncoder); ok { - b.AppendEncoder(dt.OID, paramEncoder) - } else { - err := dt.Value.Set(f) - if err != nil { - return nil, err - } - if paramEncoder, ok := dt.Value.(ParamEncoder); ok { - b.AppendEncoder(dt.OID, paramEncoder) - } else { - return nil, fmt.Errorf("Cannot encode binary format for %v", f) - } - } - } - - return b.Finish() -} diff --git a/pgtype/composite_fields_test.go b/pgtype/composite_fields_test.go deleted file mode 100644 index e73d8441..00000000 --- a/pgtype/composite_fields_test.go +++ /dev/null @@ -1,273 +0,0 @@ -package pgtype_test - -import ( - "context" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCompositeFieldsDecode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} - - // Assorted values - { - var a int32 - var b string - var c float64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, 1, a, "Format: %v", format) - assert.EqualValuesf(t, "hi", b, "Format: %v", format) - assert.EqualValuesf(t, 2.1, c, "Format: %v", format) - } - } - - // nulls, string "null", and empty string fields - { - var a pgtype.Text - var b string - var c pgtype.Text - var d string - var e pgtype.Text - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.Nilf(t, a.Get(), "Format: %v", format) - assert.EqualValuesf(t, "null", b, "Format: %v", format) - assert.Nilf(t, c.Get(), "Format: %v", format) - assert.EqualValuesf(t, "", d, "Format: %v", format) - assert.Nilf(t, e.Get(), "Format: %v", format) - } - } - - // null record - { - var a pgtype.Text - var b string - cf := pgtype.CompositeFields{&a, &b} - - for _, format := range formats { - // Cannot scan nil into - err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( - cf, - ) - if assert.Errorf(t, err, "Format: %v", format) { - continue - } - assert.NotNilf(t, cf, "Format: %v", format) - - // But can scan nil into *pgtype.CompositeFields - err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( - &cf, - ) - if assert.Errorf(t, err, "Format: %v", format) { - continue - } - assert.Nilf(t, cf, "Format: %v", format) - } - } - - // quotes and special characters - { - var a, b, c, d string - - for _, format := range formats { - err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b, &c, &d}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.Equalf(t, `"`, a, "Format: %v", format) - assert.Equalf(t, `foo bar`, b, "Format: %v", format) - assert.Equalf(t, `foo'bar`, c, "Format: %v", format) - assert.Equalf(t, `baz)bar`, d, "Format: %v", format) - } - } - - // arrays - { - var a []string - var b []int64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, &b}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format) - assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) - } - } - - // Skip nil fields - { - var a int32 - var c float64 - - for _, format := range formats { - err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( - pgtype.CompositeFields{&a, nil, &c}, - ) - if !assert.NoErrorf(t, err, "Format: %v", format) { - continue - } - - assert.EqualValuesf(t, 1, a, "Format: %v", format) - assert.EqualValuesf(t, 2.1, c, "Format: %v", format) - } - } -} - -func TestCompositeFieldsEncode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists cf_encode; - -create type cf_encode as ( - a text, - b int4, - c text, - d float8, - e text -);`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type cf_encode") - - // Use simple protocol to force text or binary encoding - simpleProtocols := []bool{true, false} - - // Assorted values - { - var a string - var b int32 - var c string - var d float64 - var e string - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol) - } - } - } - - // untyped nil - { - var a pgtype.Text - var b int32 - var c string - var d pgtype.Float8 - var e pgtype.Text - - simpleProtocol := true - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) - } - - // untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema - // of the composite type. - simpleProtocol = false - err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol) - } - - // nulls, string "null", and empty string fields - { - var a pgtype.Text - var b int32 - var c string - var d pgtype.Float8 - var e pgtype.Text - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{&pgtype.Text{}, int32(1), "null", &pgtype.Float8{}, &pgtype.Text{}}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) - assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) - } - } - } - - // quotes and special characters - { - var a string - var b int32 - var c string - var d float64 - var e string - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow( - context.Background(), - `select $1::cf_encode`, - pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`}, - ).Scan( - pgtype.CompositeFields{&a, &b, &c, &d, &e}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol) - assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol) - } - } - } -} diff --git a/pgtype/composite_type.go b/pgtype/composite_type.go deleted file mode 100644 index 85ab5910..00000000 --- a/pgtype/composite_type.go +++ /dev/null @@ -1,715 +0,0 @@ -package pgtype - -import ( - "encoding/binary" - "errors" - "fmt" - "reflect" - "strings" - - "github.com/jackc/pgio" -) - -type CompositeTypeField struct { - Name string - OID uint32 -} - -type CompositeType struct { - valid bool - - typeName string - - fields []CompositeTypeField - valueTranscoders []ValueTranscoder -} - -// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used -// for fields. All field OIDs must be previously registered in ci. -func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) { - valueTranscoders := make([]ValueTranscoder, len(fields)) - - for i := range fields { - dt, ok := ci.DataTypeForOID(fields[i].OID) - if !ok { - return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID) - } - - value := NewValue(dt.Value) - valueTranscoder, ok := value.(ValueTranscoder) - if !ok { - return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID) - } - - valueTranscoders[i] = valueTranscoder - } - - return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil -} - -// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length. -// Prefer NewCompositeType unless overriding the transcoding of fields is required. -func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) { - if len(fields) != len(values) { - return nil, errors.New("fields and valueTranscoders must have same length") - } - - return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil -} - -func (src CompositeType) Get() interface{} { - if !src.valid { - return nil - } - - results := make(map[string]interface{}, len(src.valueTranscoders)) - for i := range src.valueTranscoders { - results[src.fields[i].Name] = src.valueTranscoders[i].Get() - } - return results -} - -func (ct *CompositeType) NewTypeValue() Value { - a := &CompositeType{ - typeName: ct.typeName, - fields: ct.fields, - valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)), - } - - for i := range ct.valueTranscoders { - a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder) - } - - return a -} - -func (ct *CompositeType) TypeName() string { - return ct.typeName -} - -func (ct *CompositeType) Fields() []CompositeTypeField { - return ct.fields -} - -func (dst *CompositeType) setNil() { - dst.valid = false -} - -func (dst *CompositeType) Set(src interface{}) error { - if src == nil { - dst.setNil() - return nil - } - - switch value := src.(type) { - case []interface{}: - if len(value) != len(dst.valueTranscoders) { - return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders)) - } - for i, v := range value { - if err := dst.valueTranscoders[i].Set(v); err != nil { - return err - } - } - dst.valid = true - case *[]interface{}: - if value == nil { - dst.setNil() - return nil - } - return dst.Set(*value) - default: - return fmt.Errorf("Can not convert %v to Composite", src) - } - - return nil -} - -// AssignTo should never be called on composite value directly -func (src CompositeType) AssignTo(dst interface{}) error { - if !src.valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case []interface{}: - if len(v) != len(src.valueTranscoders) { - return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) - } - for i := range src.valueTranscoders { - if v[i] == nil { - continue - } - - err := assignToOrSet(src.valueTranscoders[i], v[i]) - if err != nil { - return fmt.Errorf("unable to assign to dst[%d]: %v", i, err) - } - } - return nil - case *[]interface{}: - return src.AssignTo(*v) - default: - if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct { - return err - } - - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func assignToOrSet(src Value, dst interface{}) error { - assignToErr := src.AssignTo(dst) - if assignToErr != nil { - // Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self. - setSucceeded := false - if setter, ok := dst.(Value); ok { - err := setter.Set(src.Get()) - setSucceeded = err == nil - } - if !setSucceeded { - return assignToErr - } - } - - return nil -} - -func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { - dstValue := reflect.ValueOf(dst) - if dstValue.Kind() != reflect.Ptr { - return false, nil - } - - if dstValue.IsNil() { - return false, nil - } - - dstElemValue := dstValue.Elem() - dstElemType := dstElemValue.Type() - - if dstElemType.Kind() != reflect.Struct { - return false, nil - } - - exportedFields := make([]int, 0, dstElemType.NumField()) - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) - if sf.PkgPath == "" { - exportedFields = append(exportedFields, i) - } - } - - if len(exportedFields) != len(src.valueTranscoders) { - return false, nil - } - - for i := range exportedFields { - err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) - if err != nil { - return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) - } - } - - return true, nil -} - -func (ct *CompositeType) BinaryFormatSupported() bool { - for _, vt := range ct.valueTranscoders { - if !vt.BinaryFormatSupported() { - return false - } - } - return true -} - -func (ct *CompositeType) TextFormatSupported() bool { - for _, vt := range ct.valueTranscoders { - if !vt.TextFormatSupported() { - return false - } - } - return true -} - -func (ct *CompositeType) PreferredFormat() int16 { - if ct.BinaryFormatSupported() { - return BinaryFormatCode - } - return TextFormatCode -} - -func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - if src == nil { - dst.setNil() - return nil - } - - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} - -func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - if !src.valid { - return nil, nil - } - - b := NewCompositeBinaryBuilder(ci, buf) - for i := range src.valueTranscoders { - b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i]) - } - - return b.Finish() -} - -// DecodeBinary implements BinaryDecoder interface. -// Opposite to Record, fields in a composite act as a "schema" -// and decoding fails if SQL value can't be assigned due to -// type mismatch -func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { - scanner := NewCompositeBinaryScanner(ci, buf) - - for _, f := range dst.valueTranscoders { - scanner.ScanDecoder(f) - } - - if scanner.Err() != nil { - return scanner.Err() - } - - dst.valid = true - - return nil -} - -func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { - scanner := NewCompositeTextScanner(ci, buf) - - for _, f := range dst.valueTranscoders { - scanner.ScanDecoder(f) - } - - if scanner.Err() != nil { - return scanner.Err() - } - - dst.valid = true - - return nil -} - -func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - if !src.valid { - return nil, nil - } - - b := NewCompositeTextBuilder(ci, buf) - for _, f := range src.valueTranscoders { - b.AppendEncoder(f) - } - - return b.Finish() -} - -type CompositeBinaryScanner struct { - ci *ConnInfo - rp int - src []byte - - fieldCount int32 - fieldBytes []byte - fieldOID uint32 - err error -} - -// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. -func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { - rp := 0 - if len(src[rp:]) < 4 { - return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} - } - - fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - - return &CompositeBinaryScanner{ - ci: ci, - rp: rp, - src: src, - fieldCount: fieldCount, - } -} - -// ScanDecoder calls Next and decodes the result with d. -func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) { - if cfs.err != nil { - return - } - - if cfs.Next() { - cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes) - } else { - cfs.err = errors.New("read past end of composite") - } -} - -// ScanDecoder calls Next and scans the result into d. -func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) { - if cfs.err != nil { - return - } - - if cfs.Next() { - cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d) - } else { - cfs.err = errors.New("read past end of composite") - } -} - -// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After -// Next returns false, the Err method can be called to check if any errors occurred. -func (cfs *CompositeBinaryScanner) Next() bool { - if cfs.err != nil { - return false - } - - if cfs.rp == len(cfs.src) { - return false - } - - if len(cfs.src[cfs.rp:]) < 8 { - cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) - return false - } - cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) - cfs.rp += 4 - - fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) - cfs.rp += 4 - - if fieldLen >= 0 { - if len(cfs.src[cfs.rp:]) < fieldLen { - cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) - return false - } - cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] - cfs.rp += fieldLen - } else { - cfs.fieldBytes = nil - } - - return true -} - -func (cfs *CompositeBinaryScanner) FieldCount() int { - return int(cfs.fieldCount) -} - -// Bytes returns the bytes of the field most recently read by Scan(). -func (cfs *CompositeBinaryScanner) Bytes() []byte { - return cfs.fieldBytes -} - -// OID returns the OID of the field most recently read by Scan(). -func (cfs *CompositeBinaryScanner) OID() uint32 { - return cfs.fieldOID -} - -// Err returns any error encountered by the scanner. -func (cfs *CompositeBinaryScanner) Err() error { - return cfs.err -} - -type CompositeTextScanner struct { - ci *ConnInfo - rp int - src []byte - - fieldBytes []byte - err error -} - -// NewCompositeTextScanner a scanner over a text encoded composite value. -func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { - if len(src) < 2 { - return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} - } - - if src[0] != '(' { - return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} - } - - if src[len(src)-1] != ')' { - return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} - } - - return &CompositeTextScanner{ - ci: ci, - rp: 1, - src: src, - } -} - -// ScanDecoder calls Next and decodes the result with d. -func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) { - if cfs.err != nil { - return - } - - if cfs.Next() { - cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes) - } else { - cfs.err = errors.New("read past end of composite") - } -} - -// ScanDecoder calls Next and scans the result into d. -func (cfs *CompositeTextScanner) ScanValue(d interface{}) { - if cfs.err != nil { - return - } - - if cfs.Next() { - cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d) - } else { - cfs.err = errors.New("read past end of composite") - } -} - -// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After -// Next returns false, the Err method can be called to check if any errors occurred. -func (cfs *CompositeTextScanner) Next() bool { - if cfs.err != nil { - return false - } - - if cfs.rp == len(cfs.src) { - return false - } - - switch cfs.src[cfs.rp] { - case ',', ')': // null - cfs.rp++ - cfs.fieldBytes = nil - return true - case '"': // quoted value - cfs.rp++ - cfs.fieldBytes = make([]byte, 0, 16) - for { - ch := cfs.src[cfs.rp] - - if ch == '"' { - cfs.rp++ - if cfs.src[cfs.rp] == '"' { - cfs.fieldBytes = append(cfs.fieldBytes, '"') - cfs.rp++ - } else { - break - } - } else if ch == '\\' { - cfs.rp++ - cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) - cfs.rp++ - } else { - cfs.fieldBytes = append(cfs.fieldBytes, ch) - cfs.rp++ - } - } - cfs.rp++ - return true - default: // unquoted value - start := cfs.rp - for { - ch := cfs.src[cfs.rp] - if ch == ',' || ch == ')' { - break - } - cfs.rp++ - } - cfs.fieldBytes = cfs.src[start:cfs.rp] - cfs.rp++ - return true - } -} - -// Bytes returns the bytes of the field most recently read by Scan(). -func (cfs *CompositeTextScanner) Bytes() []byte { - return cfs.fieldBytes -} - -// Err returns any error encountered by the scanner. -func (cfs *CompositeTextScanner) Err() error { - return cfs.err -} - -type CompositeBinaryBuilder struct { - ci *ConnInfo - buf []byte - startIdx int - fieldCount uint32 - err error -} - -func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder { - startIdx := len(buf) - buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields - return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx} -} - -func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { - if b.err != nil { - return - } - - dt, ok := b.ci.DataTypeForOID(oid) - if !ok { - b.err = fmt.Errorf("unknown data type for OID: %d", oid) - return - } - - err := dt.Value.Set(field) - if err != nil { - b.err = err - return - } - - paramEncoder, ok := dt.Value.(ParamEncoder) - if !ok { - b.err = fmt.Errorf("unable to encode for OID: %d", oid) - return - } - - b.AppendEncoder(oid, paramEncoder) -} - -func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) { - if b.err != nil { - return - } - - b.buf = pgio.AppendUint32(b.buf, oid) - lengthPos := len(b.buf) - b.buf = pgio.AppendInt32(b.buf, -1) - fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf) - if err != nil { - b.err = err - return - } - if fieldBuf != nil { - binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) - b.buf = fieldBuf - } - - b.fieldCount++ -} - -func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { - if b.err != nil { - return nil, b.err - } - - binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) - return b.buf, nil -} - -type CompositeTextBuilder struct { - ci *ConnInfo - buf []byte - startIdx int - fieldCount uint32 - err error - fieldBuf [32]byte -} - -func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder { - buf = append(buf, '(') // allocate room for number of fields - return &CompositeTextBuilder{ci: ci, buf: buf} -} - -func (b *CompositeTextBuilder) AppendValue(field interface{}) { - if b.err != nil { - return - } - - if field == nil { - b.buf = append(b.buf, ',') - return - } - - dt, ok := b.ci.DataTypeForValue(field) - if !ok { - b.err = fmt.Errorf("unknown data type for field: %v", field) - return - } - - err := dt.Value.Set(field) - if err != nil { - b.err = err - return - } - - paramEncoder, ok := dt.Value.(ParamEncoder) - if !ok { - b.err = fmt.Errorf("unable to encode for value: %v", field) - return - } - - b.AppendEncoder(paramEncoder) -} - -func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) { - if b.err != nil { - return - } - - fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0]) - if err != nil { - b.err = err - return - } - if fieldBuf != nil { - b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) - } - - b.buf = append(b.buf, ',') -} - -func (b *CompositeTextBuilder) Finish() ([]byte, error) { - if b.err != nil { - return nil, b.err - } - - b.buf[len(b.buf)-1] = ')' - return b.buf, nil -} - -var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) - -func quoteCompositeField(src string) string { - return `"` + quoteCompositeReplacer.Replace(src) + `"` -} - -func quoteCompositeFieldIfNeeded(src string) string { - if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { - return quoteCompositeField(src) - } - return src -} diff --git a/pgtype/composite_type_test.go b/pgtype/composite_type_test.go deleted file mode 100644 index a41ad0f4..00000000 --- a/pgtype/composite_type_test.go +++ /dev/null @@ -1,320 +0,0 @@ -package pgtype_test - -import ( - "context" - "fmt" - "os" - "testing" - - pgx "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCompositeTypeSetAndGet(t *testing.T) { - ci := pgtype.NewConnInfo() - ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, ci) - require.NoError(t, err) - assert.Equal(t, nil, ct.Get()) - - nilTests := []struct { - src interface{} - }{ - {nil}, // nil interface - {(*[]interface{})(nil)}, // typed nil - } - - for i, tt := range nilTests { - err := ct.Set(tt.src) - assert.NoErrorf(t, err, "%d", i) - assert.Equal(t, nil, ct.Get()) - } - - compatibleValuesTests := []struct { - src []interface{} - expected map[string]interface{} - }{ - { - src: []interface{}{"foo", int32(42)}, - expected: map[string]interface{}{"a": "foo", "b": int32(42)}, - }, - { - src: []interface{}{nil, nil}, - expected: map[string]interface{}{"a": nil, "b": nil}, - }, - { - src: []interface{}{&pgtype.Text{String: "hi", Valid: true}, &pgtype.Int4{Int: 7, Valid: true}}, - expected: map[string]interface{}{"a": "hi", "b": int32(7)}, - }, - } - - for i, tt := range compatibleValuesTests { - err := ct.Set(tt.src) - assert.NoErrorf(t, err, "%d", i) - assert.EqualValues(t, tt.expected, ct.Get()) - } -} - -func TestCompositeTypeAssignTo(t *testing.T) { - ci := pgtype.NewConnInfo() - ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, ci) - require.NoError(t, err) - - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a string - var b int32 - - err = ct.AssignTo([]interface{}{&a, &b}) - assert.NoError(t, err) - - assert.Equal(t, "foo", a) - assert.Equal(t, int32(42), b) - } - - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - - err = ct.AssignTo([]interface{}{&a, &b}) - assert.NoError(t, err) - - assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a) - assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b) - } - - // Allow nil destination component as no-op - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var b int32 - - err = ct.AssignTo([]interface{}{nil, &b}) - assert.NoError(t, err) - - assert.Equal(t, int32(42), b) - } - - // *[]interface{} dest when null - { - err := ct.Set(nil) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - dst := []interface{}{&a, &b} - - err = ct.AssignTo(&dst) - assert.NoError(t, err) - - assert.Nil(t, dst) - } - - // *[]interface{} dest when not null - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - var a pgtype.Text - var b pgtype.Int4 - dst := []interface{}{&a, &b} - - err = ct.AssignTo(&dst) - assert.NoError(t, err) - - assert.NotNil(t, dst) - assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a) - assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b) - } - - // Struct fields positionally via reflection - { - err := ct.Set([]interface{}{"foo", int32(42)}) - assert.NoError(t, err) - - s := struct { - A string - B int32 - }{} - - err = ct.AssignTo(&s) - if assert.NoError(t, err) { - assert.Equal(t, "foo", s.A) - assert.Equal(t, int32(42), s.B) - } - } -} - -func TestCompositeTypeTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists ct_test; - -create type ct_test as ( - a text, - b int4 -);`) - require.NoError(t, err) - defer conn.Exec(context.Background(), "drop type ct_test") - - var oid uint32 - err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) - require.NoError(t, err) - - defer conn.Exec(context.Background(), "drop type ct_test") - - ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{ - {"a", pgtype.TextOID}, - {"b", pgtype.Int4OID}, - }, conn.ConnInfo()) - require.NoError(t, err) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) - - // Use simple protocol to force text or binary encoding - simpleProtocols := []bool{true, false} - - var a string - var b int32 - - for _, simpleProtocol := range simpleProtocols { - err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol), - pgtype.CompositeFields{"hi", int32(42)}, - ).Scan( - []interface{}{&a, &b}, - ) - if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { - assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) - assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol) - } - } -} - -// https://github.com/jackc/pgx/issues/874 -func TestCompositeTypeTextDecodeNested(t *testing.T) { - newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { - fields := make([]pgtype.CompositeTypeField, len(fieldNames)) - for i, name := range fieldNames { - fields[i] = pgtype.CompositeTypeField{Name: name} - } - - rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals) - require.NoError(t, err) - return rowType - } - - dimensionsType := func() pgtype.ValueTranscoder { - return newCompositeType( - "dimensions", - []string{"width", "height"}, - &pgtype.Int4{}, - &pgtype.Int4{}, - ) - } - productImageType := func() pgtype.ValueTranscoder { - return newCompositeType( - "product_image_type", - []string{"source", "dimensions"}, - &pgtype.Text{}, - dimensionsType(), - ) - } - productImageSetType := newCompositeType( - "product_image_set_type", - []string{"name", "orig_image", "images"}, - &pgtype.Text{}, - productImageType(), - pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder { - return productImageType() - }), - ) - - err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`)) - require.NoError(t, err) -} - -func Example_composite() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - fmt.Println(err) - return - } - - defer conn.Close(context.Background()) - _, err = conn.Exec(context.Background(), `drop type if exists mytype;`) - if err != nil { - fmt.Println(err) - return - } - - _, err = conn.Exec(context.Background(), `create type mytype as ( - a int4, - b text -);`) - if err != nil { - fmt.Println(err) - return - } - defer conn.Exec(context.Background(), "drop type mytype") - - var oid uint32 - err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid) - if err != nil { - fmt.Println(err) - return - } - - ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{ - {"a", pgtype.Int4OID}, - {"b", pgtype.TextOID}, - }, conn.ConnInfo()) - if err != nil { - fmt.Println(err) - return - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) - - var a int - var b *string - - err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b}) - if err != nil { - fmt.Println(err) - return - } - - fmt.Printf("First: a=%d b=%s\n", a, *b) - - err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b}) - if err != nil { - fmt.Println(err) - return - } - - fmt.Printf("Second: a=%d b=%v\n", a, b) - - scanTarget := []interface{}{&a, &b} - err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget) - E(err) - - fmt.Printf("Third: isNull=%v\n", scanTarget == nil) - - // Output: - // First: a=2 b=bar - // Second: a=1 b= - // Third: isNull=true -} diff --git a/pgtype/custom_composite_test.go b/pgtype/custom_composite_test.go deleted file mode 100644 index e5f2166e..00000000 --- a/pgtype/custom_composite_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package pgtype_test - -import ( - "context" - "errors" - "fmt" - "os" - - pgx "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" -) - -type MyType struct { - a int32 // NULL will cause decoding error - b *string // there can be NULL in this position in SQL -} - -func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") - } - - if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil { - return err - } - - return nil -} - -func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { - a := pgtype.Int4{src.a, true} - var b pgtype.Text - if src.b != nil { - b = pgtype.Text{*src.b, true} - } else { - b = pgtype.Text{} - } - - return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf) -} - -func ptrS(s string) *string { - return &s -} - -func E(err error) { - if err != nil { - panic(err) - } -} - -// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL -// composites can be added. -func Example_customCompositeTypes() { - conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - E(err) - - defer conn.Close(context.Background()) - _, err = conn.Exec(context.Background(), `drop type if exists mytype; - -create type mytype as ( - a int4, - b text -);`) - E(err) - defer conn.Exec(context.Background(), "drop type mytype") - - var result *MyType - - // Demonstrates both passing and reading back composite values - err = conn.QueryRow(context.Background(), "select $1::mytype", - pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). - Scan(&result) - E(err) - - fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) - - // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result - err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) - E(err) - - fmt.Printf("Second row: %v\n", result) - - // Output: - // First row: a=1 b=foo - // Second row: -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 50ca29c3..6df3a582 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -343,7 +343,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) - ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) + // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) diff --git a/pgtype/record.go b/pgtype/record.go deleted file mode 100644 index 5bb4d701..00000000 --- a/pgtype/record.go +++ /dev/null @@ -1,141 +0,0 @@ -package pgtype - -import ( - "fmt" - "reflect" -) - -// Record is the generic PostgreSQL record type such as is created with the -// "row" function. Record only implements BinaryEncoder and Value. The text -// format output format from PostgreSQL does not include type information and is -// therefore impossible to decode. No encoders are implemented because -// PostgreSQL does not support input of generic records. -type Record struct { - Fields []Value - Valid bool -} - -func (dst *Record) Set(src interface{}) error { - if src == nil { - *dst = Record{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case []Value: - *dst = Record{Fields: value, Valid: true} - default: - return fmt.Errorf("cannot convert %v to Record", src) - } - - return nil -} - -func (dst Record) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Fields -} - -func (src *Record) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *[]Value: - *v = make([]Value, len(src.Fields)) - copy(*v, src.Fields) - return nil - case *[]interface{}: - *v = make([]interface{}, len(src.Fields)) - for i := range *v { - (*v)[i] = src.Fields[i].Get() - } - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { - var binaryDecoder BinaryDecoder - - if dt, ok := ci.DataTypeForOID(fieldOID); ok { - binaryDecoder, _ = dt.Value.(BinaryDecoder) - } else { - return nil, fmt.Errorf("unknown oid while decoding record: %v", fieldOID) - } - - if binaryDecoder == nil { - return nil, fmt.Errorf("no binary decoder registered for: %v", fieldOID) - } - - // Duplicate struct to scan into - binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) - *v = binaryDecoder.(Value) - return binaryDecoder, nil -} - -func (Record) BinaryFormatSupported() bool { - return true -} - -func (Record) TextFormatSupported() bool { - return false -} - -func (Record) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Record) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return fmt.Errorf("text format is not supported") - } - return fmt.Errorf("unknown format code %d", format) -} - -func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Record{} - return nil - } - - scanner := NewCompositeBinaryScanner(ci, src) - - fields := make([]Value, scanner.FieldCount()) - - for i := 0; scanner.Next(); i++ { - binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) - if err != nil { - return err - } - - if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { - return err - } - } - - if scanner.Err() != nil { - return scanner.Err() - } - - *dst = Record{Fields: fields, Valid: true} - - return nil -} diff --git a/pgtype/record_test.go b/pgtype/record_test.go deleted file mode 100644 index 921f0975..00000000 --- a/pgtype/record_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package pgtype_test - -import ( - "context" - "fmt" - "reflect" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -var recordTests = []struct { - sql string - expected pgtype.Record -}{ - { - sql: `select row()`, - expected: pgtype.Record{ - Fields: []pgtype.Value{}, - Valid: true, - }, - }, - { - sql: `select row('foo'::text, 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4{Int: 42, Valid: true}, - }, - Valid: true, - }, - }, - { - sql: `select row(100.0::float4, 1.09::float4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Float4{Float: 100, Valid: true}, - &pgtype.Float4{Float: 1.09, Valid: true}, - }, - Valid: true, - }, - }, - { - sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Valid: true}, - {Int: 2, Valid: true}, - {}, - {Int: 4, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, - Valid: true, - }, - &pgtype.Int4{Int: 42, Valid: true}, - }, - Valid: true, - }, - }, - { - sql: `select row(null)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Unknown{}, - }, - Valid: true, - }, - }, - { - sql: `select null::record`, - expected: pgtype.Record{}, - }, -} - -func TestRecordTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - for i, tt := range recordTests { - psName := fmt.Sprintf("test%d", i) - _, err := conn.Prepare(context.Background(), psName, tt.sql) - if err != nil { - t.Fatal(err) - } - - t.Run(tt.sql, func(t *testing.T) { - var result pgtype.Record - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { - t.Errorf("%v", err) - return - } - - if !reflect.DeepEqual(tt.expected, result) { - t.Errorf("expected %#v, got %#v", tt.expected, result) - } - }) - - } -} - -func TestRecordWithUnknownOID(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Exec(context.Background(), `drop type if exists floatrange; - -create type floatrange as range ( - subtype = float8, - subtype_diff = float8mi -);`) - if err != nil { - t.Fatal(err) - } - defer conn.Exec(context.Background(), "drop type floatrange") - - var result pgtype.Record - err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) - if err == nil { - t.Errorf("expected error but none") - } -} - -func TestRecordAssignTo(t *testing.T) { - var valueSlice []pgtype.Value - var interfaceSlice []interface{} - - simpleTests := []struct { - src pgtype.Record - dst interface{} - expected interface{} - }{ - { - src: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4{Int: 42, Valid: true}, - }, - Valid: true, - }, - dst: &valueSlice, - expected: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4{Int: 42, Valid: true}, - }, - }, - { - src: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Valid: true}, - &pgtype.Int4{Int: 42, Valid: true}, - }, - Valid: true, - }, - dst: &interfaceSlice, - expected: []interface{}{"foo", int32(42)}, - }, - { - src: pgtype.Record{}, - dst: &valueSlice, - expected: (([]pgtype.Value)(nil)), - }, - { - src: pgtype.Record{}, - dst: &interfaceSlice, - expected: (([]interface{})(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -}