2
0

Remove simple protocol and one round trip query options

It is impossible to guarantee that the a query executed with the simple
protocol will behave the same as with the extended protocol. This is
because the normal pgx path relies on knowing the OID of query
parameters. Without this encoding a value can only be determined by the
value instead of the combination of value and PostgreSQL type. For
example, how should a []int32 be encoded? It might be encoded into a
PostgreSQL int4[] or json.

Removal also simplifies the core query path.

The primary reason for the simple protocol is for servers like PgBouncer
that may not be able to support normal prepared statements. After
further research it appears that issuing a "flush" instead "sync" after
preparing the unnamed statement would allow PgBouncer to work.

The one round trip mode can be better handled with prepared statements.

As a last resort, all original server functionality can still be accessed by
dropping down to PgConn.
This commit is contained in:
Jack Christensen
2019-04-13 11:39:01 -05:00
parent 5a374c467f
commit c53c9e6eb5
16 changed files with 32 additions and 1021 deletions
-237
View File
@@ -1,237 +0,0 @@
package sanitize
import (
"bytes"
"encoding/hex"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/pkg/errors"
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part interface{}
type Query struct {
Parts []Part
}
func (q *Query) Sanitize(args ...interface{}) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
case int:
argIdx := part - 1
if argIdx >= len(args) {
return "", errors.Errorf("insufficient arguments")
}
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
case int64:
str = strconv.FormatInt(arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
case []byte:
str = QuoteBytes(arg)
case string:
str = QuoteString(arg)
case time.Time:
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", errors.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
default:
return "", errors.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
if !used {
return "", errors.Errorf("unused argument: %d", i)
}
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
}
func QuoteString(str string) string {
return "'" + strings.Replace(str, "'", "''", -1) + "'"
}
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}
type sqlLexer struct {
src string
start int
pos int
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else {
l.parts = append(l.parts, num)
l.pos -= width
l.start = l.pos
return rawState
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
return query.Sanitize(args...)
}
-175
View File
@@ -1,175 +0,0 @@
package sanitize_test
import (
"testing"
"github.com/jackc/pgx/internal/sanitize"
)
func TestNewQuery(t *testing.T) {
successTests := []struct {
sql string
expected sanitize.Query
}{
{
sql: "select 42",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
},
{
sql: "select $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
},
{
sql: "select 'quoted $42', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
},
{
sql: `select "doubled quoted $42", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
},
{
sql: "select 'foo''bar', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
},
{
sql: `select "foo""bar", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
},
{
sql: "select '''', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
},
{
sql: `select """", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
},
{
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
},
{
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
},
{
sql: `select E'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
},
{
sql: `select e'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
},
}
for i, tt := range successTests {
query, err := sanitize.NewQuery(tt.sql)
if err != nil {
t.Errorf("%d. %v", i, err)
}
if len(query.Parts) == len(tt.expected.Parts) {
for j := range query.Parts {
if query.Parts[j] != tt.expected.Parts[j] {
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
}
}
} else {
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
}
}
}
func TestQuerySanitize(t *testing.T) {
successfulTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
args: []interface{}{},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{int64(42)},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{float64(1.23)},
expected: `select 1.23`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{true},
expected: `select true`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
expected: `select '\x00010203ff'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{nil},
expected: `select null`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foobar"},
expected: `select 'foobar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foo'bar"},
expected: `select 'foo''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{`foo\'bar`},
expected: `select 'foo\''bar'`,
},
}
for i, tt := range successfulTests {
actual, err := tt.query.Sanitize(tt.args...)
if err != nil {
t.Errorf("%d. %v", i, err)
continue
}
if tt.expected != actual {
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
}
}
errorTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
args: []interface{}{int64(42)},
expected: `insufficient arguments`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
args: []interface{}{int64(42)},
expected: `unused argument: 0`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{42},
expected: `invalid arg type: int`,
},
}
for i, tt := range errorTests {
_, err := tt.query.Sanitize(tt.args...)
if err == nil || err.Error() != tt.expected {
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
}
}
}