2
0

Fix query sanitizer

...when query text has contains Unicode replacement character.
uft8.RuneError actually is a valid character.
This commit is contained in:
Jack Christensen
2022-11-14 18:22:57 -06:00
parent b4d2eae777
commit ba4bbf92af
2 changed files with 52 additions and 24 deletions
+18
View File
@@ -18,6 +18,12 @@ type Query struct {
Parts []Part Parts []Part
} }
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3
func (q *Query) Sanitize(args ...any) (string, error) { func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args)) argUse := make([]bool, len(args))
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
@@ -138,6 +144,7 @@ func rawState(l *sqlLexer) stateFn {
return multilineCommentState return multilineCommentState
} }
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@@ -146,6 +153,7 @@ func rawState(l *sqlLexer) stateFn {
} }
} }
} }
}
func singleQuoteState(l *sqlLexer) stateFn { func singleQuoteState(l *sqlLexer) stateFn {
for { for {
@@ -160,6 +168,7 @@ func singleQuoteState(l *sqlLexer) stateFn {
} }
l.pos += width l.pos += width
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@@ -168,6 +177,7 @@ func singleQuoteState(l *sqlLexer) stateFn {
} }
} }
} }
}
func doubleQuoteState(l *sqlLexer) stateFn { func doubleQuoteState(l *sqlLexer) stateFn {
for { for {
@@ -182,6 +192,7 @@ func doubleQuoteState(l *sqlLexer) stateFn {
} }
l.pos += width l.pos += width
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@@ -190,6 +201,7 @@ func doubleQuoteState(l *sqlLexer) stateFn {
} }
} }
} }
}
// placeholderState consumes a placeholder value. The $ must have already has // placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit. // already been consumed. The first rune must be a digit.
@@ -228,6 +240,7 @@ func escapeStringState(l *sqlLexer) stateFn {
} }
l.pos += width l.pos += width
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@@ -236,6 +249,7 @@ func escapeStringState(l *sqlLexer) stateFn {
} }
} }
} }
}
func oneLineCommentState(l *sqlLexer) stateFn { func oneLineCommentState(l *sqlLexer) stateFn {
for { for {
@@ -249,6 +263,7 @@ func oneLineCommentState(l *sqlLexer) stateFn {
case '\n', '\r': case '\n', '\r':
return rawState return rawState
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@@ -257,6 +272,7 @@ func oneLineCommentState(l *sqlLexer) stateFn {
} }
} }
} }
}
func multilineCommentState(l *sqlLexer) stateFn { func multilineCommentState(l *sqlLexer) stateFn {
for { for {
@@ -283,6 +299,7 @@ func multilineCommentState(l *sqlLexer) stateFn {
l.nested-- l.nested--
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@@ -291,6 +308,7 @@ func multilineCommentState(l *sqlLexer) stateFn {
} }
} }
} }
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args // SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is // as necessary. This function is only safe when standard_conforming_strings is
+10
View File
@@ -88,6 +88,16 @@ func TestNewQuery(t *testing.T) {
sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1", sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}}, expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}},
}, },
{
// https://github.com/jackc/pgx/issues/1380
sql: "select 'hello wrld'",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello wrld'"}},
},
{
// Unterminated quoted string
sql: "select 'hello world",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}},
},
} }
for i, tt := range successTests { for i, tt := range successTests {