Restore simple protocol support
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v4/internal/sanitize"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,6 +25,14 @@ type ConnConfig struct {
|
||||
pgconn.Config
|
||||
Logger Logger
|
||||
LogLevel LogLevel
|
||||
|
||||
// PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended
|
||||
// protocol. This can improve performance due to being able to use the binary format. It also does not rely on client
|
||||
// side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement)
|
||||
// and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
|
||||
// used by default. The same functionality can be controlled on a per query basis by setting
|
||||
// QueryExOptions.SimpleProtocol.
|
||||
PreferSimpleProtocol bool
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
@@ -390,6 +399,36 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
|
||||
}
|
||||
|
||||
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
simpleProtocol := c.config.PreferSimpleProtocol
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
switch arg := arguments[0].(type) {
|
||||
case QuerySimpleProtocol:
|
||||
simpleProtocol = bool(arg)
|
||||
arguments = arguments[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
if simpleProtocol {
|
||||
sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mrr := c.pgConn.Exec(ctx, sql)
|
||||
if mrr.NextResult() {
|
||||
result := mrr.ResultReader().Read()
|
||||
err = mrr.Close()
|
||||
return result.CommandTag, err
|
||||
} else {
|
||||
err = mrr.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c.eqb.Reset()
|
||||
|
||||
if ps, ok := c.preparedStatements[sql]; ok {
|
||||
@@ -495,6 +534,9 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
||||
return r
|
||||
}
|
||||
|
||||
// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query.
|
||||
type QuerySimpleProtocol bool
|
||||
|
||||
// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
|
||||
type QueryResultFormats []int16
|
||||
|
||||
@@ -506,6 +548,7 @@ type QueryResultFormatsByOID map[pgtype.OID]int16
|
||||
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
|
||||
var resultFormats QueryResultFormats
|
||||
var resultFormatsByOID QueryResultFormatsByOID
|
||||
simpleProtocol := c.config.PreferSimpleProtocol
|
||||
|
||||
optionLoop:
|
||||
for len(args) > 0 {
|
||||
@@ -516,14 +559,39 @@ optionLoop:
|
||||
case QueryResultFormatsByOID:
|
||||
resultFormatsByOID = arg
|
||||
args = args[1:]
|
||||
case QuerySimpleProtocol:
|
||||
simpleProtocol = bool(arg)
|
||||
args = args[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
c.eqb.Reset()
|
||||
rows := c.getRows(sql, args)
|
||||
|
||||
var err error
|
||||
if simpleProtocol {
|
||||
sql, err = c.sanitizeForSimpleQuery(sql, args...)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
mrr := c.pgConn.Exec(ctx, sql)
|
||||
if mrr.NextResult() {
|
||||
rows.resultReader = mrr.ResultReader()
|
||||
rows.multiResultReader = mrr
|
||||
} else {
|
||||
err = mrr.Close()
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
c.eqb.Reset()
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
|
||||
@@ -550,7 +618,6 @@ optionLoop:
|
||||
}
|
||||
rows.sql = ps.SQL
|
||||
|
||||
var err error
|
||||
args, err = convertDriverValuers(args)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
@@ -663,3 +730,24 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
mrr: mrr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {
|
||||
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
|
||||
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
||||
}
|
||||
|
||||
if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
|
||||
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
|
||||
}
|
||||
|
||||
var err error
|
||||
valueArgs := make([]interface{}, len(args))
|
||||
for i, a := range args {
|
||||
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return sanitize.SanitizeSQL(sql, valueArgs...)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user