Add more testing of Encode*
Handle case where TextEncoder is used to a core type that the driver could otherwise have handled as binary.
This commit is contained in:
@@ -734,16 +734,19 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||
|
||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||
for i, oid := range ps.ParameterOids {
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid:
|
||||
switch arg := arguments[i].(type) {
|
||||
case BinaryEncoder:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
case TextOid, VarcharOid, DateOid, TimestampTzOid:
|
||||
case TextEncoder:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
if _, ok := arguments[i].(BinaryEncoder); ok {
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
} else {
|
||||
case TextOid, VarcharOid, DateOid, TimestampTzOid:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
return SerializationError(fmt.Sprintf("Parameter %d oid %d is not a core type and argument type %T does not implement TextEncoder or BinaryEncoder", i, oid, arg))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -755,41 +758,40 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||
continue
|
||||
}
|
||||
|
||||
switch oid {
|
||||
case BoolOid:
|
||||
err = encodeBool(wbuf, arguments[i])
|
||||
case ByteaOid:
|
||||
err = encodeBytea(wbuf, arguments[i])
|
||||
case Int2Oid:
|
||||
err = encodeInt2(wbuf, arguments[i])
|
||||
case Int4Oid:
|
||||
err = encodeInt4(wbuf, arguments[i])
|
||||
case Int8Oid:
|
||||
err = encodeInt8(wbuf, arguments[i])
|
||||
case Float4Oid:
|
||||
err = encodeFloat4(wbuf, arguments[i])
|
||||
case Float8Oid:
|
||||
err = encodeFloat8(wbuf, arguments[i])
|
||||
case TextOid, VarcharOid:
|
||||
err = encodeText(wbuf, arguments[i])
|
||||
case DateOid:
|
||||
err = encodeDate(wbuf, arguments[i])
|
||||
case TimestampTzOid:
|
||||
err = encodeTimestampTz(wbuf, arguments[i])
|
||||
switch arg := arguments[i].(type) {
|
||||
case BinaryEncoder:
|
||||
err = arg.EncodeBinary(wbuf)
|
||||
case TextEncoder:
|
||||
var s string
|
||||
s, err = arg.EncodeText()
|
||||
wbuf.WriteInt32(int32(len(s)))
|
||||
wbuf.WriteBytes([]byte(s))
|
||||
default:
|
||||
switch arg := arguments[i].(type) {
|
||||
case BinaryEncoder:
|
||||
err = arg.EncodeBinary(wbuf)
|
||||
case TextEncoder:
|
||||
var s string
|
||||
s, err = arg.EncodeText()
|
||||
wbuf.WriteInt32(int32(len(s)))
|
||||
wbuf.WriteBytes([]byte(s))
|
||||
switch oid {
|
||||
case BoolOid:
|
||||
err = encodeBool(wbuf, arguments[i])
|
||||
case ByteaOid:
|
||||
err = encodeBytea(wbuf, arguments[i])
|
||||
case Int2Oid:
|
||||
err = encodeInt2(wbuf, arguments[i])
|
||||
case Int4Oid:
|
||||
err = encodeInt4(wbuf, arguments[i])
|
||||
case Int8Oid:
|
||||
err = encodeInt8(wbuf, arguments[i])
|
||||
case Float4Oid:
|
||||
err = encodeFloat4(wbuf, arguments[i])
|
||||
case Float8Oid:
|
||||
err = encodeFloat8(wbuf, arguments[i])
|
||||
case TextOid, VarcharOid:
|
||||
err = encodeText(wbuf, arguments[i])
|
||||
case DateOid:
|
||||
err = encodeDate(wbuf, arguments[i])
|
||||
case TimestampTzOid:
|
||||
err = encodeTimestampTz(wbuf, arguments[i])
|
||||
default:
|
||||
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg))
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -604,6 +604,33 @@ func TestQueryPreparedEncodeError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that an argument that implements TextEncoder, but not BinaryEncoder
|
||||
// works when the parameter type is a core type.
|
||||
type coreTextEncoder struct{}
|
||||
|
||||
func (n *coreTextEncoder) EncodeText() (string, error) {
|
||||
return "42", nil
|
||||
}
|
||||
|
||||
func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustPrepare(t, conn, "testTranscode", "select $1::integer")
|
||||
|
||||
var n int32
|
||||
err := conn.QueryRow("testTranscode", &coreTextEncoder{}).Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected conn.QueryRow error: %v", err)
|
||||
}
|
||||
|
||||
if n != 42 {
|
||||
t.Errorf("Expected 42, got %v", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -65,6 +65,33 @@ type NullInt64 struct {
|
||||
Valid bool // Valid is true if Int64 is not NULL
|
||||
}
|
||||
|
||||
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
||||
if size == -1 {
|
||||
n.Int64, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Int64 = decodeInt8(rows, fd, size)
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (n *NullInt64) EncodeText() (string, error) {
|
||||
if n.Valid {
|
||||
return strconv.FormatInt(int64(n.Int64), 10), nil
|
||||
} else {
|
||||
return "null", nil
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NullInt64) EncodeBinary(w *WriteBuf) error {
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
return encodeInt8(w, n.Int64)
|
||||
}
|
||||
|
||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
// QuoteString escapes and quotes a string making it safe for interpolation
|
||||
@@ -96,70 +123,55 @@ func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
switch arg := args[n-1].(type) {
|
||||
case string:
|
||||
return QuoteString(arg)
|
||||
case int:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case time.Time:
|
||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
return strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||
case nil:
|
||||
return "null"
|
||||
case TextEncoder:
|
||||
var s string
|
||||
s, err = arg.EncodeText()
|
||||
return s
|
||||
default:
|
||||
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
s, err = sanitizeArg(args[n-1])
|
||||
return s
|
||||
}
|
||||
|
||||
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
||||
return
|
||||
}
|
||||
|
||||
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
||||
if size == -1 {
|
||||
n.Int64, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Int64 = decodeInt8(rows, fd, size)
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (n *NullInt64) EncodeText() (string, error) {
|
||||
if n.Valid {
|
||||
return strconv.FormatInt(int64(n.Int64), 10), nil
|
||||
} else {
|
||||
func sanitizeArg(arg interface{}) (string, error) {
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
return QuoteString(arg), nil
|
||||
case int:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case time.Time:
|
||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")), nil
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32), nil
|
||||
case float64:
|
||||
return strconv.FormatFloat(arg, 'f', -1, 64), nil
|
||||
case bool:
|
||||
return strconv.FormatBool(arg), nil
|
||||
case []byte:
|
||||
return `E'\\x` + hex.EncodeToString(arg) + `'`, nil
|
||||
case nil:
|
||||
return "null", nil
|
||||
case TextEncoder:
|
||||
return arg.EncodeText()
|
||||
default:
|
||||
return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/jackc/pgx"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -185,3 +186,46 @@ func TestTimestampTzTranscode(t *testing.T) {
|
||||
t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullX(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type allTypes struct {
|
||||
i64 pgx.NullInt64
|
||||
}
|
||||
|
||||
var actual, zero allTypes
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
queryArgs []interface{}
|
||||
scanArgs []interface{}
|
||||
expected allTypes
|
||||
}{
|
||||
{"select $1::int8", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}},
|
||||
{"select $1::int8", []interface{}{&pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
psName := fmt.Sprintf("success%d", i)
|
||||
mustPrepare(t, conn, psName, tt.sql)
|
||||
|
||||
for _, sql := range []string{tt.sql, psName} {
|
||||
actual = zero
|
||||
|
||||
err := conn.QueryRow(sql, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, sql, tt.queryArgs)
|
||||
}
|
||||
|
||||
if actual != tt.expected {
|
||||
t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, sql, tt.queryArgs)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user