diff --git a/conn.go b/conn.go index ce0539e6..5cb75126 100644 --- a/conn.go +++ b/conn.go @@ -14,6 +14,7 @@ import ( "os" "os/user" "path/filepath" + "reflect" "regexp" "strconv" "strings" @@ -752,6 +753,14 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} case string: err = encodeText(wbuf, arguments[i]) default: + if v := reflect.ValueOf(arguments[i]); v.Kind() == reflect.Ptr { + if v.IsNil() { + wbuf.WriteInt32(-1) + continue + } else { + arguments[i] = v.Elem().Interface() + } + } switch oid { case BoolOid: err = encodeBool(wbuf, arguments[i]) diff --git a/query.go b/query.go index 93527e20..0fd4d6ea 100644 --- a/query.go +++ b/query.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "reflect" "time" ) @@ -242,53 +243,74 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid { decodeJson(vr, &d) } else { - switch d := d.(type) { + decode: + switch v := d.(type) { case *bool: - *d = decodeBool(vr) + *v = decodeBool(vr) case *int64: - *d = decodeInt8(vr) + *v = decodeInt8(vr) case *int16: - *d = decodeInt2(vr) + *v = decodeInt2(vr) case *int32: - *d = decodeInt4(vr) + *v = decodeInt4(vr) case *Oid: - *d = decodeOid(vr) + *v = decodeOid(vr) case *string: - *d = decodeText(vr) + *v = decodeText(vr) case *float32: - *d = decodeFloat4(vr) + *v = decodeFloat4(vr) case *float64: - *d = decodeFloat8(vr) + *v = decodeFloat8(vr) case *[]bool: - *d = decodeBoolArray(vr) + *v = decodeBoolArray(vr) case *[]int16: - *d = decodeInt2Array(vr) + *v = decodeInt2Array(vr) case *[]int32: - *d = decodeInt4Array(vr) + *v = decodeInt4Array(vr) case *[]int64: - *d = decodeInt8Array(vr) + *v = decodeInt8Array(vr) case *[]float32: - *d = decodeFloat4Array(vr) + *v = decodeFloat4Array(vr) case *[]float64: - *d = decodeFloat8Array(vr) + *v = decodeFloat8Array(vr) case *[]string: - *d = decodeTextArray(vr) + *v = decodeTextArray(vr) case *[]time.Time: - *d = decodeTimestampArray(vr) + *v = decodeTimestampArray(vr) case *time.Time: switch vr.Type().DataType { case DateOid: - *d = decodeDate(vr) + *v = decodeDate(vr) case TimestampTzOid: - *d = decodeTimestampTz(vr) + *v = decodeTimestampTz(vr) case TimestampOid: - *d = decodeTimestamp(vr) + *v = decodeTimestamp(vr) default: rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)) } case *net.IPNet: - *d = decodeInet(vr) + *v = decodeInet(vr) default: + // if d is a pointer to pointer, strip the pointer and try again + if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { + if el := v.Elem(); el.Kind() == reflect.Ptr { + // -1 is a null value + if vr.Len() == -1 { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + continue + } else { + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + d = el.Interface() + goto decode + } + } + } rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d)) } diff --git a/values_test.go b/values_test.go index 46e79d11..b0d141e3 100644 --- a/values_test.go +++ b/values_test.go @@ -500,3 +500,101 @@ func TestNullXMismatch(t *testing.T) { ensureConnValid(t, conn) } } + +func TestPointerPointer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s *string + i16 *int16 + i32 *int32 + i64 *int64 + f32 *float32 + f64 *float64 + b *bool + t *time.Time + } + + var actual, zero, expected allTypes + + { + s := "foo" + expected.s = &s + i16 := int16(1) + expected.i16 = &i16 + i32 := int32(1) + expected.i32 = &i32 + i64 := int64(1) + expected.i64 = &i64 + f32 := float32(1.23) + expected.f32 = &f32 + f64 := float64(1.23) + expected.f64 = &f64 + b := true + expected.b = &b + t := time.Unix(123, 5000) + expected.t = &t + } + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, + {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, + {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, + {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, + {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, + {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, + {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, + {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, + {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, + {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, + {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, + {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, + {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, + {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, + {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, + {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, + {"select $1::timestamp", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, + {"select $1::timestamp", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + +func TestPointerPointerNonZero(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + f := "foo" + dest := &f + + err := conn.QueryRow("select $1::text", nil).Scan(&dest) + if err != nil { + t.Errorf("Unexpected failure scanning: %v", err) + } + if dest != nil { + t.Errorf("Expected dest to be nil, got %#v", dest) + } +}