2
0

Add more testing of Encode*

Handle case where TextEncoder is used to a core type that the driver
could otherwise have handled as binary.
This commit is contained in:
Jack Christensen
2014-07-11 11:16:12 -05:00
parent 6884fdfb52
commit 24395d98df
4 changed files with 177 additions and 92 deletions
+37 -35
View File
@@ -734,16 +734,19 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
for i, oid := range ps.ParameterOids {
switch oid {
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid:
switch arg := arguments[i].(type) {
case BinaryEncoder:
wbuf.WriteInt16(BinaryFormatCode)
case TextOid, VarcharOid, DateOid, TimestampTzOid:
case TextEncoder:
wbuf.WriteInt16(TextFormatCode)
default:
if _, ok := arguments[i].(BinaryEncoder); ok {
switch oid {
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid:
wbuf.WriteInt16(BinaryFormatCode)
} else {
case TextOid, VarcharOid, DateOid, TimestampTzOid:
wbuf.WriteInt16(TextFormatCode)
default:
return SerializationError(fmt.Sprintf("Parameter %d oid %d is not a core type and argument type %T does not implement TextEncoder or BinaryEncoder", i, oid, arg))
}
}
}
@@ -755,41 +758,40 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
continue
}
switch oid {
case BoolOid:
err = encodeBool(wbuf, arguments[i])
case ByteaOid:
err = encodeBytea(wbuf, arguments[i])
case Int2Oid:
err = encodeInt2(wbuf, arguments[i])
case Int4Oid:
err = encodeInt4(wbuf, arguments[i])
case Int8Oid:
err = encodeInt8(wbuf, arguments[i])
case Float4Oid:
err = encodeFloat4(wbuf, arguments[i])
case Float8Oid:
err = encodeFloat8(wbuf, arguments[i])
case TextOid, VarcharOid:
err = encodeText(wbuf, arguments[i])
case DateOid:
err = encodeDate(wbuf, arguments[i])
case TimestampTzOid:
err = encodeTimestampTz(wbuf, arguments[i])
switch arg := arguments[i].(type) {
case BinaryEncoder:
err = arg.EncodeBinary(wbuf)
case TextEncoder:
var s string
s, err = arg.EncodeText()
wbuf.WriteInt32(int32(len(s)))
wbuf.WriteBytes([]byte(s))
default:
switch arg := arguments[i].(type) {
case BinaryEncoder:
err = arg.EncodeBinary(wbuf)
case TextEncoder:
var s string
s, err = arg.EncodeText()
wbuf.WriteInt32(int32(len(s)))
wbuf.WriteBytes([]byte(s))
switch oid {
case BoolOid:
err = encodeBool(wbuf, arguments[i])
case ByteaOid:
err = encodeBytea(wbuf, arguments[i])
case Int2Oid:
err = encodeInt2(wbuf, arguments[i])
case Int4Oid:
err = encodeInt4(wbuf, arguments[i])
case Int8Oid:
err = encodeInt8(wbuf, arguments[i])
case Float4Oid:
err = encodeFloat4(wbuf, arguments[i])
case Float8Oid:
err = encodeFloat8(wbuf, arguments[i])
case TextOid, VarcharOid:
err = encodeText(wbuf, arguments[i])
case DateOid:
err = encodeDate(wbuf, arguments[i])
case TimestampTzOid:
err = encodeTimestampTz(wbuf, arguments[i])
default:
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg))
}
}
if err != nil {
return err
}
+27
View File
@@ -604,6 +604,33 @@ func TestQueryPreparedEncodeError(t *testing.T) {
}
}
// Ensure that an argument that implements TextEncoder, but not BinaryEncoder
// works when the parameter type is a core type.
type coreTextEncoder struct{}
func (n *coreTextEncoder) EncodeText() (string, error) {
return "42", nil
}
func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustPrepare(t, conn, "testTranscode", "select $1::integer")
var n int32
err := conn.QueryRow("testTranscode", &coreTextEncoder{}).Scan(&n)
if err != nil {
t.Fatalf("Unexpected conn.QueryRow error: %v", err)
}
if n != 42 {
t.Errorf("Expected 42, got %v", n)
}
}
func TestPrepare(t *testing.T) {
t.Parallel()
+69 -57
View File
@@ -65,6 +65,33 @@ type NullInt64 struct {
Valid bool // Valid is true if Int64 is not NULL
}
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
if size == -1 {
n.Int64, n.Valid = 0, false
return nil
}
n.Valid = true
n.Int64 = decodeInt8(rows, fd, size)
return rows.Err()
}
func (n *NullInt64) EncodeText() (string, error) {
if n.Valid {
return strconv.FormatInt(int64(n.Int64), 10), nil
} else {
return "null", nil
}
}
func (n *NullInt64) EncodeBinary(w *WriteBuf) error {
if !n.Valid {
w.WriteInt32(-1)
return nil
}
return encodeInt8(w, n.Int64)
}
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
// QuoteString escapes and quotes a string making it safe for interpolation
@@ -96,70 +123,55 @@ func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
return
}
switch arg := args[n-1].(type) {
case string:
return QuoteString(arg)
case int:
return strconv.FormatInt(int64(arg), 10)
case int8:
return strconv.FormatInt(int64(arg), 10)
case int16:
return strconv.FormatInt(int64(arg), 10)
case int32:
return strconv.FormatInt(int64(arg), 10)
case int64:
return strconv.FormatInt(int64(arg), 10)
case time.Time:
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
case uint:
return strconv.FormatUint(uint64(arg), 10)
case uint8:
return strconv.FormatUint(uint64(arg), 10)
case uint16:
return strconv.FormatUint(uint64(arg), 10)
case uint32:
return strconv.FormatUint(uint64(arg), 10)
case uint64:
return strconv.FormatUint(uint64(arg), 10)
case float32:
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
case float64:
return strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
return strconv.FormatBool(arg)
case []byte:
return `E'\\x` + hex.EncodeToString(arg) + `'`
case nil:
return "null"
case TextEncoder:
var s string
s, err = arg.EncodeText()
return s
default:
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
return ""
}
var s string
s, err = sanitizeArg(args[n-1])
return s
}
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
return
}
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
if size == -1 {
n.Int64, n.Valid = 0, false
return nil
}
n.Valid = true
n.Int64 = decodeInt8(rows, fd, size)
return rows.Err()
}
func (n *NullInt64) EncodeText() (string, error) {
if n.Valid {
return strconv.FormatInt(int64(n.Int64), 10), nil
} else {
func sanitizeArg(arg interface{}) (string, error) {
switch arg := arg.(type) {
case string:
return QuoteString(arg), nil
case int:
return strconv.FormatInt(int64(arg), 10), nil
case int8:
return strconv.FormatInt(int64(arg), 10), nil
case int16:
return strconv.FormatInt(int64(arg), 10), nil
case int32:
return strconv.FormatInt(int64(arg), 10), nil
case int64:
return strconv.FormatInt(int64(arg), 10), nil
case time.Time:
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")), nil
case uint:
return strconv.FormatUint(uint64(arg), 10), nil
case uint8:
return strconv.FormatUint(uint64(arg), 10), nil
case uint16:
return strconv.FormatUint(uint64(arg), 10), nil
case uint32:
return strconv.FormatUint(uint64(arg), 10), nil
case uint64:
return strconv.FormatUint(uint64(arg), 10), nil
case float32:
return strconv.FormatFloat(float64(arg), 'f', -1, 32), nil
case float64:
return strconv.FormatFloat(arg, 'f', -1, 64), nil
case bool:
return strconv.FormatBool(arg), nil
case []byte:
return `E'\\x` + hex.EncodeToString(arg) + `'`, nil
case nil:
return "null", nil
case TextEncoder:
return arg.EncodeText()
default:
return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
}
}
+44
View File
@@ -1,6 +1,7 @@
package pgx_test
import (
"fmt"
"github.com/jackc/pgx"
"strings"
"testing"
@@ -185,3 +186,46 @@ func TestTimestampTzTranscode(t *testing.T) {
t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
}
}
func TestNullX(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
type allTypes struct {
i64 pgx.NullInt64
}
var actual, zero allTypes
tests := []struct {
sql string
queryArgs []interface{}
scanArgs []interface{}
expected allTypes
}{
{"select $1::int8", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}},
{"select $1::int8", []interface{}{&pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}},
}
for i, tt := range tests {
psName := fmt.Sprintf("success%d", i)
mustPrepare(t, conn, psName, tt.sql)
for _, sql := range []string{tt.sql, psName} {
actual = zero
err := conn.QueryRow(sql, tt.queryArgs...).Scan(tt.scanArgs...)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, sql, tt.queryArgs)
}
if actual != tt.expected {
t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, sql, tt.queryArgs)
}
ensureConnValid(t, conn)
}
}
}