From 7ff405ff840a0f1177039f5a2aa384dd3fb3e3c2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 10 Apr 2017 08:58:51 -0500 Subject: [PATCH] Add simple protocol suuport with (Query|Exec)Ex --- cid_test.go | 17 +++++++++++++++-- json.go | 2 +- numeric.go | 21 +++++++++++++++++++-- numeric_test.go | 3 +++ pgtype_test.go | 31 +++++++++++++++++++++++++++++++ xid_test.go | 17 +++++++++++++++-- 6 files changed, 84 insertions(+), 7 deletions(-) diff --git a/cid_test.go b/cid_test.go index 0d114cda..210573f6 100644 --- a/cid_test.go +++ b/cid_test.go @@ -8,10 +8,23 @@ import ( ) func TestCidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "cid", []interface{}{ + pgTypeName := "cid" + values := []interface{}{ pgtype.Cid{Uint: 42, Status: pgtype.Present}, pgtype.Cid{Status: pgtype.Null}, - }) + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + // No direct conversion from int to cid, convert through text + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } } func TestCidSet(t *testing.T) { diff --git a/json.go b/json.go index 05d965ca..b1c061f9 100644 --- a/json.go +++ b/json.go @@ -145,7 +145,7 @@ func (dst *Json) Scan(src interface{}) error { func (src Json) Value() (driver.Value, error) { switch src.Status { case Present: - return src.Bytes, nil + return string(src.Bytes), nil case Null: return nil, nil default: diff --git a/numeric.go b/numeric.go index 0f3f6529..a26e8c89 100644 --- a/numeric.go +++ b/numeric.go @@ -121,13 +121,13 @@ func (src *Numeric) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { case *float32: - f, err := strconv.ParseFloat(src.Int.String(), 64) + f, err := src.toFloat64() if err != nil { return err } return float64AssignTo(f, src.Status, dst) case *float64: - f, err := strconv.ParseFloat(src.Int.String(), 64) + f, err := src.toFloat64() if err != nil { return err } @@ -283,6 +283,23 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { return num, nil } +func (src *Numeric) toFloat64() (float64, error) { + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return 0, err + } + if src.Exp > 0 { + for i := 0; i < int(src.Exp); i++ { + f *= 10 + } + } else if src.Exp < 0 { + for i := 0; i > int(src.Exp); i-- { + f /= 10 + } + } + return f, nil +} + func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Numeric{Status: Null} diff --git a/numeric_test.go b/numeric_test.go index 64dea847..93aa8866 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -247,9 +247,12 @@ func TestNumericAssignTo(t *testing.T) { }{ {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, diff --git a/pgtype_test.go b/pgtype_test.go index 0b1ffc54..f486f077 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "database/sql" "fmt" "io" @@ -125,6 +126,7 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } @@ -175,6 +177,35 @@ func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } } +func testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRowEx( + context.Background(), + fmt.Sprintf("select ($1)::%s", pgTypeName), + &pgx.QueryExOptions{SimpleProtocol: true}, + v, + ).Scan(result.Interface()) + if err != nil { + t.Errorf("Simple protocol %d: %v", i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) + } + } +} + func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := mustConnectDatabaseSQL(t, driverName) defer mustClose(t, conn) diff --git a/xid_test.go b/xid_test.go index fecfb64b..11dd0615 100644 --- a/xid_test.go +++ b/xid_test.go @@ -8,10 +8,23 @@ import ( ) func TestXidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "xid", []interface{}{ + pgTypeName := "xid" + values := []interface{}{ pgtype.Xid{Uint: 42, Status: pgtype.Present}, pgtype.Xid{Status: pgtype.Null}, - }) + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + // No direct conversion from int to xid, convert through text + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } } func TestXidSet(t *testing.T) {