Add more testing of Encode*
Handle case where TextEncoder is used to a core type that the driver could otherwise have handled as binary.
This commit is contained in:
@@ -65,6 +65,33 @@ type NullInt64 struct {
|
||||
Valid bool // Valid is true if Int64 is not NULL
|
||||
}
|
||||
|
||||
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
||||
if size == -1 {
|
||||
n.Int64, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Int64 = decodeInt8(rows, fd, size)
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (n *NullInt64) EncodeText() (string, error) {
|
||||
if n.Valid {
|
||||
return strconv.FormatInt(int64(n.Int64), 10), nil
|
||||
} else {
|
||||
return "null", nil
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NullInt64) EncodeBinary(w *WriteBuf) error {
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
return encodeInt8(w, n.Int64)
|
||||
}
|
||||
|
||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
// QuoteString escapes and quotes a string making it safe for interpolation
|
||||
@@ -96,70 +123,55 @@ func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
switch arg := args[n-1].(type) {
|
||||
case string:
|
||||
return QuoteString(arg)
|
||||
case int:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case time.Time:
|
||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
return strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||
case nil:
|
||||
return "null"
|
||||
case TextEncoder:
|
||||
var s string
|
||||
s, err = arg.EncodeText()
|
||||
return s
|
||||
default:
|
||||
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
s, err = sanitizeArg(args[n-1])
|
||||
return s
|
||||
}
|
||||
|
||||
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
||||
return
|
||||
}
|
||||
|
||||
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
||||
if size == -1 {
|
||||
n.Int64, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Int64 = decodeInt8(rows, fd, size)
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (n *NullInt64) EncodeText() (string, error) {
|
||||
if n.Valid {
|
||||
return strconv.FormatInt(int64(n.Int64), 10), nil
|
||||
} else {
|
||||
func sanitizeArg(arg interface{}) (string, error) {
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
return QuoteString(arg), nil
|
||||
case int:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(arg), 10), nil
|
||||
case time.Time:
|
||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")), nil
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(arg), 10), nil
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32), nil
|
||||
case float64:
|
||||
return strconv.FormatFloat(arg, 'f', -1, 64), nil
|
||||
case bool:
|
||||
return strconv.FormatBool(arg), nil
|
||||
case []byte:
|
||||
return `E'\\x` + hex.EncodeToString(arg) + `'`, nil
|
||||
case nil:
|
||||
return "null", nil
|
||||
case TextEncoder:
|
||||
return arg.EncodeText()
|
||||
default:
|
||||
return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user