Add more testing of Encode*
Handle case where TextEncoder is used to a core type that the driver could otherwise have handled as binary.
This commit is contained in:
@@ -734,16 +734,19 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
|
|
||||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||||
for i, oid := range ps.ParameterOids {
|
for i, oid := range ps.ParameterOids {
|
||||||
switch oid {
|
switch arg := arguments[i].(type) {
|
||||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid:
|
case BinaryEncoder:
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
case TextOid, VarcharOid, DateOid, TimestampTzOid:
|
case TextEncoder:
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
default:
|
default:
|
||||||
if _, ok := arguments[i].(BinaryEncoder); ok {
|
switch oid {
|
||||||
|
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid:
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
} else {
|
case TextOid, VarcharOid, DateOid, TimestampTzOid:
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
|
default:
|
||||||
|
return SerializationError(fmt.Sprintf("Parameter %d oid %d is not a core type and argument type %T does not implement TextEncoder or BinaryEncoder", i, oid, arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -755,41 +758,40 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
switch oid {
|
switch arg := arguments[i].(type) {
|
||||||
case BoolOid:
|
case BinaryEncoder:
|
||||||
err = encodeBool(wbuf, arguments[i])
|
err = arg.EncodeBinary(wbuf)
|
||||||
case ByteaOid:
|
case TextEncoder:
|
||||||
err = encodeBytea(wbuf, arguments[i])
|
var s string
|
||||||
case Int2Oid:
|
s, err = arg.EncodeText()
|
||||||
err = encodeInt2(wbuf, arguments[i])
|
wbuf.WriteInt32(int32(len(s)))
|
||||||
case Int4Oid:
|
wbuf.WriteBytes([]byte(s))
|
||||||
err = encodeInt4(wbuf, arguments[i])
|
|
||||||
case Int8Oid:
|
|
||||||
err = encodeInt8(wbuf, arguments[i])
|
|
||||||
case Float4Oid:
|
|
||||||
err = encodeFloat4(wbuf, arguments[i])
|
|
||||||
case Float8Oid:
|
|
||||||
err = encodeFloat8(wbuf, arguments[i])
|
|
||||||
case TextOid, VarcharOid:
|
|
||||||
err = encodeText(wbuf, arguments[i])
|
|
||||||
case DateOid:
|
|
||||||
err = encodeDate(wbuf, arguments[i])
|
|
||||||
case TimestampTzOid:
|
|
||||||
err = encodeTimestampTz(wbuf, arguments[i])
|
|
||||||
default:
|
default:
|
||||||
switch arg := arguments[i].(type) {
|
switch oid {
|
||||||
case BinaryEncoder:
|
case BoolOid:
|
||||||
err = arg.EncodeBinary(wbuf)
|
err = encodeBool(wbuf, arguments[i])
|
||||||
case TextEncoder:
|
case ByteaOid:
|
||||||
var s string
|
err = encodeBytea(wbuf, arguments[i])
|
||||||
s, err = arg.EncodeText()
|
case Int2Oid:
|
||||||
wbuf.WriteInt32(int32(len(s)))
|
err = encodeInt2(wbuf, arguments[i])
|
||||||
wbuf.WriteBytes([]byte(s))
|
case Int4Oid:
|
||||||
|
err = encodeInt4(wbuf, arguments[i])
|
||||||
|
case Int8Oid:
|
||||||
|
err = encodeInt8(wbuf, arguments[i])
|
||||||
|
case Float4Oid:
|
||||||
|
err = encodeFloat4(wbuf, arguments[i])
|
||||||
|
case Float8Oid:
|
||||||
|
err = encodeFloat8(wbuf, arguments[i])
|
||||||
|
case TextOid, VarcharOid:
|
||||||
|
err = encodeText(wbuf, arguments[i])
|
||||||
|
case DateOid:
|
||||||
|
err = encodeDate(wbuf, arguments[i])
|
||||||
|
case TimestampTzOid:
|
||||||
|
err = encodeTimestampTz(wbuf, arguments[i])
|
||||||
default:
|
default:
|
||||||
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg))
|
return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder or BinaryEncoder", arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -604,6 +604,33 @@ func TestQueryPreparedEncodeError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure that an argument that implements TextEncoder, but not BinaryEncoder
|
||||||
|
// works when the parameter type is a core type.
|
||||||
|
type coreTextEncoder struct{}
|
||||||
|
|
||||||
|
func (n *coreTextEncoder) EncodeText() (string, error) {
|
||||||
|
return "42", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustPrepare(t, conn, "testTranscode", "select $1::integer")
|
||||||
|
|
||||||
|
var n int32
|
||||||
|
err := conn.QueryRow("testTranscode", &coreTextEncoder{}).Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected conn.QueryRow error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != 42 {
|
||||||
|
t.Errorf("Expected 42, got %v", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPrepare(t *testing.T) {
|
func TestPrepare(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,33 @@ type NullInt64 struct {
|
|||||||
Valid bool // Valid is true if Int64 is not NULL
|
Valid bool // Valid is true if Int64 is not NULL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
||||||
|
if size == -1 {
|
||||||
|
n.Int64, n.Valid = 0, false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
n.Valid = true
|
||||||
|
n.Int64 = decodeInt8(rows, fd, size)
|
||||||
|
return rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NullInt64) EncodeText() (string, error) {
|
||||||
|
if n.Valid {
|
||||||
|
return strconv.FormatInt(int64(n.Int64), 10), nil
|
||||||
|
} else {
|
||||||
|
return "null", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NullInt64) EncodeBinary(w *WriteBuf) error {
|
||||||
|
if !n.Valid {
|
||||||
|
w.WriteInt32(-1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return encodeInt8(w, n.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||||
|
|
||||||
// QuoteString escapes and quotes a string making it safe for interpolation
|
// QuoteString escapes and quotes a string making it safe for interpolation
|
||||||
@@ -96,70 +123,55 @@ func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch arg := args[n-1].(type) {
|
var s string
|
||||||
case string:
|
s, err = sanitizeArg(args[n-1])
|
||||||
return QuoteString(arg)
|
return s
|
||||||
case int:
|
|
||||||
return strconv.FormatInt(int64(arg), 10)
|
|
||||||
case int8:
|
|
||||||
return strconv.FormatInt(int64(arg), 10)
|
|
||||||
case int16:
|
|
||||||
return strconv.FormatInt(int64(arg), 10)
|
|
||||||
case int32:
|
|
||||||
return strconv.FormatInt(int64(arg), 10)
|
|
||||||
case int64:
|
|
||||||
return strconv.FormatInt(int64(arg), 10)
|
|
||||||
case time.Time:
|
|
||||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
|
||||||
case uint:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10)
|
|
||||||
case uint8:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10)
|
|
||||||
case uint16:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10)
|
|
||||||
case uint32:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10)
|
|
||||||
case uint64:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10)
|
|
||||||
case float32:
|
|
||||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
|
||||||
case float64:
|
|
||||||
return strconv.FormatFloat(arg, 'f', -1, 64)
|
|
||||||
case bool:
|
|
||||||
return strconv.FormatBool(arg)
|
|
||||||
case []byte:
|
|
||||||
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
|
||||||
case nil:
|
|
||||||
return "null"
|
|
||||||
case TextEncoder:
|
|
||||||
var s string
|
|
||||||
s, err = arg.EncodeText()
|
|
||||||
return s
|
|
||||||
default:
|
|
||||||
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
func sanitizeArg(arg interface{}) (string, error) {
|
||||||
if size == -1 {
|
switch arg := arg.(type) {
|
||||||
n.Int64, n.Valid = 0, false
|
case string:
|
||||||
return nil
|
return QuoteString(arg), nil
|
||||||
}
|
case int:
|
||||||
n.Valid = true
|
return strconv.FormatInt(int64(arg), 10), nil
|
||||||
n.Int64 = decodeInt8(rows, fd, size)
|
case int8:
|
||||||
return rows.Err()
|
return strconv.FormatInt(int64(arg), 10), nil
|
||||||
}
|
case int16:
|
||||||
|
return strconv.FormatInt(int64(arg), 10), nil
|
||||||
func (n *NullInt64) EncodeText() (string, error) {
|
case int32:
|
||||||
if n.Valid {
|
return strconv.FormatInt(int64(arg), 10), nil
|
||||||
return strconv.FormatInt(int64(n.Int64), 10), nil
|
case int64:
|
||||||
} else {
|
return strconv.FormatInt(int64(arg), 10), nil
|
||||||
|
case time.Time:
|
||||||
|
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")), nil
|
||||||
|
case uint:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10), nil
|
||||||
|
case uint8:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10), nil
|
||||||
|
case uint16:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10), nil
|
||||||
|
case uint32:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10), nil
|
||||||
|
case uint64:
|
||||||
|
return strconv.FormatUint(uint64(arg), 10), nil
|
||||||
|
case float32:
|
||||||
|
return strconv.FormatFloat(float64(arg), 'f', -1, 32), nil
|
||||||
|
case float64:
|
||||||
|
return strconv.FormatFloat(arg, 'f', -1, 64), nil
|
||||||
|
case bool:
|
||||||
|
return strconv.FormatBool(arg), nil
|
||||||
|
case []byte:
|
||||||
|
return `E'\\x` + hex.EncodeToString(arg) + `'`, nil
|
||||||
|
case nil:
|
||||||
return "null", nil
|
return "null", nil
|
||||||
|
case TextEncoder:
|
||||||
|
return arg.EncodeText()
|
||||||
|
default:
|
||||||
|
return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -185,3 +186,46 @@ func TestTimestampTzTranscode(t *testing.T) {
|
|||||||
t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
|
t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNullX(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
type allTypes struct {
|
||||||
|
i64 pgx.NullInt64
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual, zero allTypes
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
queryArgs []interface{}
|
||||||
|
scanArgs []interface{}
|
||||||
|
expected allTypes
|
||||||
|
}{
|
||||||
|
{"select $1::int8", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}},
|
||||||
|
{"select $1::int8", []interface{}{&pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
psName := fmt.Sprintf("success%d", i)
|
||||||
|
mustPrepare(t, conn, psName, tt.sql)
|
||||||
|
|
||||||
|
for _, sql := range []string{tt.sql, psName} {
|
||||||
|
actual = zero
|
||||||
|
|
||||||
|
err := conn.QueryRow(sql, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual != tt.expected {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user