Enable all QueryExecModes for exec path
This commit is contained in:
@@ -405,48 +405,62 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||||
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol
|
mode := c.config.DefaultQueryExecMode
|
||||||
|
|
||||||
optionLoop:
|
optionLoop:
|
||||||
for len(arguments) > 0 {
|
for len(arguments) > 0 {
|
||||||
switch arg := arguments[0].(type) {
|
switch arg := arguments[0].(type) {
|
||||||
case QueryExecMode:
|
case QueryExecMode:
|
||||||
simpleProtocol = arg == QueryExecModeSimpleProtocol
|
mode = arg
|
||||||
arguments = arguments[1:]
|
arguments = arguments[1:]
|
||||||
default:
|
default:
|
||||||
break optionLoop
|
break optionLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always use simple protocol when there are no arguments.
|
||||||
|
if len(arguments) == 0 {
|
||||||
|
mode = QueryExecModeSimpleProtocol
|
||||||
|
}
|
||||||
|
|
||||||
if sd, ok := c.preparedStatements[sql]; ok {
|
if sd, ok := c.preparedStatements[sql]; ok {
|
||||||
return c.execPrepared(ctx, sd, arguments)
|
return c.execPrepared(ctx, sd, arguments)
|
||||||
}
|
}
|
||||||
|
|
||||||
if simpleProtocol {
|
switch mode {
|
||||||
return c.execSimpleProtocol(ctx, sql, arguments)
|
case QueryExecModeCacheStatement:
|
||||||
}
|
if c.statementCache == nil {
|
||||||
|
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
|
||||||
if len(arguments) == 0 {
|
}
|
||||||
return c.execSimpleProtocol(ctx, sql, arguments)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.statementCache != nil {
|
|
||||||
sd, err := c.statementCache.Get(ctx, sql)
|
sd, err := c.statementCache.Get(ctx, sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pgconn.CommandTag{}, err
|
return pgconn.CommandTag{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.statementCache.Mode() == stmtcache.ModeDescribe {
|
return c.execPrepared(ctx, sd, arguments)
|
||||||
return c.execParams(ctx, sd, arguments)
|
case QueryExecModeCacheDescribe:
|
||||||
|
if c.descriptionCache == nil {
|
||||||
|
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
|
||||||
|
}
|
||||||
|
sd, err := c.descriptionCache.Get(ctx, sql)
|
||||||
|
if err != nil {
|
||||||
|
return pgconn.CommandTag{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.execParams(ctx, sd, arguments)
|
||||||
|
case QueryExecModeDescribeExec:
|
||||||
|
sd, err := c.Prepare(ctx, "", sql)
|
||||||
|
if err != nil {
|
||||||
|
return pgconn.CommandTag{}, err
|
||||||
}
|
}
|
||||||
return c.execPrepared(ctx, sd, arguments)
|
return c.execPrepared(ctx, sd, arguments)
|
||||||
|
case QueryExecModeExec:
|
||||||
|
return c.execSQLParams(ctx, sql, arguments)
|
||||||
|
case QueryExecModeSimpleProtocol:
|
||||||
|
return c.execSimpleProtocol(ctx, sql, arguments)
|
||||||
|
default:
|
||||||
|
return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
sd, err := c.Prepare(ctx, "", sql)
|
|
||||||
if err != nil {
|
|
||||||
return pgconn.CommandTag{}, err
|
|
||||||
}
|
|
||||||
return c.execPrepared(ctx, sd, arguments)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
|
func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||||
@@ -510,6 +524,38 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
|
|||||||
return result.CommandTag, result.Err
|
return result.CommandTag, result.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type unknownArgumentTypeQueryExecModeExecError struct {
|
||||||
|
arg interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
|
||||||
|
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{}) (pgconn.CommandTag, error) {
|
||||||
|
c.eqb.Reset()
|
||||||
|
|
||||||
|
anynil.NormalizeSlice(args)
|
||||||
|
|
||||||
|
paramOIDs := make([]uint32, len(args))
|
||||||
|
|
||||||
|
for i := range args {
|
||||||
|
dt, ok := c.TypeMap().TypeForValue(args[i])
|
||||||
|
if !ok {
|
||||||
|
return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]}
|
||||||
|
}
|
||||||
|
err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i])
|
||||||
|
if err != nil {
|
||||||
|
return pgconn.CommandTag{}, err
|
||||||
|
}
|
||||||
|
paramOIDs[i] = dt.OID
|
||||||
|
}
|
||||||
|
|
||||||
|
result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
|
||||||
|
c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
|
||||||
|
return result.CommandTag, result.Err
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
|
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
|
||||||
r := &connRows{}
|
r := &connRows{}
|
||||||
|
|
||||||
|
|||||||
+9
-1
@@ -256,7 +256,15 @@ func TestExecFailureWithArguments(t *testing.T) {
|
|||||||
assert.False(t, pgconn.SafeToRetry(err))
|
assert.False(t, pgconn.SafeToRetry(err))
|
||||||
|
|
||||||
_, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2")
|
_, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2")
|
||||||
require.Error(t, err)
|
if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec {
|
||||||
|
// The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it
|
||||||
|
// locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check
|
||||||
|
// for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing
|
||||||
|
// the SQL.
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user