Add support for pointers to pointers
Using types like **string allows the inner pointer to be nil’ed out, avoiding the need for NullX types. Signed-off-by: Jonathan Rudenberg <jonathan@titanous.com>
This commit is contained in:
committed by
Jack Christensen
parent
4ebb0508b6
commit
272262536b
@@ -14,6 +14,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -752,6 +753,14 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
case string:
|
case string:
|
||||||
err = encodeText(wbuf, arguments[i])
|
err = encodeText(wbuf, arguments[i])
|
||||||
default:
|
default:
|
||||||
|
if v := reflect.ValueOf(arguments[i]); v.Kind() == reflect.Ptr {
|
||||||
|
if v.IsNil() {
|
||||||
|
wbuf.WriteInt32(-1)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
arguments[i] = v.Elem().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
switch oid {
|
switch oid {
|
||||||
case BoolOid:
|
case BoolOid:
|
||||||
err = encodeBool(wbuf, arguments[i])
|
err = encodeBool(wbuf, arguments[i])
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -242,53 +243,74 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||||||
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
||||||
decodeJson(vr, &d)
|
decodeJson(vr, &d)
|
||||||
} else {
|
} else {
|
||||||
switch d := d.(type) {
|
decode:
|
||||||
|
switch v := d.(type) {
|
||||||
case *bool:
|
case *bool:
|
||||||
*d = decodeBool(vr)
|
*v = decodeBool(vr)
|
||||||
case *int64:
|
case *int64:
|
||||||
*d = decodeInt8(vr)
|
*v = decodeInt8(vr)
|
||||||
case *int16:
|
case *int16:
|
||||||
*d = decodeInt2(vr)
|
*v = decodeInt2(vr)
|
||||||
case *int32:
|
case *int32:
|
||||||
*d = decodeInt4(vr)
|
*v = decodeInt4(vr)
|
||||||
case *Oid:
|
case *Oid:
|
||||||
*d = decodeOid(vr)
|
*v = decodeOid(vr)
|
||||||
case *string:
|
case *string:
|
||||||
*d = decodeText(vr)
|
*v = decodeText(vr)
|
||||||
case *float32:
|
case *float32:
|
||||||
*d = decodeFloat4(vr)
|
*v = decodeFloat4(vr)
|
||||||
case *float64:
|
case *float64:
|
||||||
*d = decodeFloat8(vr)
|
*v = decodeFloat8(vr)
|
||||||
case *[]bool:
|
case *[]bool:
|
||||||
*d = decodeBoolArray(vr)
|
*v = decodeBoolArray(vr)
|
||||||
case *[]int16:
|
case *[]int16:
|
||||||
*d = decodeInt2Array(vr)
|
*v = decodeInt2Array(vr)
|
||||||
case *[]int32:
|
case *[]int32:
|
||||||
*d = decodeInt4Array(vr)
|
*v = decodeInt4Array(vr)
|
||||||
case *[]int64:
|
case *[]int64:
|
||||||
*d = decodeInt8Array(vr)
|
*v = decodeInt8Array(vr)
|
||||||
case *[]float32:
|
case *[]float32:
|
||||||
*d = decodeFloat4Array(vr)
|
*v = decodeFloat4Array(vr)
|
||||||
case *[]float64:
|
case *[]float64:
|
||||||
*d = decodeFloat8Array(vr)
|
*v = decodeFloat8Array(vr)
|
||||||
case *[]string:
|
case *[]string:
|
||||||
*d = decodeTextArray(vr)
|
*v = decodeTextArray(vr)
|
||||||
case *[]time.Time:
|
case *[]time.Time:
|
||||||
*d = decodeTimestampArray(vr)
|
*v = decodeTimestampArray(vr)
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
switch vr.Type().DataType {
|
switch vr.Type().DataType {
|
||||||
case DateOid:
|
case DateOid:
|
||||||
*d = decodeDate(vr)
|
*v = decodeDate(vr)
|
||||||
case TimestampTzOid:
|
case TimestampTzOid:
|
||||||
*d = decodeTimestampTz(vr)
|
*v = decodeTimestampTz(vr)
|
||||||
case TimestampOid:
|
case TimestampOid:
|
||||||
*d = decodeTimestamp(vr)
|
*v = decodeTimestamp(vr)
|
||||||
default:
|
default:
|
||||||
rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
|
rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
|
||||||
}
|
}
|
||||||
case *net.IPNet:
|
case *net.IPNet:
|
||||||
*d = decodeInet(vr)
|
*v = decodeInet(vr)
|
||||||
default:
|
default:
|
||||||
|
// if d is a pointer to pointer, strip the pointer and try again
|
||||||
|
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
|
||||||
|
if el := v.Elem(); el.Kind() == reflect.Ptr {
|
||||||
|
// -1 is a null value
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
if !el.IsNil() {
|
||||||
|
// if the destination pointer is not nil, nil it out
|
||||||
|
el.Set(reflect.Zero(el.Type()))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
if el.IsNil() {
|
||||||
|
// allocate destination
|
||||||
|
el.Set(reflect.New(el.Type().Elem()))
|
||||||
|
}
|
||||||
|
d = el.Interface()
|
||||||
|
goto decode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d))
|
rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -500,3 +500,101 @@ func TestNullXMismatch(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPointerPointer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
type allTypes struct {
|
||||||
|
s *string
|
||||||
|
i16 *int16
|
||||||
|
i32 *int32
|
||||||
|
i64 *int64
|
||||||
|
f32 *float32
|
||||||
|
f64 *float64
|
||||||
|
b *bool
|
||||||
|
t *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual, zero, expected allTypes
|
||||||
|
|
||||||
|
{
|
||||||
|
s := "foo"
|
||||||
|
expected.s = &s
|
||||||
|
i16 := int16(1)
|
||||||
|
expected.i16 = &i16
|
||||||
|
i32 := int32(1)
|
||||||
|
expected.i32 = &i32
|
||||||
|
i64 := int64(1)
|
||||||
|
expected.i64 = &i64
|
||||||
|
f32 := float32(1.23)
|
||||||
|
expected.f32 = &f32
|
||||||
|
f64 := float64(1.23)
|
||||||
|
expected.f64 = &f64
|
||||||
|
b := true
|
||||||
|
expected.b = &b
|
||||||
|
t := time.Unix(123, 5000)
|
||||||
|
expected.t = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
queryArgs []interface{}
|
||||||
|
scanArgs []interface{}
|
||||||
|
expected allTypes
|
||||||
|
}{
|
||||||
|
{"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}},
|
||||||
|
{"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}},
|
||||||
|
{"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}},
|
||||||
|
{"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}},
|
||||||
|
{"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}},
|
||||||
|
{"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}},
|
||||||
|
{"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}},
|
||||||
|
{"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}},
|
||||||
|
{"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}},
|
||||||
|
{"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}},
|
||||||
|
{"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}},
|
||||||
|
{"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}},
|
||||||
|
{"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}},
|
||||||
|
{"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}},
|
||||||
|
{"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}},
|
||||||
|
{"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}},
|
||||||
|
{"select $1::timestamp", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}},
|
||||||
|
{"select $1::timestamp", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
actual = zero
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual, tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPointerPointerNonZero(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
f := "foo"
|
||||||
|
dest := &f
|
||||||
|
|
||||||
|
err := conn.QueryRow("select $1::text", nil).Scan(&dest)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected failure scanning: %v", err)
|
||||||
|
}
|
||||||
|
if dest != nil {
|
||||||
|
t.Errorf("Expected dest to be nil, got %#v", dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user