From 107196ab0c90528a41a9c2910e158763ac320cd4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 18:43:04 -0500 Subject: [PATCH] Add NamedArgs https://github.com/jackc/pgx/issues/1186 https://github.com/jackc/pgx/issues/387 --- conn.go | 2 +- conn_test.go | 2 +- named_args.go | 266 +++++++++++++++++++++++++++++++++++++++++++++ named_args_test.go | 96 ++++++++++++++++ 4 files changed, 364 insertions(+), 2 deletions(-) create mode 100644 named_args.go create mode 100644 named_args_test.go diff --git a/conn.go b/conn.go index dd4a7301..72154325 100644 --- a/conn.go +++ b/conn.go @@ -692,7 +692,7 @@ type QueryResultFormatsByOID map[uint32]int16 // QueryRewriter rewrites a query when used as the first arguments to a query method. type QueryRewriter interface { - RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any) + RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) } // Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The diff --git a/conn_test.go b/conn_test.go index 392ea623..675fba17 100644 --- a/conn_test.go +++ b/conn_test.go @@ -235,7 +235,7 @@ type testQueryRewriter struct { args []any } -func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) { +func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) { return qr.sql, qr.args } diff --git a/named_args.go b/named_args.go new file mode 100644 index 00000000..e6906b3b --- /dev/null +++ b/named_args.go @@ -0,0 +1,266 @@ +package pgx + +import ( + "context" + "strconv" + "strings" + "unicode/utf8" +) + +// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$' +// ordinal placeholder and construct the appropriate arguments. +// +// For example, the following two queries are equivalent: +// +// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})) +// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2})) +type NamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + nameToOrdinal: make(map[namedArg]int, len(na)), + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + sb := strings.Builder{} + for _, p := range l.parts { + switch p := p.(type) { + case string: + sb.WriteString(p) + case namedArg: + sb.WriteRune('$') + sb.WriteString(strconv.Itoa(l.nameToOrdinal[p])) + } + } + + newArgs = make([]any, len(l.nameToOrdinal)) + for name, ordinal := range l.nameToOrdinal { + newArgs[ordinal-1] = na[string(name)] + } + + return sb.String(), newArgs +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '@': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if isLetter(nextRune) { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return namedArgState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func isLetter(r rune) bool { + return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') +} + +func namedArgState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if r == utf8.RuneError { + if l.pos-l.start > 0 { + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, na) + l.start = l.pos + } + return nil + } else if !(isLetter(r) || (r >= '0' && r <= '9')) { + l.pos -= width + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, namedArg(na)) + l.start = l.pos + return rawState + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} diff --git a/named_args_test.go b/named_args_test.go new file mode 100644 index 00000000..fea3b897 --- /dev/null +++ b/named_args_test.go @@ -0,0 +1,96 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" +) + +func TestNamedArgsRewriteQuery(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + args []any + namedArgs pgx.NamedArgs + expectedSQL string + expectedArgs []any + }{ + { + sql: "select * from users where id = @id", + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: "select * from users where id = $1", + expectedArgs: []any{int32(42)}, + }, + { + sql: "select * from t where foo < @abc and baz = @def and bar < @abc", + namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)}, + expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1", + expectedArgs: []any{int32(42), int32(1)}, + }, + { + sql: "select @a::int, @b::text", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "select $1::int, $2::text", + expectedArgs: []any{int32(42), "foo"}, + }, + { + sql: "at end @", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "at end @", + expectedArgs: []any{}, + }, + { + sql: "ignores without letter after @ foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "ignores without letter after @ foo bar", + expectedArgs: []any{}, + }, + { + sql: "name must start with letter @1 foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "name must start with letter @1 foo bar", + expectedArgs: []any{}, + }, + { + sql: `select *, '@foo' as "@bar" from users where id = @id`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * -- @foo + from users -- @single line comments + where id = @id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * -- @foo + from users -- @single line comments + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * /* @multi line + @comment + */ + /* /* with @nesting */ */ + from users + where id = @id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * /* @multi line + @comment + */ + /* /* with @nesting */ */ + from users + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + + // test comments and quotes + } { + sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } +}