diff --git a/composite.go b/composite.go new file mode 100644 index 00000000..1caa24d6 --- /dev/null +++ b/composite.go @@ -0,0 +1,128 @@ +package pgtype + +import ( + errors "golang.org/x/xerrors" +) + +type composite struct { + fields []Value + status Status +} + +// helper struct to act both as a scanning target and query argument +type rowValue struct { + args []interface{} +} + +// Row helper function builds a value which can be both used to +// "assemble" composite quiery arguments and to scan results back. +// +// When passed as an argument to query, values from Row args will +// be assigned to corresponding fields in a composite type and a single +// composite type will be passed to the PostgreSQL. Composite type need +// to be registered in ConnInfo first. This is required so that pgx +// can know which SQL types to use when constructing SQL composite argument +// +// When passed to Scan individual fields from composite query result +// are assigned to corresponding Row arguments. First argument MUST +// be of type *bool to flag when NULL value received. So total number +// of Row arguments, when passed to Scan should be number of composite +// fields you expect to read + 1 +func Row(fields ...interface{}) rowValue { + return rowValue{fields} +} + +// Composite types is meant to be passed to ConnInfo.RegisterDataType only, +// so it is made private on purpose. Once registered, it allows Row +// function to correctly pass query arguments. +func Composite(fields ...Value) *composite { + return &composite{fields, Undefined} +} + +func (src composite) Get() interface{} { + switch src.status { + case Present: + return src + case Null: + return nil + default: + return src.status + } +} + +// Set is called internally when passing query arguments. +// Only valid src is a result of pgtype.Row() or nil +func (dst *composite) Set(src interface{}) error { + if src == nil { + *dst = composite{status: Null} + return nil + } + + switch value := src.(type) { + case rowValue: + if len(value.args) != len(dst.fields) { + return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) + } + for i, v := range value.args { + if err := dst.fields[i].Set(v); err != nil { + return err + } + } + dst.status = Present + default: + return errors.Errorf("Use pgtype.Row() as query parameter") + } + + return nil +} + +// AssignTo is never called on composite value directly, it is here +// to satisfy Valuer interface +func (src composite) AssignTo(dst interface{}) error { + return errors.New("BUG: should never be called, because pgtype.composite doesn't support decoding") +} + +func (src composite) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + return EncodeRow(ci, buf, src.fields...) +} + +// DecodeBinary here is just to make pgx use binary result format by default. +// Users should be using Row function or their own types to scan composites +func (src composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { + return errors.New("Pass pgtype.Row() to Scan to deconstruct Composite") +} + +// DecodeBinary is called when pgtype.Row() is passed to Scan() to +// deconstruct composite value +func (r rowValue) DecodeBinary(ci *ConnInfo, src []byte) error { + if len(r.args) == 0 { + return errors.New("pgtype.Row must have 'isNull *bool' as a first argument when used in Scan") + } + + isNull, ok := r.args[0].(*bool) + if !ok { + return errors.New("pgtype.Row must have 'isNull *bool' as a first argument when used in Scan") + } + args := r.args[1:] + + var record Record + if err := record.DecodeBinary(ci, src); err != nil { + return err + } + + if record.Status == Null { + *isNull = true + return nil + } + + if len(record.Fields) != len(args) { + return errors.Errorf("SQL composite can't be read, 'pgtype.Row' has wrong field cout. %d != %d", len(record.Fields), len(args)) + } + + for i, f := range record.Fields { + if err := f.AssignTo(args[i]); err != nil { + return err + } + } + return nil +} diff --git a/composite_test.go b/composite_test.go index d0c48f6e..b38cdd45 100644 --- a/composite_test.go +++ b/composite_test.go @@ -81,7 +81,8 @@ create type mytype as ( // Demonstrates both passing and reading back composite values err = conn.QueryRow(context.Background(), "select $1::mytype", - pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}).Scan(&result) + pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). + Scan(&result) E(err) fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) @@ -92,12 +93,21 @@ create type mytype as ( fmt.Printf("Second row: %v\n", result) - // Adhoc rows can be decoded inplace without boilerplate (works with composite types too) + //WIP + q, err := conn.Prepare(context.Background(), "z", "select $1::mytype") + E(err) + conn.ConnInfo().RegisterDataType(pgtype.DataType{pgtype.Composite(&pgtype.Int4{}, &pgtype.Text{}), "mytype", q.ParamOIDs[0]}) + + // Adhoc rows can be decoded inplace without boilerplate + // Composite types can be encoded/decoded inplace + var isNull bool var a int var b *string - err = conn.QueryRow(context.Background(), "select (2, 'bar')::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(pgtype.ROW(&isNull, &a, &b)) + err = conn.QueryRow(context.Background(), "select row(($1::mytype).a, ($1).b)", + pgx.QueryResultFormats{pgx.BinaryFormatCode}, pgtype.Row(2, "bar")). + Scan(pgtype.Row(&isNull, &a, &b)) E(err) fmt.Printf("Adhoc: isNull=%v a=%d b=%s", isNull, a, *b) diff --git a/convert.go b/convert.go index d22a714f..134e123d 100644 --- a/convert.go +++ b/convert.go @@ -504,34 +504,6 @@ func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err er return } -// ROW allows deconstructing row values (records and composite types) into -// fields directly without creating your own type and implementing decoder interfaces -func ROW(isNull *bool, fields ...interface{}) BinaryDecoderFunc { - return func(ci *ConnInfo, src []byte) error { - var record Record - if err := record.DecodeBinary(ci, src); err != nil { - return err - } - - if record.Status == Null { - *isNull = true - return nil - } - - if len(record.Fields) != len(fields) { - return errors.Errorf("can't scan row value, number of fields don't match: row fields count=%d desired fields count=%d", len(record.Fields), len(fields)) - } - - for i, f := range record.Fields { - if err := f.AssignTo(fields[i]); err != nil { - return err - } - } - - return nil - } -} - func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/pgtype.go b/pgtype.go index 1749c8c2..e86255f4 100644 --- a/pgtype.go +++ b/pgtype.go @@ -167,6 +167,15 @@ func (f BinaryDecoderFunc) DecodeBinary(ci *ConnInfo, src []byte) error { return f(ci, src) } +//The BinaryEncoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. +// If f is a function with the appropriate signature, BinaryEncoderFunc(f) is a BinaryDecoder that calls f. +type BinaryEncoderFunc func(ci *ConnInfo, buf []byte) ([]byte, error) + +// EncodeBinary calls f(ci, buf) +func (f BinaryEncoderFunc) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + return f(ci, buf) +} + var errUndefined = errors.New("cannot encode status undefined") var errBadStatus = errors.New("invalid status")