2
0

Restore simple protocol support

This commit is contained in:
Jack Christensen
2019-05-20 20:36:03 -05:00
parent 6d23b58b01
commit 29f02807b0
7 changed files with 861 additions and 3 deletions
+90 -2
View File
@@ -10,6 +10,7 @@ import (
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4/internal/sanitize"
)
const (
@@ -24,6 +25,14 @@ type ConnConfig struct {
pgconn.Config
Logger Logger
LogLevel LogLevel
// PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended
// protocol. This can improve performance due to being able to use the binary format. It also does not rely on client
// side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement)
// and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
// used by default. The same functionality can be controlled on a per query basis by setting
// QueryExOptions.SimpleProtocol.
PreferSimpleProtocol bool
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
@@ -390,6 +399,36 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
}
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
simpleProtocol := c.config.PreferSimpleProtocol
optionLoop:
for len(arguments) > 0 {
switch arg := arguments[0].(type) {
case QuerySimpleProtocol:
simpleProtocol = bool(arg)
arguments = arguments[1:]
default:
break optionLoop
}
}
if simpleProtocol {
sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
if err != nil {
return nil, err
}
mrr := c.pgConn.Exec(ctx, sql)
if mrr.NextResult() {
result := mrr.ResultReader().Read()
err = mrr.Close()
return result.CommandTag, err
} else {
err = mrr.Close()
return nil, err
}
}
c.eqb.Reset()
if ps, ok := c.preparedStatements[sql]; ok {
@@ -495,6 +534,9 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
return r
}
// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query.
type QuerySimpleProtocol bool
// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
type QueryResultFormats []int16
@@ -506,6 +548,7 @@ type QueryResultFormatsByOID map[pgtype.OID]int16
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
simpleProtocol := c.config.PreferSimpleProtocol
optionLoop:
for len(args) > 0 {
@@ -516,14 +559,39 @@ optionLoop:
case QueryResultFormatsByOID:
resultFormatsByOID = arg
args = args[1:]
case QuerySimpleProtocol:
simpleProtocol = bool(arg)
args = args[1:]
default:
break optionLoop
}
}
c.eqb.Reset()
rows := c.getRows(sql, args)
var err error
if simpleProtocol {
sql, err = c.sanitizeForSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
return rows, err
}
mrr := c.pgConn.Exec(ctx, sql)
if mrr.NextResult() {
rows.resultReader = mrr.ResultReader()
rows.multiResultReader = mrr
} else {
err = mrr.Close()
rows.fatal(err)
return rows, err
}
return rows, nil
}
c.eqb.Reset()
ps, ok := c.preparedStatements[sql]
if !ok {
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
@@ -550,7 +618,6 @@ optionLoop:
}
rows.sql = ps.SQL
var err error
args, err = convertDriverValuers(args)
if err != nil {
rows.fatal(err)
@@ -663,3 +730,24 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
mrr: mrr,
}
}
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
}
if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
}
var err error
valueArgs := make([]interface{}, len(args))
for i, a := range args {
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
if err != nil {
return "", err
}
}
return sanitize.SanitizeSQL(sql, valueArgs...)
}