From e51cb1ef09a161010a263455ce67125d9c42d8e5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 14:04:11 -0500 Subject: [PATCH] Add CompositeBinaryBuilder --- composite_bench_test.go | 23 +++++++------ composite_fields.go | 29 +++------------- composite_type.go | 76 ++++++++++++++++++++++++++++++++++------- convert.go | 23 ++++--------- 4 files changed, 87 insertions(+), 64 deletions(-) diff --git a/composite_bench_test.go b/composite_bench_test.go index e1dd6d04..4858ccad 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -3,6 +3,7 @@ package pgtype_test import ( "testing" + "github.com/jackc/pgio" "github.com/jackc/pgtype" errors "golang.org/x/xerrors" ) @@ -12,22 +13,22 @@ type MyCompositeRaw struct { B *string } -func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { - a := pgtype.Int4{src.A, pgtype.Present} +func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + buf = pgio.AppendUint32(buf, 2) - fieldBytes := make([]byte, 0, 64) - fieldBytes, _ = a.EncodeBinary(ci, fieldBytes[:0]) - - newBuf = pgtype.RecordStart(buf, 2) - newBuf = pgtype.RecordAdd(newBuf, pgtype.Int4OID, fieldBytes) + buf = pgio.AppendUint32(buf, pgtype.Int4OID) + buf = pgio.AppendInt32(buf, 4) + buf = pgio.AppendInt32(buf, src.A) + buf = pgio.AppendUint32(buf, pgtype.TextOID) if src.B != nil { - fieldBytes, _ = pgtype.Text{*src.B, pgtype.Present}.EncodeBinary(ci, fieldBytes[:0]) - newBuf = pgtype.RecordAdd(newBuf, pgtype.TextOID, fieldBytes) + buf = pgio.AppendInt32(buf, int32(len(*src.B))) + buf = append(buf, (*src.B)...) } else { - newBuf = pgtype.RecordAddNull(newBuf, pgtype.TextOID) + buf = pgio.AppendInt32(buf, -1) } - return + + return buf, nil } func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { diff --git a/composite_fields.go b/composite_fields.go index 751adce8..b97506eb 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -1,9 +1,6 @@ package pgtype import ( - "encoding/binary" - - "github.com/jackc/pgio" errors "golang.org/x/xerrors" ) @@ -143,7 +140,7 @@ func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { // * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail. // * No dereferencing will be done. e.g. *Text must be used instead of Text. func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - buf = pgio.AppendUint32(buf, uint32(len(cf))) + b := NewCompositeBinaryBuilder(ci, buf) for _, f := range cf { dt, ok := ci.DataTypeForValue(f) @@ -151,38 +148,20 @@ func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) return nil, errors.Errorf("Unknown OID for %#v", f) } - buf = pgio.AppendUint32(buf, dt.OID) - lengthPos := len(buf) - buf = pgio.AppendInt32(buf, -1) - if binaryEncoder, ok := f.(BinaryEncoder); ok { - fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if fieldBuf != nil { - binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf))) - buf = fieldBuf - } + b.AppendEncoder(dt.OID, binaryEncoder) } else { err := dt.Value.Set(f) if err != nil { return nil, err } if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { - fieldBuf, err := binaryEncoder.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if fieldBuf != nil { - binary.BigEndian.PutUint32(buf[lengthPos:], uint32(len(fieldBuf)-len(buf))) - buf = fieldBuf - } + b.AppendEncoder(dt.OID, binaryEncoder) } else { return nil, errors.Errorf("Cannot encode binary format for %v", f) } } } - return buf, nil + return b.Finish() } diff --git a/composite_type.go b/composite_type.go index b4b1ab28..99f0189f 100644 --- a/composite_type.go +++ b/composite_type.go @@ -350,22 +350,74 @@ func (cfs *CompositeTextScanner) Err() error { return cfs.err } -// RecordStart adds record header to the buf -func RecordStart(buf []byte, fieldCount int) []byte { - return pgio.AppendUint32(buf, uint32(fieldCount)) +type CompositeBinaryBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error } -// RecordAdd adds record field to the buf -func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { - buf = pgio.AppendUint32(buf, oid) - buf = pgio.AppendUint32(buf, uint32(len(fieldBytes))) - buf = append(buf, fieldBytes...) - return buf +func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder { + startIdx := len(buf) + buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields + return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx} } -// RecordAddNull adds null value as a field to the buf -func RecordAddNull(buf []byte, oid uint32) []byte { - return pgio.AppendInt32(buf, int32(-1)) +func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { + if b.err != nil { + return + } + + dt, ok := b.ci.DataTypeForOID(oid) + if !ok { + b.err = errors.Errorf("unknown data type for OID: %d", oid) + return + } + + err := dt.Value.Set(field) + if err != nil { + b.err = err + return + } + + binaryEncoder, ok := dt.Value.(BinaryEncoder) + if !ok { + b.err = errors.Errorf("unable to encode binary for OID: %d", oid) + return + } + + b.AppendEncoder(oid, binaryEncoder) +} + +func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) { + if b.err != nil { + return + } + + b.buf = pgio.AppendUint32(b.buf, oid) + lengthPos := len(b.buf) + b.buf = pgio.AppendInt32(b.buf, -1) + fieldBuf, err := field.EncodeBinary(b.ci, b.buf) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(b.buf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) + b.buf = fieldBuf + } + + b.fieldCount++ +} + +func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) + return b.buf, nil } var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) diff --git a/convert.go b/convert.go index 6e70e82e..f170e05b 100644 --- a/convert.go +++ b/convert.go @@ -435,30 +435,21 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // EncodeRow builds a binary representation of row values (row(), composite types) func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { - fieldBytes := make([]byte, 0, 128) + b := NewCompositeBinaryBuilder(ci, buf) - newBuf = RecordStart(buf, len(fields)) for _, f := range fields { dt, ok := ci.DataTypeForValue(f) if !ok { return nil, errors.Errorf("Unknown OID for %s", f) } - if f.Get() != nil { - binaryEncoder, ok := f.(BinaryEncoder) - if !ok { - return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name()) - } - fieldBytes, err = binaryEncoder.EncodeBinary(ci, fieldBytes[:0]) - if err != nil { - return nil, err - } - newBuf = RecordAdd(newBuf, dt.OID, fieldBytes) - } else { - newBuf = RecordAddNull(newBuf, dt.OID) + binaryEncoder, ok := f.(BinaryEncoder) + if !ok { + return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name()) } - + b.AppendEncoder(dt.OID, binaryEncoder) } - return + + return b.Finish() } func init() {