diff --git a/sanitize.go b/sanitize.go deleted file mode 100644 index 320af55b..00000000 --- a/sanitize.go +++ /dev/null @@ -1,106 +0,0 @@ -package pgx - -import ( - "encoding/hex" - "fmt" - "regexp" - "strconv" - "strings" - "time" -) - -type SerializationError string - -func (e SerializationError) Error() string { - return string(e) -} - -// TextEncoder is an interface used to encode values in text format for -// transmission to the PostgreSQL server. It is used by unprepared -// queries and for prepared queries when the type does not implement -// BinaryEncoder -type TextEncoder interface { - // EncodeText MUST sanitize (and quote, if necessary) the returned string. - // It will be interpolated directly into the SQL string. - EncodeText() (string, error) -} - -var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`) - -// QuoteString escapes and quotes a string making it safe for interpolation -// into an SQL string. -func QuoteString(input string) (output string) { - output = "'" + strings.Replace(input, "'", "''", -1) + "'" - return -} - -// QuoteIdentifier escapes and quotes an identifier making it safe for -// interpolation into an SQL string -func QuoteIdentifier(input string) (output string) { - output = `"` + strings.Replace(input, `"`, `""`, -1) + `"` - return -} - -// SanitizeSql substitutely args positionaly into sql. Placeholder values are -// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as -// appropriate. -func SanitizeSql(sql string, args ...interface{}) (output string, err error) { - replacer := func(match string) (replacement string) { - if err != nil { - return "" - } - - n, _ := strconv.ParseInt(match[1:], 10, 0) - if int(n-1) >= len(args) { - err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args)) - return - } - - switch arg := args[n-1].(type) { - case string: - return QuoteString(arg) - case int: - return strconv.FormatInt(int64(arg), 10) - case int8: - return strconv.FormatInt(int64(arg), 10) - case int16: - return strconv.FormatInt(int64(arg), 10) - case int32: - return strconv.FormatInt(int64(arg), 10) - case int64: - return strconv.FormatInt(int64(arg), 10) - case time.Time: - return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")) - case uint: - return strconv.FormatUint(uint64(arg), 10) - case uint8: - return strconv.FormatUint(uint64(arg), 10) - case uint16: - return strconv.FormatUint(uint64(arg), 10) - case uint32: - return strconv.FormatUint(uint64(arg), 10) - case uint64: - return strconv.FormatUint(uint64(arg), 10) - case float32: - return strconv.FormatFloat(float64(arg), 'f', -1, 32) - case float64: - return strconv.FormatFloat(arg, 'f', -1, 64) - case bool: - return strconv.FormatBool(arg) - case []byte: - return `E'\\x` + hex.EncodeToString(arg) + `'` - case nil: - return "null" - case TextEncoder: - var s string - s, err = arg.EncodeText() - return s - default: - err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg)) - return "" - } - } - - output = literalPattern.ReplaceAllStringFunc(sql, replacer) - return -} diff --git a/sanitize_test.go b/sanitize_test.go deleted file mode 100644 index d3134dc3..00000000 --- a/sanitize_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package pgx_test - -import ( - "github.com/jackc/pgx" - "strings" - "testing" -) - -func TestQuoteString(t *testing.T) { - t.Parallel() - - if pgx.QuoteString("test") != "'test'" { - t.Error("Failed to quote string") - } - - if pgx.QuoteString("Jack's") != "'Jack''s'" { - t.Error("Failed to quote and escape string with embedded quote") - } -} - -func TestSanitizeSql(t *testing.T) { - t.Parallel() - - successTests := []struct { - sql string - args []interface{} - output string - }{ - {"select $1", []interface{}{nil}, "select null"}, - {"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"}, - {"select $1", []interface{}{42}, "select 42"}, - {"select $1", []interface{}{1.23}, "select 1.23"}, - {"select $1", []interface{}{true}, "select true"}, - {"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"}, - {"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`}, - {"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"}, - } - - for i, tt := range successTests { - san, err := pgx.SanitizeSql(tt.sql, tt.args...) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args) - } - if san != tt.output { - t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args) - } - } - - errorTests := []struct { - sql string - args []interface{} - err string - }{ - {"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"}, - {"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"}, - } - - for i, tt := range errorTests { - _, err := pgx.SanitizeSql(tt.sql, tt.args...) - if err == nil { - t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err) - } - if !strings.Contains(err.Error(), tt.err) { - t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args) - } - } -} diff --git a/value_transcoder.go b/values.go similarity index 82% rename from value_transcoder.go rename to values.go index 3a9afc17..fdb84680 100644 --- a/value_transcoder.go +++ b/values.go @@ -6,6 +6,7 @@ import ( "math" "regexp" "strconv" + "strings" "time" "unsafe" ) @@ -29,10 +30,26 @@ const ( BinaryFormatCode = 1 ) +type SerializationError string + +func (e SerializationError) Error() string { + return string(e) +} + type Scanner interface { Scan(qr *QueryResult, fd *FieldDescription, size int32) error } +// TextEncoder is an interface used to encode values in text format for +// transmission to the PostgreSQL server. It is used by unprepared +// queries and for prepared queries when the type does not implement +// BinaryEncoder +type TextEncoder interface { + // EncodeText MUST sanitize (and quote, if necessary) the returned string. + // It will be interpolated directly into the SQL string. + EncodeText() (string, error) +} + // BinaryEncoder is an interface used to encode values in binary format for // transmission to the PostgreSQL server. It is used by prepared queries. type BinaryEncoder interface { @@ -46,6 +63,86 @@ type NullInt64 struct { Valid bool // Valid is true if Int64 is not NULL } +var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`) + +// QuoteString escapes and quotes a string making it safe for interpolation +// into an SQL string. +func QuoteString(input string) (output string) { + output = "'" + strings.Replace(input, "'", "''", -1) + "'" + return +} + +// QuoteIdentifier escapes and quotes an identifier making it safe for +// interpolation into an SQL string +func QuoteIdentifier(input string) (output string) { + output = `"` + strings.Replace(input, `"`, `""`, -1) + `"` + return +} + +// SanitizeSql substitutely args positionaly into sql. Placeholder values are +// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as +// appropriate. +func SanitizeSql(sql string, args ...interface{}) (output string, err error) { + replacer := func(match string) (replacement string) { + if err != nil { + return "" + } + + n, _ := strconv.ParseInt(match[1:], 10, 0) + if int(n-1) >= len(args) { + err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args)) + return + } + + switch arg := args[n-1].(type) { + case string: + return QuoteString(arg) + case int: + return strconv.FormatInt(int64(arg), 10) + case int8: + return strconv.FormatInt(int64(arg), 10) + case int16: + return strconv.FormatInt(int64(arg), 10) + case int32: + return strconv.FormatInt(int64(arg), 10) + case int64: + return strconv.FormatInt(int64(arg), 10) + case time.Time: + return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")) + case uint: + return strconv.FormatUint(uint64(arg), 10) + case uint8: + return strconv.FormatUint(uint64(arg), 10) + case uint16: + return strconv.FormatUint(uint64(arg), 10) + case uint32: + return strconv.FormatUint(uint64(arg), 10) + case uint64: + return strconv.FormatUint(uint64(arg), 10) + case float32: + return strconv.FormatFloat(float64(arg), 'f', -1, 32) + case float64: + return strconv.FormatFloat(arg, 'f', -1, 64) + case bool: + return strconv.FormatBool(arg) + case []byte: + return `E'\\x` + hex.EncodeToString(arg) + `'` + case nil: + return "null" + case TextEncoder: + var s string + s, err = arg.EncodeText() + return s + default: + err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg)) + return "" + } + } + + output = literalPattern.ReplaceAllStringFunc(sql, replacer) + return +} + func (n *NullInt64) Scan(qr *QueryResult, fd *FieldDescription, size int32) error { if size == -1 { n.Int64, n.Valid = 0, false diff --git a/value_transcoder_test.go b/values_test.go similarity index 78% rename from value_transcoder_test.go rename to values_test.go index b86139c4..623d47ae 100644 --- a/value_transcoder_test.go +++ b/values_test.go @@ -1,10 +1,72 @@ package pgx_test import ( + "github.com/jackc/pgx" + "strings" "testing" "time" ) +func TestQuoteString(t *testing.T) { + t.Parallel() + + if pgx.QuoteString("test") != "'test'" { + t.Error("Failed to quote string") + } + + if pgx.QuoteString("Jack's") != "'Jack''s'" { + t.Error("Failed to quote and escape string with embedded quote") + } +} + +func TestSanitizeSql(t *testing.T) { + t.Parallel() + + successTests := []struct { + sql string + args []interface{} + output string + }{ + {"select $1", []interface{}{nil}, "select null"}, + {"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"}, + {"select $1", []interface{}{42}, "select 42"}, + {"select $1", []interface{}{1.23}, "select 1.23"}, + {"select $1", []interface{}{true}, "select true"}, + {"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"}, + {"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`}, + {"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"}, + } + + for i, tt := range successTests { + san, err := pgx.SanitizeSql(tt.sql, tt.args...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args) + } + if san != tt.output { + t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args) + } + } + + errorTests := []struct { + sql string + args []interface{} + err string + }{ + {"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"}, + {"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"}, + } + + for i, tt := range errorTests { + _, err := pgx.SanitizeSql(tt.sql, tt.args...) + if err == nil { + t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args) + } + } +} + func TestEncodeError(t *testing.T) { t.Parallel()