Add CompositeFields type
This adds support for the text format and removes the need for the ScanRowValue function.
This commit is contained in:
@@ -233,6 +233,96 @@ func (cfs *CompositeBinaryScanner) Err() error {
|
|||||||
return cfs.err
|
return cfs.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CompositeTextScanner struct {
|
||||||
|
rp int
|
||||||
|
src []byte
|
||||||
|
|
||||||
|
fieldBytes []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCompositeTextScanner a scanner over a text encoded composite balue.
|
||||||
|
func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) {
|
||||||
|
if len(src) < 2 {
|
||||||
|
return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
if src[0] != '(' {
|
||||||
|
return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('")
|
||||||
|
}
|
||||||
|
|
||||||
|
if src[len(src)-1] != ')' {
|
||||||
|
return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'")
|
||||||
|
}
|
||||||
|
|
||||||
|
return CompositeTextScanner{
|
||||||
|
rp: 1,
|
||||||
|
src: src,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
|
||||||
|
// Scan returns false, the Err method can be called to check if any errors occurred.
|
||||||
|
func (cfs *CompositeTextScanner) Scan() bool {
|
||||||
|
if cfs.err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfs.rp == len(cfs.src) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cfs.src[cfs.rp] {
|
||||||
|
case ',', ')': // null
|
||||||
|
cfs.rp++
|
||||||
|
cfs.fieldBytes = nil
|
||||||
|
return true
|
||||||
|
case '"': // quoted value
|
||||||
|
cfs.rp++
|
||||||
|
cfs.fieldBytes = make([]byte, 0, 16)
|
||||||
|
for {
|
||||||
|
ch := cfs.src[cfs.rp]
|
||||||
|
|
||||||
|
if ch == '"' {
|
||||||
|
cfs.rp++
|
||||||
|
if cfs.src[cfs.rp] == '"' {
|
||||||
|
cfs.fieldBytes = append(cfs.fieldBytes, '"')
|
||||||
|
cfs.rp++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cfs.fieldBytes = append(cfs.fieldBytes, ch)
|
||||||
|
cfs.rp++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfs.rp++
|
||||||
|
return true
|
||||||
|
default: // unquoted value
|
||||||
|
start := cfs.rp
|
||||||
|
for {
|
||||||
|
ch := cfs.src[cfs.rp]
|
||||||
|
if ch == ',' || ch == ')' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cfs.rp++
|
||||||
|
}
|
||||||
|
cfs.fieldBytes = cfs.src[start:cfs.rp]
|
||||||
|
cfs.rp++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes returns the bytes of the field most recently read by Scan().
|
||||||
|
func (cfs *CompositeTextScanner) Bytes() []byte {
|
||||||
|
return cfs.fieldBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err returns any error encountered by the scanner.
|
||||||
|
func (cfs *CompositeTextScanner) Err() error {
|
||||||
|
return cfs.err
|
||||||
|
}
|
||||||
|
|
||||||
// RecordStart adds record header to the buf
|
// RecordStart adds record header to the buf
|
||||||
func RecordStart(buf []byte, fieldCount int) []byte {
|
func RecordStart(buf []byte, fieldCount int) []byte {
|
||||||
return pgio.AppendUint32(buf, uint32(fieldCount))
|
return pgio.AppendUint32(buf, uint32(fieldCount))
|
||||||
|
|||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
errors "golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a
|
||||||
|
// nullable value use a *CompositeFields. It will be set to nil in case of null.
|
||||||
|
type CompositeFields []interface{}
|
||||||
|
|
||||||
|
func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
if len(cf) == 0 {
|
||||||
|
return errors.Errorf("cannot decode into empty CompositeFields")
|
||||||
|
}
|
||||||
|
|
||||||
|
if src == nil {
|
||||||
|
return errors.Errorf("cannot decode unexpected null into CompositeFields")
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner, err := NewCompositeBinaryScanner(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(cf) != scanner.FieldCount() {
|
||||||
|
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; scanner.Scan(); i++ {
|
||||||
|
err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if scanner.Err() != nil {
|
||||||
|
return scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if len(cf) == 0 {
|
||||||
|
return errors.Errorf("cannot decode into empty CompositeFields")
|
||||||
|
}
|
||||||
|
|
||||||
|
if src == nil {
|
||||||
|
return errors.Errorf("cannot decode unexpected null into CompositeFields")
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner, err := NewCompositeTextScanner(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldCount := 0
|
||||||
|
|
||||||
|
for i := 0; scanner.Scan(); i++ {
|
||||||
|
err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldCount += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if scanner.Err() != nil {
|
||||||
|
return scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cf) != fieldCount {
|
||||||
|
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
package pgtype_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgtype"
|
||||||
|
"github.com/jackc/pgtype/testutil"
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompositeFieldsDecode(t *testing.T) {
|
||||||
|
conn := testutil.MustConnectPgx(t)
|
||||||
|
defer testutil.MustCloseContext(t, conn)
|
||||||
|
|
||||||
|
formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}
|
||||||
|
|
||||||
|
// Assorted values
|
||||||
|
{
|
||||||
|
var a int32
|
||||||
|
var b string
|
||||||
|
var c float64
|
||||||
|
|
||||||
|
for _, format := range formats {
|
||||||
|
err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
|
||||||
|
pgtype.CompositeFields{&a, &b, &c},
|
||||||
|
)
|
||||||
|
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.EqualValuesf(t, 1, a, "Format: %v", format)
|
||||||
|
assert.EqualValuesf(t, "hi", b, "Format: %v", format)
|
||||||
|
assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nulls, string "null", and empty string fields
|
||||||
|
{
|
||||||
|
var a pgtype.Text
|
||||||
|
var b string
|
||||||
|
var c pgtype.Text
|
||||||
|
var d string
|
||||||
|
var e pgtype.Text
|
||||||
|
|
||||||
|
for _, format := range formats {
|
||||||
|
err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan(
|
||||||
|
pgtype.CompositeFields{&a, &b, &c, &d, &e},
|
||||||
|
)
|
||||||
|
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Nilf(t, a.Get(), "Format: %v", format)
|
||||||
|
assert.EqualValuesf(t, "null", b, "Format: %v", format)
|
||||||
|
assert.Nilf(t, c.Get(), "Format: %v", format)
|
||||||
|
assert.EqualValuesf(t, "", d, "Format: %v", format)
|
||||||
|
assert.Nilf(t, e.Get(), "Format: %v", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// null record
|
||||||
|
{
|
||||||
|
var a pgtype.Text
|
||||||
|
var b string
|
||||||
|
cf := pgtype.CompositeFields{&a, &b}
|
||||||
|
|
||||||
|
for _, format := range formats {
|
||||||
|
// Cannot scan nil into
|
||||||
|
err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
|
||||||
|
cf,
|
||||||
|
)
|
||||||
|
if assert.Errorf(t, err, "Format: %v", format) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
assert.NotNilf(t, cf, "Format: %v", format)
|
||||||
|
|
||||||
|
// But can scan nil into *pgtype.CompositeFields
|
||||||
|
err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
|
||||||
|
&cf,
|
||||||
|
)
|
||||||
|
if assert.Errorf(t, err, "Format: %v", format) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
assert.Nilf(t, cf, "Format: %v", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// quotes and special characters
|
||||||
|
{
|
||||||
|
var a, b, c, d string
|
||||||
|
|
||||||
|
for _, format := range formats {
|
||||||
|
err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan(
|
||||||
|
pgtype.CompositeFields{&a, &b, &c, &d},
|
||||||
|
)
|
||||||
|
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equalf(t, `"`, a, "Format: %v", format)
|
||||||
|
assert.Equalf(t, `foo bar`, b, "Format: %v", format)
|
||||||
|
assert.Equalf(t, `foo'bar`, c, "Format: %v", format)
|
||||||
|
assert.Equalf(t, `baz)bar`, d, "Format: %v", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// arrays
|
||||||
|
{
|
||||||
|
var a []string
|
||||||
|
var b []int64
|
||||||
|
|
||||||
|
for _, format := range formats {
|
||||||
|
err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan(
|
||||||
|
pgtype.CompositeFields{&a, &b},
|
||||||
|
)
|
||||||
|
if !assert.NoErrorf(t, err, "Format: %v", format) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format)
|
||||||
|
assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
-32
@@ -433,38 +433,6 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanRowValue decodes ROW()'s and composite type
|
|
||||||
// from src argument using provided decoders. Decoders should match
|
|
||||||
// order and count of fields of record being decoded.
|
|
||||||
//
|
|
||||||
// In practice you can pass pgtype.Value types as decoders, as
|
|
||||||
// most of them implement BinaryDecoder interface.
|
|
||||||
//
|
|
||||||
// ScanRowValue takes ownership of src, caller MUST not use it after call
|
|
||||||
func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error {
|
|
||||||
scanner, err := NewCompositeBinaryScanner(src)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dst) != scanner.FieldCount() {
|
|
||||||
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", scanner.FieldCount(), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; scanner.Scan(); i++ {
|
|
||||||
err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), dst[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if scanner.Err() != nil {
|
|
||||||
return scanner.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// EncodeRow builds a binary representation of row values (row(), composite types)
|
// EncodeRow builds a binary representation of row values (row(), composite types)
|
||||||
func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) {
|
func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) {
|
||||||
fieldBytes := make([]byte, 0, 128)
|
fieldBytes := make([]byte, 0, 128)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
|
|||||||
return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs")
|
return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pgtype.ScanRowValue(ci, src, &dst.a, &dst.b); err != nil {
|
if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -79,54 +79,6 @@ var recordTests = []struct {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// row values are binary compatible with records, so we test our helper
|
|
||||||
// routines here
|
|
||||||
func TestScanRowValue(t *testing.T) {
|
|
||||||
conn := testutil.MustConnectPgx(t)
|
|
||||||
defer testutil.MustCloseContext(t, conn)
|
|
||||||
|
|
||||||
for i := 0; i < len(recordTests); i++ {
|
|
||||||
tt := recordTests[i]
|
|
||||||
psName := fmt.Sprintf("test%d", i)
|
|
||||||
_, err := conn.Prepare(context.Background(), psName, tt.sql)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
t.Run(tt.sql, func(t *testing.T) {
|
|
||||||
desc := []interface{}{}
|
|
||||||
for _, f := range tt.expected.Fields {
|
|
||||||
desc = append(desc, f.(pgtype.BinaryDecoder))
|
|
||||||
}
|
|
||||||
|
|
||||||
var raw pgtype.GenericBinary
|
|
||||||
|
|
||||||
if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&raw); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if raw.Status == pgtype.Null {
|
|
||||||
// ScanRowValue deals with complete rows only, NULL values (but NOT null fields)
|
|
||||||
// should be handled by the calling code
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// borrow fields from a neighbor test, this makes scan always fail
|
|
||||||
desc = desc[:0]
|
|
||||||
for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields {
|
|
||||||
desc = append(desc, f.(pgtype.BinaryDecoder))
|
|
||||||
}
|
|
||||||
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil {
|
|
||||||
t.Error("Matching scan didn't fail, despite fields not mathching query result")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecordTranscode(t *testing.T) {
|
func TestRecordTranscode(t *testing.T) {
|
||||||
conn := testutil.MustConnectPgx(t)
|
conn := testutil.MustConnectPgx(t)
|
||||||
defer testutil.MustCloseContext(t, conn)
|
defer testutil.MustCloseContext(t, conn)
|
||||||
|
|||||||
Reference in New Issue
Block a user