Add composite to arbitrary struct encoding and decoding
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -619,3 +620,39 @@ func (w byteSliceWrapper) UUIDValue() (UUID, error) {
|
|||||||
copy(uuid.Bytes[:], w)
|
copy(uuid.Bytes[:], w)
|
||||||
return uuid, nil
|
return uuid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// structWrapper implements CompositeIndexGetter for a struct.
|
||||||
|
type structWrapper struct {
|
||||||
|
s interface{}
|
||||||
|
exportedFields []reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w structWrapper) IsNull() bool {
|
||||||
|
return w.s == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w structWrapper) Index(i int) interface{} {
|
||||||
|
if i >= len(w.exportedFields) {
|
||||||
|
return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.exportedFields[i].Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct.
|
||||||
|
type ptrStructWrapper struct {
|
||||||
|
s interface{}
|
||||||
|
exportedFields []reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ptrStructWrapper) ScanNull() error {
|
||||||
|
return fmt.Errorf("cannot scan NULL into %#v", w.s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ptrStructWrapper) ScanIndex(i int) interface{} {
|
||||||
|
if i >= len(w.exportedFields) {
|
||||||
|
return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.exportedFields[i].Addr().Interface()
|
||||||
|
}
|
||||||
|
|||||||
@@ -123,3 +123,42 @@ create type point3d as (
|
|||||||
require.Equalf(t, input, output, "%v", format.name)
|
require.Equalf(t, input, output, "%v", format.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
|
||||||
|
conn := testutil.MustConnectPgx(t)
|
||||||
|
defer testutil.MustCloseContext(t, conn)
|
||||||
|
|
||||||
|
_, err := conn.Exec(context.Background(), `drop type if exists point3d;
|
||||||
|
|
||||||
|
create type point3d as (
|
||||||
|
x float8,
|
||||||
|
y float8,
|
||||||
|
z float8
|
||||||
|
);`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Exec(context.Background(), "drop type point3d")
|
||||||
|
|
||||||
|
dt, err := conn.LoadDataType(context.Background(), "point3d")
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn.ConnInfo().RegisterDataType(*dt)
|
||||||
|
|
||||||
|
formats := []struct {
|
||||||
|
name string
|
||||||
|
code int16
|
||||||
|
}{
|
||||||
|
{name: "TextFormat", code: pgx.TextFormatCode},
|
||||||
|
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
|
||||||
|
}
|
||||||
|
|
||||||
|
type anotherPoint struct {
|
||||||
|
X, Y, Z float64
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, format := range formats {
|
||||||
|
input := anotherPoint{X: 1, Y: 2, Z: 3}
|
||||||
|
var output anotherPoint
|
||||||
|
err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output)
|
||||||
|
require.NoErrorf(t, err, "%v", format.name)
|
||||||
|
require.Equalf(t, input, output, "%v", format.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -203,12 +203,14 @@ func NewConnInfo() *ConnInfo {
|
|||||||
TryWrapDerefPointerEncodePlan,
|
TryWrapDerefPointerEncodePlan,
|
||||||
TryWrapBuiltinTypeEncodePlan,
|
TryWrapBuiltinTypeEncodePlan,
|
||||||
TryWrapFindUnderlyingTypeEncodePlan,
|
TryWrapFindUnderlyingTypeEncodePlan,
|
||||||
|
TryWrapStructEncodePlan,
|
||||||
},
|
},
|
||||||
|
|
||||||
TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{
|
TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{
|
||||||
TryPointerPointerScanPlan,
|
TryPointerPointerScanPlan,
|
||||||
TryWrapBuiltinTypeScanPlan,
|
TryWrapBuiltinTypeScanPlan,
|
||||||
TryFindUnderlyingTypeScanPlan,
|
TryFindUnderlyingTypeScanPlan,
|
||||||
|
TryWrapStructScanPlan,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -887,6 +889,47 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst interface{}) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
|
||||||
|
func TryWrapStructScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) {
|
||||||
|
targetValue := reflect.ValueOf(target)
|
||||||
|
if targetValue.Kind() != reflect.Ptr {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
targetElemValue := targetValue.Elem()
|
||||||
|
targetElemType := targetElemValue.Type()
|
||||||
|
|
||||||
|
if targetElemType.Kind() == reflect.Struct {
|
||||||
|
exportedFields := getExportedFieldValues(targetElemValue)
|
||||||
|
if len(exportedFields) == 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
w := ptrStructWrapper{
|
||||||
|
s: target,
|
||||||
|
exportedFields: exportedFields,
|
||||||
|
}
|
||||||
|
return &wrapAnyPtrStructScanPlan{}, &w, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
type wrapAnyPtrStructScanPlan struct {
|
||||||
|
next ScanPlan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plan *wrapAnyPtrStructScanPlan) SetNext(next ScanPlan) { plan.next = next }
|
||||||
|
|
||||||
|
func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error {
|
||||||
|
w := ptrStructWrapper{
|
||||||
|
s: target,
|
||||||
|
exportedFields: getExportedFieldValues(reflect.ValueOf(target).Elem()),
|
||||||
|
}
|
||||||
|
|
||||||
|
return plan.next.Scan(src, &w)
|
||||||
|
}
|
||||||
|
|
||||||
// PlanScan prepares a plan to scan a value into target.
|
// PlanScan prepares a plan to scan a value into target.
|
||||||
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan {
|
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan {
|
||||||
if _, ok := target.(*UndecodedBytes); ok {
|
if _, ok := target.(*UndecodedBytes); ok {
|
||||||
@@ -1406,6 +1449,52 @@ func (plan *wrapFmtStringerEncodePlan) Encode(value interface{}, buf []byte) (ne
|
|||||||
return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf)
|
return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
|
||||||
|
func TryWrapStructEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) {
|
||||||
|
if reflect.TypeOf(value).Kind() == reflect.Struct {
|
||||||
|
exportedFields := getExportedFieldValues(reflect.ValueOf(value))
|
||||||
|
if len(exportedFields) == 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
w := structWrapper{
|
||||||
|
s: value,
|
||||||
|
exportedFields: exportedFields,
|
||||||
|
}
|
||||||
|
return &wrapAnyStructEncodePlan{}, w, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
type wrapAnyStructEncodePlan struct {
|
||||||
|
next EncodePlan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plan *wrapAnyStructEncodePlan) SetNext(next EncodePlan) { plan.next = next }
|
||||||
|
|
||||||
|
func (plan *wrapAnyStructEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
|
||||||
|
w := structWrapper{
|
||||||
|
s: value,
|
||||||
|
exportedFields: getExportedFieldValues(reflect.ValueOf(value)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return plan.next.Encode(w, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExportedFieldValues(structValue reflect.Value) []reflect.Value {
|
||||||
|
structType := structValue.Type()
|
||||||
|
exportedFields := make([]reflect.Value, 0, structValue.NumField())
|
||||||
|
for i := 0; i < structType.NumField(); i++ {
|
||||||
|
sf := structType.Field(i)
|
||||||
|
if sf.IsExported() {
|
||||||
|
exportedFields = append(exportedFields, structValue.Field(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return exportedFields
|
||||||
|
}
|
||||||
|
|
||||||
// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return
|
// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return
|
||||||
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
|
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
|
||||||
// written.
|
// written.
|
||||||
|
|||||||
Reference in New Issue
Block a user