Refactor and add CompositeTextBuilder
This commit is contained in:
+6
-41
@@ -58,52 +58,17 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using
|
||||
// CompositeFields to encode directly.
|
||||
func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||
buf = append(buf, '(')
|
||||
|
||||
fieldBuf := make([]byte, 0, 32)
|
||||
b := NewCompositeTextBuilder(ci, buf)
|
||||
|
||||
for _, f := range cf {
|
||||
if f != nil {
|
||||
fieldBuf = fieldBuf[0:0]
|
||||
if textEncoder, ok := f.(TextEncoder); ok {
|
||||
var err error
|
||||
fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...)
|
||||
}
|
||||
} else {
|
||||
dt, ok := ci.DataTypeForValue(f)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Unknown data type for %#v", f)
|
||||
}
|
||||
|
||||
err := dt.Value.Set(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if textEncoder, ok := dt.Value.(TextEncoder); ok {
|
||||
var err error
|
||||
fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.Errorf("Cannot encode text format for %v", f)
|
||||
}
|
||||
}
|
||||
if textEncoder, ok := f.(TextEncoder); ok {
|
||||
b.AppendEncoder(textEncoder)
|
||||
} else {
|
||||
b.AppendValue(f)
|
||||
}
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
|
||||
buf[len(buf)-1] = ')'
|
||||
return buf, nil
|
||||
return b.Finish()
|
||||
}
|
||||
|
||||
// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is
|
||||
|
||||
@@ -177,6 +177,27 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
|
||||
if buf == nil {
|
||||
dst.status = Null
|
||||
return nil
|
||||
}
|
||||
|
||||
scanner := NewCompositeTextScanner(ci, buf)
|
||||
|
||||
for _, f := range dst.fields {
|
||||
scanner.ScanDecoder(f)
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
dst.status = Present
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type CompositeBinaryScanner struct {
|
||||
ci *ConnInfo
|
||||
rp int
|
||||
@@ -474,6 +495,77 @@ func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
type CompositeTextBuilder struct {
|
||||
ci *ConnInfo
|
||||
buf []byte
|
||||
startIdx int
|
||||
fieldCount uint32
|
||||
err error
|
||||
fieldBuf [32]byte
|
||||
}
|
||||
|
||||
func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder {
|
||||
buf = append(buf, '(') // allocate room for number of fields
|
||||
return &CompositeTextBuilder{ci: ci, buf: buf}
|
||||
}
|
||||
|
||||
func (b *CompositeTextBuilder) AppendValue(field interface{}) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if field == nil {
|
||||
b.buf = append(b.buf, ',')
|
||||
return
|
||||
}
|
||||
|
||||
dt, ok := b.ci.DataTypeForValue(field)
|
||||
if !ok {
|
||||
b.err = errors.Errorf("unknown data type for field: %v", field)
|
||||
return
|
||||
}
|
||||
|
||||
err := dt.Value.Set(field)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
|
||||
textEncoder, ok := dt.Value.(TextEncoder)
|
||||
if !ok {
|
||||
b.err = errors.Errorf("unable to encode text for value: %v", field)
|
||||
return
|
||||
}
|
||||
|
||||
b.AppendEncoder(textEncoder)
|
||||
}
|
||||
|
||||
func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0])
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
b.buf = append(b.buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...)
|
||||
}
|
||||
|
||||
b.buf = append(b.buf, ',')
|
||||
}
|
||||
|
||||
func (b *CompositeTextBuilder) Finish() ([]byte, error) {
|
||||
if b.err != nil {
|
||||
return nil, b.err
|
||||
}
|
||||
|
||||
b.buf[len(b.buf)-1] = ')'
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
|
||||
|
||||
func quoteCompositeField(src string) string {
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgtype/testutil"
|
||||
pgx "github.com/jackc/pgx/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompositeTypeSetAndGet(t *testing.T) {
|
||||
@@ -130,6 +132,47 @@ func TestCompositeTypeAssignTo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTypeTranscode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
_, err := conn.Exec(context.Background(), `drop type if exists ct_test;
|
||||
|
||||
create type ct_test as (
|
||||
a text,
|
||||
b int4
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
var oid uint32
|
||||
err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
ct := pgtype.NewCompositeType("ct_test", &pgtype.Text{}, &pgtype.Int4{})
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: "ct_test", OID: oid})
|
||||
|
||||
// Use simple protocol to force text or binary encoding
|
||||
simpleProtocols := []bool{true, false}
|
||||
|
||||
var a string
|
||||
var b int32
|
||||
|
||||
for _, simpleProtocol := range simpleProtocols {
|
||||
err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol),
|
||||
pgtype.CompositeFields{"hi", int32(42)},
|
||||
).Scan(
|
||||
[]interface{}{&a, &b},
|
||||
)
|
||||
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
|
||||
assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
|
||||
assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Example_composite() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user