c543134753
pgx v5 was not vulnerable to CVE-2024-27289 do to how the sanitizer was being called. But the sanitizer itself still had the underlying issue. This commit ports the fix from pgx v4 to v5 to ensure that the issue does not emerge if pgx uses the sanitizer differently in the future.
332 lines
6.9 KiB
Go
332 lines
6.9 KiB
Go
package sanitize
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
// Part is either a string or an int. A string is raw SQL. An int is a
|
|
// argument placeholder.
|
|
type Part any
|
|
|
|
type Query struct {
|
|
Parts []Part
|
|
}
|
|
|
|
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
|
|
// character. utf8.RuneError is not an error if it is also width 3.
|
|
//
|
|
// https://github.com/jackc/pgx/issues/1380
|
|
const replacementcharacterwidth = 3
|
|
|
|
func (q *Query) Sanitize(args ...any) (string, error) {
|
|
argUse := make([]bool, len(args))
|
|
buf := &bytes.Buffer{}
|
|
|
|
for _, part := range q.Parts {
|
|
var str string
|
|
switch part := part.(type) {
|
|
case string:
|
|
str = part
|
|
case int:
|
|
argIdx := part - 1
|
|
|
|
if argIdx < 0 {
|
|
return "", fmt.Errorf("first sql argument must be > 0")
|
|
}
|
|
|
|
if argIdx >= len(args) {
|
|
return "", fmt.Errorf("insufficient arguments")
|
|
}
|
|
arg := args[argIdx]
|
|
switch arg := arg.(type) {
|
|
case nil:
|
|
str = "null"
|
|
case int64:
|
|
str = strconv.FormatInt(arg, 10)
|
|
case float64:
|
|
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
|
case bool:
|
|
str = strconv.FormatBool(arg)
|
|
case []byte:
|
|
str = QuoteBytes(arg)
|
|
case string:
|
|
str = QuoteString(arg)
|
|
case time.Time:
|
|
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
|
default:
|
|
return "", fmt.Errorf("invalid arg type: %T", arg)
|
|
}
|
|
argUse[argIdx] = true
|
|
|
|
// Prevent SQL injection via Line Comment Creation
|
|
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
|
str = "(" + str + ")"
|
|
default:
|
|
return "", fmt.Errorf("invalid Part type: %T", part)
|
|
}
|
|
buf.WriteString(str)
|
|
}
|
|
|
|
for i, used := range argUse {
|
|
if !used {
|
|
return "", fmt.Errorf("unused argument: %d", i)
|
|
}
|
|
}
|
|
return buf.String(), nil
|
|
}
|
|
|
|
func NewQuery(sql string) (*Query, error) {
|
|
l := &sqlLexer{
|
|
src: sql,
|
|
stateFn: rawState,
|
|
}
|
|
|
|
for l.stateFn != nil {
|
|
l.stateFn = l.stateFn(l)
|
|
}
|
|
|
|
query := &Query{Parts: l.parts}
|
|
|
|
return query, nil
|
|
}
|
|
|
|
func QuoteString(str string) string {
|
|
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
|
}
|
|
|
|
func QuoteBytes(buf []byte) string {
|
|
return `'\x` + hex.EncodeToString(buf) + "'"
|
|
}
|
|
|
|
type sqlLexer struct {
|
|
src string
|
|
start int
|
|
pos int
|
|
nested int // multiline comment nesting level.
|
|
stateFn stateFn
|
|
parts []Part
|
|
}
|
|
|
|
type stateFn func(*sqlLexer) stateFn
|
|
|
|
func rawState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case 'e', 'E':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune == '\'' {
|
|
l.pos += width
|
|
return escapeStringState
|
|
}
|
|
case '\'':
|
|
return singleQuoteState
|
|
case '"':
|
|
return doubleQuoteState
|
|
case '$':
|
|
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if '0' <= nextRune && nextRune <= '9' {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
|
}
|
|
l.start = l.pos
|
|
return placeholderState
|
|
}
|
|
case '-':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune == '-' {
|
|
l.pos += width
|
|
return oneLineCommentState
|
|
}
|
|
case '/':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune == '*' {
|
|
l.pos += width
|
|
return multilineCommentState
|
|
}
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func singleQuoteState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '\'':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune != '\'' {
|
|
return rawState
|
|
}
|
|
l.pos += width
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func doubleQuoteState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '"':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune != '"' {
|
|
return rawState
|
|
}
|
|
l.pos += width
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// placeholderState consumes a placeholder value. The $ must have already has
|
|
// already been consumed. The first rune must be a digit.
|
|
func placeholderState(l *sqlLexer) stateFn {
|
|
num := 0
|
|
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
if '0' <= r && r <= '9' {
|
|
num *= 10
|
|
num += int(r - '0')
|
|
} else {
|
|
l.parts = append(l.parts, num)
|
|
l.pos -= width
|
|
l.start = l.pos
|
|
return rawState
|
|
}
|
|
}
|
|
}
|
|
|
|
func escapeStringState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '\\':
|
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
case '\'':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune != '\'' {
|
|
return rawState
|
|
}
|
|
l.pos += width
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func oneLineCommentState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '\\':
|
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
case '\n', '\r':
|
|
return rawState
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func multilineCommentState(l *sqlLexer) stateFn {
|
|
for {
|
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
l.pos += width
|
|
|
|
switch r {
|
|
case '/':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune == '*' {
|
|
l.pos += width
|
|
l.nested++
|
|
}
|
|
case '*':
|
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
|
if nextRune != '/' {
|
|
continue
|
|
}
|
|
|
|
l.pos += width
|
|
if l.nested == 0 {
|
|
return rawState
|
|
}
|
|
l.nested--
|
|
|
|
case utf8.RuneError:
|
|
if width != replacementcharacterwidth {
|
|
if l.pos-l.start > 0 {
|
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
|
l.start = l.pos
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
|
// as necessary. This function is only safe when standard_conforming_strings is
|
|
// on.
|
|
func SanitizeSQL(sql string, args ...any) (string, error) {
|
|
query, err := NewQuery(sql)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return query.Sanitize(args...)
|
|
}
|