Merge SetFields functionality into Set
This commit is contained in:
@@ -104,7 +104,7 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
c.SetFields(f1, f2)
|
c.Set([]interface{}{f1, f2})
|
||||||
buf, _ = c.EncodeBinary(ci, buf[:0])
|
buf, _ = c.EncodeBinary(ci, buf[:0])
|
||||||
}
|
}
|
||||||
x = buf
|
x = buf
|
||||||
|
|||||||
+15
-20
@@ -20,13 +20,17 @@ type CompositeType struct {
|
|||||||
// To read composite fields back pass result of Scan() method
|
// To read composite fields back pass result of Scan() method
|
||||||
// to query Scan function.
|
// to query Scan function.
|
||||||
func NewCompositeType(fields ...Value) *CompositeType {
|
func NewCompositeType(fields ...Value) *CompositeType {
|
||||||
return &CompositeType{fields, Present}
|
return &CompositeType{fields, Undefined}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (src CompositeType) Get() interface{} {
|
func (src CompositeType) Get() interface{} {
|
||||||
switch src.status {
|
switch src.status {
|
||||||
case Present:
|
case Present:
|
||||||
return src
|
results := make([]interface{}, len(src.fields))
|
||||||
|
for i := range results {
|
||||||
|
results[i] = src.fields[i].Get()
|
||||||
|
}
|
||||||
|
return results
|
||||||
case Null:
|
case Null:
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
@@ -34,17 +38,16 @@ func (src CompositeType) Get() interface{} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set is called internally when passing query arguments.
|
|
||||||
func (dst *CompositeType) Set(src interface{}) error {
|
func (dst *CompositeType) Set(src interface{}) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = CompositeType{status: Null}
|
dst.status = Null
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch value := src.(type) {
|
switch value := src.(type) {
|
||||||
case []Value:
|
case []interface{}:
|
||||||
if len(value) != len(dst.fields) {
|
if len(value) != len(dst.fields) {
|
||||||
return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields))
|
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.fields))
|
||||||
}
|
}
|
||||||
for i, v := range value {
|
for i, v := range value {
|
||||||
if err := dst.fields[i].Set(v); err != nil {
|
if err := dst.fields[i].Set(v); err != nil {
|
||||||
@@ -52,6 +55,12 @@ func (dst *CompositeType) Set(src interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
dst.status = Present
|
dst.status = Present
|
||||||
|
case *[]interface{}:
|
||||||
|
if value == nil {
|
||||||
|
dst.status = Null
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return dst.Set(*value)
|
||||||
default:
|
default:
|
||||||
return errors.Errorf("Can not convert %v to Composite", src)
|
return errors.Errorf("Can not convert %v to Composite", src)
|
||||||
}
|
}
|
||||||
@@ -138,20 +147,6 @@ func (src CompositeType) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFields sets Composite's fields to corresponding values
|
|
||||||
func (dst *CompositeType) SetFields(values ...interface{}) error {
|
|
||||||
if len(values) != len(dst.fields) {
|
|
||||||
return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields))
|
|
||||||
}
|
|
||||||
for i, v := range values {
|
|
||||||
if err := dst.fields[i].Set(v); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dst.status = Present
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompositeBinaryScanner struct {
|
type CompositeBinaryScanner struct {
|
||||||
rp int
|
rp int
|
||||||
src []byte
|
src []byte
|
||||||
|
|||||||
+45
-1
@@ -4,11 +4,55 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/jackc/pgtype"
|
"github.com/jackc/pgtype"
|
||||||
pgx "github.com/jackc/pgx/v4"
|
pgx "github.com/jackc/pgx/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestCompositeTypeSetAndGet(t *testing.T) {
|
||||||
|
ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{})
|
||||||
|
assert.Equal(t, pgtype.Undefined, ct.Get())
|
||||||
|
|
||||||
|
nilTests := []struct {
|
||||||
|
src interface{}
|
||||||
|
}{
|
||||||
|
{nil}, // nil interface
|
||||||
|
{(*[]interface{})(nil)}, // typed nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range nilTests {
|
||||||
|
err := ct.Set(tt.src)
|
||||||
|
assert.NoErrorf(t, err, "%d", i)
|
||||||
|
assert.Equal(t, nil, ct.Get())
|
||||||
|
}
|
||||||
|
|
||||||
|
compatibleValuesTests := []struct {
|
||||||
|
src []interface{}
|
||||||
|
expected []interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
src: []interface{}{"foo", int32(42)},
|
||||||
|
expected: []interface{}{"foo", int32(42)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
src: []interface{}{nil, nil},
|
||||||
|
expected: []interface{}{nil, nil},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}},
|
||||||
|
expected: []interface{}{"hi", int32(7)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range compatibleValuesTests {
|
||||||
|
err := ct.Set(tt.src)
|
||||||
|
assert.NoErrorf(t, err, "%d", i)
|
||||||
|
assert.EqualValues(t, tt.expected, ct.Get())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//ExampleComposite demonstrates use of Row() function to pass and receive
|
//ExampleComposite demonstrates use of Row() function to pass and receive
|
||||||
// back composite types without creating boilderplate custom types.
|
// back composite types without creating boilderplate custom types.
|
||||||
func Example_composite() {
|
func Example_composite() {
|
||||||
@@ -32,7 +76,7 @@ create type mytype as (
|
|||||||
var b *string
|
var b *string
|
||||||
|
|
||||||
c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{})
|
c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{})
|
||||||
c.SetFields(2, "bar")
|
c.Set([]interface{}{2, "bar"})
|
||||||
|
|
||||||
err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c).
|
err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c).
|
||||||
Scan(c.Scan(&isNull, &a, &b))
|
Scan(c.Scan(&isNull, &a, &b))
|
||||||
|
|||||||
Reference in New Issue
Block a user