From 24395d98df8489c755a05095f19fd8d8e439730f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 11 Jul 2014 11:16:12 -0500 Subject: [PATCH] Add more testing of Encode* Handle case where TextEncoder is used to a core type that the driver could otherwise have handled as binary. --- conn.go | 72 ++++++++++++++-------------- conn_test.go | 27 +++++++++++ values.go | 126 +++++++++++++++++++++++++++---------------------- values_test.go | 44 +++++++++++++++++ 4 files changed, 177 insertions(+), 92 deletions(-) diff --git a/conn.go b/conn.go index 780566a5..a65eb272 100644 --- a/conn.go +++ b/conn.go @@ -734,16 +734,19 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(int16(len(ps.ParameterOids))) for i, oid := range ps.ParameterOids { - switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid: + switch arg := arguments[i].(type) { + case BinaryEncoder: wbuf.WriteInt16(BinaryFormatCode) - case TextOid, VarcharOid, DateOid, TimestampTzOid: + case TextEncoder: wbuf.WriteInt16(TextFormatCode) default: - if _, ok := arguments[i].(BinaryEncoder); ok { + switch oid { + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid: wbuf.WriteInt16(BinaryFormatCode) - } else { + case TextOid, VarcharOid, DateOid, TimestampTzOid: wbuf.WriteInt16(TextFormatCode) + default: + return SerializationError(fmt.Sprintf("Parameter %d oid %d is not a core type and argument type %T does not implement TextEncoder or BinaryEncoder", i, oid, arg)) } } } @@ -755,41 +758,40 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} continue } - switch oid { - case BoolOid: - err = encodeBool(wbuf, arguments[i]) - case ByteaOid: - err = encodeBytea(wbuf, arguments[i]) - case Int2Oid: - err = encodeInt2(wbuf, arguments[i]) - case Int4Oid: - err = encodeInt4(wbuf, arguments[i]) - case Int8Oid: - err = encodeInt8(wbuf, arguments[i]) - case Float4Oid: - err = encodeFloat4(wbuf, arguments[i]) - case Float8Oid: - err = encodeFloat8(wbuf, arguments[i]) - case TextOid, VarcharOid: - err = encodeText(wbuf, arguments[i]) - case DateOid: - err = encodeDate(wbuf, arguments[i]) - case TimestampTzOid: - err = encodeTimestampTz(wbuf, arguments[i]) + switch arg := arguments[i].(type) { + case BinaryEncoder: + err = arg.EncodeBinary(wbuf) + case TextEncoder: + var s string + s, err = arg.EncodeText() + wbuf.WriteInt32(int32(len(s))) + wbuf.WriteBytes([]byte(s)) default: - switch arg := arguments[i].(type) { - case BinaryEncoder: - err = arg.EncodeBinary(wbuf) - case TextEncoder: - var s string - s, err = arg.EncodeText() - wbuf.WriteInt32(int32(len(s))) - wbuf.WriteBytes([]byte(s)) + switch oid { + case BoolOid: + err = encodeBool(wbuf, arguments[i]) + case ByteaOid: + err = encodeBytea(wbuf, arguments[i]) + case Int2Oid: + err = encodeInt2(wbuf, arguments[i]) + case Int4Oid: + err = encodeInt4(wbuf, arguments[i]) + case Int8Oid: + err = encodeInt8(wbuf, arguments[i]) + case Float4Oid: + err = encodeFloat4(wbuf, arguments[i]) + case Float8Oid: + err = encodeFloat8(wbuf, arguments[i]) + case TextOid, VarcharOid: + err = encodeText(wbuf, arguments[i]) + case DateOid: + err = encodeDate(wbuf, arguments[i]) + case TimestampTzOid: + err = encodeTimestampTz(wbuf, arguments[i]) default: return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg)) } } - if err != nil { return err } diff --git a/conn_test.go b/conn_test.go index 218d0200..901fddfd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -604,6 +604,33 @@ func TestQueryPreparedEncodeError(t *testing.T) { } } +// Ensure that an argument that implements TextEncoder, but not BinaryEncoder +// works when the parameter type is a core type. +type coreTextEncoder struct{} + +func (n *coreTextEncoder) EncodeText() (string, error) { + return "42", nil +} + +func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustPrepare(t, conn, "testTranscode", "select $1::integer") + + var n int32 + err := conn.QueryRow("testTranscode", &coreTextEncoder{}).Scan(&n) + if err != nil { + t.Fatalf("Unexpected conn.QueryRow error: %v", err) + } + + if n != 42 { + t.Errorf("Expected 42, got %v", n) + } +} + func TestPrepare(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index 5cb07d0e..603a7731 100644 --- a/values.go +++ b/values.go @@ -65,6 +65,33 @@ type NullInt64 struct { Valid bool // Valid is true if Int64 is not NULL } +func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error { + if size == -1 { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + n.Int64 = decodeInt8(rows, fd, size) + return rows.Err() +} + +func (n *NullInt64) EncodeText() (string, error) { + if n.Valid { + return strconv.FormatInt(int64(n.Int64), 10), nil + } else { + return "null", nil + } +} + +func (n *NullInt64) EncodeBinary(w *WriteBuf) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeInt8(w, n.Int64) +} + var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`) // QuoteString escapes and quotes a string making it safe for interpolation @@ -96,70 +123,55 @@ func SanitizeSql(sql string, args ...interface{}) (output string, err error) { return } - switch arg := args[n-1].(type) { - case string: - return QuoteString(arg) - case int: - return strconv.FormatInt(int64(arg), 10) - case int8: - return strconv.FormatInt(int64(arg), 10) - case int16: - return strconv.FormatInt(int64(arg), 10) - case int32: - return strconv.FormatInt(int64(arg), 10) - case int64: - return strconv.FormatInt(int64(arg), 10) - case time.Time: - return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")) - case uint: - return strconv.FormatUint(uint64(arg), 10) - case uint8: - return strconv.FormatUint(uint64(arg), 10) - case uint16: - return strconv.FormatUint(uint64(arg), 10) - case uint32: - return strconv.FormatUint(uint64(arg), 10) - case uint64: - return strconv.FormatUint(uint64(arg), 10) - case float32: - return strconv.FormatFloat(float64(arg), 'f', -1, 32) - case float64: - return strconv.FormatFloat(arg, 'f', -1, 64) - case bool: - return strconv.FormatBool(arg) - case []byte: - return `E'\\x` + hex.EncodeToString(arg) + `'` - case nil: - return "null" - case TextEncoder: - var s string - s, err = arg.EncodeText() - return s - default: - err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg)) - return "" - } + var s string + s, err = sanitizeArg(args[n-1]) + return s } output = literalPattern.ReplaceAllStringFunc(sql, replacer) return } -func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error { - if size == -1 { - n.Int64, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int64 = decodeInt8(rows, fd, size) - return rows.Err() -} - -func (n *NullInt64) EncodeText() (string, error) { - if n.Valid { - return strconv.FormatInt(int64(n.Int64), 10), nil - } else { +func sanitizeArg(arg interface{}) (string, error) { + switch arg := arg.(type) { + case string: + return QuoteString(arg), nil + case int: + return strconv.FormatInt(int64(arg), 10), nil + case int8: + return strconv.FormatInt(int64(arg), 10), nil + case int16: + return strconv.FormatInt(int64(arg), 10), nil + case int32: + return strconv.FormatInt(int64(arg), 10), nil + case int64: + return strconv.FormatInt(int64(arg), 10), nil + case time.Time: + return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")), nil + case uint: + return strconv.FormatUint(uint64(arg), 10), nil + case uint8: + return strconv.FormatUint(uint64(arg), 10), nil + case uint16: + return strconv.FormatUint(uint64(arg), 10), nil + case uint32: + return strconv.FormatUint(uint64(arg), 10), nil + case uint64: + return strconv.FormatUint(uint64(arg), 10), nil + case float32: + return strconv.FormatFloat(float64(arg), 'f', -1, 32), nil + case float64: + return strconv.FormatFloat(arg, 'f', -1, 64), nil + case bool: + return strconv.FormatBool(arg), nil + case []byte: + return `E'\\x` + hex.EncodeToString(arg) + `'`, nil + case nil: return "null", nil + case TextEncoder: + return arg.EncodeText() + default: + return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg)) } } diff --git a/values_test.go b/values_test.go index 2285ad5d..f36cef33 100644 --- a/values_test.go +++ b/values_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "fmt" "github.com/jackc/pgx" "strings" "testing" @@ -185,3 +186,46 @@ func TestTimestampTzTranscode(t *testing.T) { t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) } } + +func TestNullX(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + i64 pgx.NullInt64 + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::int8", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, + {"select $1::int8", []interface{}{&pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, + } + + for i, tt := range tests { + psName := fmt.Sprintf("success%d", i) + mustPrepare(t, conn, psName, tt.sql) + + for _, sql := range []string{tt.sql, psName} { + actual = zero + + err := conn.QueryRow(sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, sql, tt.queryArgs) + } + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } + } +}