Add QueryRewriter interface
This commit is contained in:
@@ -404,6 +404,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C
|
||||
|
||||
func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
var queryRewriter QueryRewriter
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
@@ -411,11 +412,18 @@ optionLoop:
|
||||
case QueryExecMode:
|
||||
mode = arg
|
||||
arguments = arguments[1:]
|
||||
case QueryRewriter:
|
||||
queryRewriter = arg
|
||||
arguments = arguments[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||
}
|
||||
|
||||
// Always use simple protocol when there are no arguments.
|
||||
if len(arguments) == 0 {
|
||||
mode = QueryExecModeSimpleProtocol
|
||||
@@ -682,6 +690,11 @@ type QueryResultFormats []int16
|
||||
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
|
||||
type QueryResultFormatsByOID map[uint32]int16
|
||||
|
||||
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
||||
type QueryRewriter interface {
|
||||
RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any)
|
||||
}
|
||||
|
||||
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
|
||||
// error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from
|
||||
// Query and handle it in Rows.
|
||||
@@ -696,6 +709,7 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error)
|
||||
var resultFormats QueryResultFormats
|
||||
var resultFormatsByOID QueryResultFormatsByOID
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
var queryRewriter QueryRewriter
|
||||
|
||||
optionLoop:
|
||||
for len(args) > 0 {
|
||||
@@ -709,11 +723,18 @@ optionLoop:
|
||||
case QueryExecMode:
|
||||
mode = arg
|
||||
args = args[1:]
|
||||
case QueryRewriter:
|
||||
queryRewriter = arg
|
||||
args = args[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||
}
|
||||
|
||||
c.eqb.Reset()
|
||||
anynil.NormalizeSlice(args)
|
||||
rows := c.getRows(ctx, sql, args)
|
||||
@@ -883,6 +904,30 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []an
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
|
||||
for _, bi := range b.items {
|
||||
var queryRewriter QueryRewriter
|
||||
sql := bi.query
|
||||
arguments := bi.arguments
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
switch arg := arguments[0].(type) {
|
||||
case QueryRewriter:
|
||||
queryRewriter = arg
|
||||
arguments = arguments[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||
}
|
||||
|
||||
bi.query = sql
|
||||
bi.arguments = arguments
|
||||
}
|
||||
|
||||
if mode == QueryExecModeSimpleProtocol {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.items {
|
||||
|
||||
Reference in New Issue
Block a user