2
0

Always use bound parameters

PostgreSQL has two string syntaxes, one that allows backslash escapes and one
that does not (SQL standard conforming strings). By default PostgreSQL uses
standard conforming strings. QuoteString was only designed for use with
standard conforming strings. If PostgreSQL was configured with certain
combinations of the standard_conforming_strings and backslash_quote settings,
QuoteString may not correctly sanitize strings. QuoteString was only used in
unprepared queries, bound parameters are used for prepared queries.

This commit alters pgx to use always use bound parameters.

As a consequence of never doing string interpolation there is no need to have
separate Text and Binary encoders. There is now only the Encoder interface.

This change had a negative effect on the performance of simple unprepared
queries, but prepared statements should already be used for performance.

fixes #26

https://github.com/jackc/pgx/issues/26
This commit is contained in:
Jack Christensen
2014-07-18 14:44:34 -05:00
parent d57e4902a1
commit 61bf7d841a
9 changed files with 166 additions and 339 deletions
+46 -191
View File
@@ -4,9 +4,7 @@ import (
"encoding/hex"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"time"
"unsafe"
)
@@ -33,13 +31,6 @@ const (
BinaryFormatCode = 1
)
// EncodeText statuses
const (
NullText = iota
SafeText = iota
UnsafeText = iota
)
type SerializationError string
func (e SerializationError) Error() string {
@@ -49,32 +40,28 @@ func (e SerializationError) Error() string {
// Scanner is an interface used to decode values from the PostgreSQL server.
type Scanner interface {
// Scan MUST check r.Type().DataType and r.Type().FormatCode before decoding.
// It should not assume that it was called on the type of value.
// It should not assume that it was called on a data type or format that it
// understands.
Scan(r *ValueReader) error
}
// TextEncoder is an interface used to encode values in text format for
// transmission to the PostgreSQL server. It is used by unprepared
// queries and for prepared queries when the type does not implement
// BinaryEncoder
type TextEncoder interface {
// EncodeText returns the value encoded into a string. status must be
// NullText if the value is NULL, UnsafeText if the value should be quoted
// and escaped, or SafeText if the value should not be quoted.
EncodeText() (val string, status byte, err error)
}
// BinaryEncoder is an interface used to encode values in binary format for
// transmission to the PostgreSQL server. It is used by prepared queries.
type BinaryEncoder interface {
// EncodeBinary writes the binary value to w.
// Encoder is an interface used to encode values for transmission to the
// PostgreSQL server.
type Encoder interface {
// Encode writes the value to w.
//
// EncodeBinary MUST check oid to see if the parameter data type is
// compatible. If this is not done, the PostgreSQL server may detect the
// error if the expected data size or format of the encoded data does not
// match. But if the encoded data is a valid representation of the data type
// PostgreSQL expects such as date and int4, incorrect data may be stored.
EncodeBinary(w *WriteBuf, oid Oid) error
// If the value is NULL an int32(-1) should be written.
//
// Encode MUST check oid to see if the parameter data type is compatible. If
// this is not done, the PostgreSQL server may detect the error if the
// expected data size or format of the encoded data does not match. But if
// the encoded data is a valid representation of the data type PostgreSQL
// expects such as date and int4, incorrect data may be stored.
Encode(w *WriteBuf, oid Oid) error
// FormatCode returns the format that the encoder writes the value. It must be
// either pgx.TextFormatCode or pgx.BinaryFormatCode.
FormatCode() int16
}
// NullFloat32 represents an float4 that may be null.
@@ -102,17 +89,11 @@ func (n *NullFloat32) Scan(vr *ValueReader) error {
return vr.Err()
}
func (n NullFloat32) EncodeText() (string, byte, error) {
if n.Valid {
return strconv.FormatFloat(float64(n.Float32), 'f', -1, 32), SafeText, nil
} else {
return "", NullText, nil
}
}
func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode }
func (n NullFloat32) EncodeBinary(w *WriteBuf, oid Oid) error {
func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error {
if oid != Float4Oid {
return SerializationError(fmt.Sprintf("NullFloat32.EncodeBinary cannot encode into OID %d", oid))
return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid))
}
if !n.Valid {
@@ -148,15 +129,9 @@ func (n *NullFloat64) Scan(vr *ValueReader) error {
return vr.Err()
}
func (n NullFloat64) EncodeText() (string, byte, error) {
if n.Valid {
return strconv.FormatFloat(n.Float64, 'f', -1, 64), SafeText, nil
} else {
return "", NullText, nil
}
}
func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode }
func (n NullFloat64) EncodeBinary(w *WriteBuf, oid Oid) error {
func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error {
if oid != Float8Oid {
return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", oid))
}
@@ -193,12 +168,15 @@ func (s *NullString) Scan(vr *ValueReader) error {
return vr.Err()
}
func (s NullString) EncodeText() (string, byte, error) {
if s.Valid {
return s.String, UnsafeText, nil
} else {
return "", NullText, nil
func (n NullString) FormatCode() int16 { return TextFormatCode }
func (s NullString) Encode(w *WriteBuf, oid Oid) error {
if !s.Valid {
w.WriteInt32(-1)
return nil
}
return encodeText(w, s.String)
}
// NullInt16 represents an smallint that may be null.
@@ -226,17 +204,11 @@ func (n *NullInt16) Scan(vr *ValueReader) error {
return vr.Err()
}
func (n NullInt16) EncodeText() (string, byte, error) {
if n.Valid {
return strconv.FormatInt(int64(n.Int16), 10), SafeText, nil
} else {
return "", NullText, nil
}
}
func (n NullInt16) FormatCode() int16 { return BinaryFormatCode }
func (n NullInt16) EncodeBinary(w *WriteBuf, oid Oid) error {
func (n NullInt16) Encode(w *WriteBuf, oid Oid) error {
if oid != Int2Oid {
return SerializationError(fmt.Sprintf("NullInt16.EncodeBinary cannot encode into OID %d", oid))
return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid))
}
if !n.Valid {
@@ -272,17 +244,11 @@ func (n *NullInt32) Scan(vr *ValueReader) error {
return vr.Err()
}
func (n NullInt32) EncodeText() (string, byte, error) {
if n.Valid {
return strconv.FormatInt(int64(n.Int32), 10), SafeText, nil
} else {
return "", NullText, nil
}
}
func (n NullInt32) FormatCode() int16 { return BinaryFormatCode }
func (n NullInt32) EncodeBinary(w *WriteBuf, oid Oid) error {
func (n NullInt32) Encode(w *WriteBuf, oid Oid) error {
if oid != Int4Oid {
return SerializationError(fmt.Sprintf("NullInt32.EncodeBinary cannot encode into OID %d", oid))
return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid))
}
if !n.Valid {
@@ -318,17 +284,11 @@ func (n *NullInt64) Scan(vr *ValueReader) error {
return vr.Err()
}
func (n NullInt64) EncodeText() (string, byte, error) {
if n.Valid {
return strconv.FormatInt(int64(n.Int64), 10), SafeText, nil
} else {
return "", NullText, nil
}
}
func (n NullInt64) FormatCode() int16 { return BinaryFormatCode }
func (n NullInt64) EncodeBinary(w *WriteBuf, oid Oid) error {
func (n NullInt64) Encode(w *WriteBuf, oid Oid) error {
if oid != Int8Oid {
return SerializationError(fmt.Sprintf("NullInt64.EncodeBinary cannot encode into OID %d", oid))
return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid))
}
if !n.Valid {
@@ -364,17 +324,11 @@ func (n *NullBool) Scan(vr *ValueReader) error {
return vr.Err()
}
func (n NullBool) EncodeText() (string, byte, error) {
if n.Valid {
return strconv.FormatBool(n.Bool), SafeText, nil
} else {
return "", NullText, nil
}
}
func (n NullBool) FormatCode() int16 { return BinaryFormatCode }
func (n NullBool) EncodeBinary(w *WriteBuf, oid Oid) error {
func (n NullBool) Encode(w *WriteBuf, oid Oid) error {
if oid != BoolOid {
return SerializationError(fmt.Sprintf("NullBool.EncodeBinary cannot encode into OID %d", oid))
return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid))
}
if !n.Valid {
@@ -412,17 +366,11 @@ func (n *NullTime) Scan(vr *ValueReader) error {
return vr.Err()
}
func (n NullTime) EncodeText() (string, byte, error) {
if n.Valid {
return n.Time.Format("2006-01-02 15:04:05.999999 -0700"), UnsafeText, nil
} else {
return "", NullText, nil
}
}
func (n NullTime) FormatCode() int16 { return BinaryFormatCode }
func (n NullTime) EncodeBinary(w *WriteBuf, oid Oid) error {
func (n NullTime) Encode(w *WriteBuf, oid Oid) error {
if oid != TimestampTzOid {
return SerializationError(fmt.Sprintf("NullTime.EncodeBinary cannot encode into OID %d", oid))
return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid))
}
if !n.Valid {
@@ -433,99 +381,6 @@ func (n NullTime) EncodeBinary(w *WriteBuf, oid Oid) error {
return encodeTimestampTz(w, n.Time)
}
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
// QuoteString escapes and quotes a string making it safe for interpolation
// into an SQL string.
func QuoteString(input string) (output string) {
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
return
}
// QuoteIdentifier escapes and quotes an identifier making it safe for
// interpolation into an SQL string
func QuoteIdentifier(input string) (output string) {
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
return
}
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
// appropriate.
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
replacer := func(match string) (replacement string) {
if err != nil {
return ""
}
n, _ := strconv.ParseInt(match[1:], 10, 0)
if int(n-1) >= len(args) {
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
return
}
var s string
s, err = sanitizeArg(args[n-1])
return s
}
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
return
}
func sanitizeArg(arg interface{}) (string, error) {
switch arg := arg.(type) {
case string:
return QuoteString(arg), nil
case int:
return strconv.FormatInt(int64(arg), 10), nil
case int8:
return strconv.FormatInt(int64(arg), 10), nil
case int16:
return strconv.FormatInt(int64(arg), 10), nil
case int32:
return strconv.FormatInt(int64(arg), 10), nil
case int64:
return strconv.FormatInt(int64(arg), 10), nil
case time.Time:
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")), nil
case uint:
return strconv.FormatUint(uint64(arg), 10), nil
case uint8:
return strconv.FormatUint(uint64(arg), 10), nil
case uint16:
return strconv.FormatUint(uint64(arg), 10), nil
case uint32:
return strconv.FormatUint(uint64(arg), 10), nil
case uint64:
return strconv.FormatUint(uint64(arg), 10), nil
case float32:
return strconv.FormatFloat(float64(arg), 'f', -1, 32), nil
case float64:
return strconv.FormatFloat(arg, 'f', -1, 64), nil
case bool:
return strconv.FormatBool(arg), nil
case []byte:
return `E'\\x` + hex.EncodeToString(arg) + `'`, nil
case nil:
return "null", nil
case TextEncoder:
s, status, err := arg.EncodeText()
switch status {
case NullText:
return "null", err
case UnsafeText:
return QuoteString(s), err
case SafeText:
return s, err
default:
return "", SerializationError("Received invalid status from EncodeText")
}
default:
return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
}
}
func decodeBool(vr *ValueReader) bool {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into bool"))