2
0

Add QueryRewriter interface

This commit is contained in:
Jack Christensen
2022-04-23 17:26:42 -05:00
parent f9857b73d9
commit b72b0daa5a
5 changed files with 115 additions and 1 deletions
+45
View File
@@ -404,6 +404,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C
func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
mode := c.config.DefaultQueryExecMode
var queryRewriter QueryRewriter
optionLoop:
for len(arguments) > 0 {
@@ -411,11 +412,18 @@ optionLoop:
case QueryExecMode:
mode = arg
arguments = arguments[1:]
case QueryRewriter:
queryRewriter = arg
arguments = arguments[1:]
default:
break optionLoop
}
}
if queryRewriter != nil {
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
}
// Always use simple protocol when there are no arguments.
if len(arguments) == 0 {
mode = QueryExecModeSimpleProtocol
@@ -682,6 +690,11 @@ type QueryResultFormats []int16
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
type QueryResultFormatsByOID map[uint32]int16
// QueryRewriter rewrites a query when used as the first arguments to a query method.
type QueryRewriter interface {
RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any)
}
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
// error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from
// Query and handle it in Rows.
@@ -696,6 +709,7 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error)
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
mode := c.config.DefaultQueryExecMode
var queryRewriter QueryRewriter
optionLoop:
for len(args) > 0 {
@@ -709,11 +723,18 @@ optionLoop:
case QueryExecMode:
mode = arg
args = args[1:]
case QueryRewriter:
queryRewriter = arg
args = args[1:]
default:
break optionLoop
}
}
if queryRewriter != nil {
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
}
c.eqb.Reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)
@@ -883,6 +904,30 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []an
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
mode := c.config.DefaultQueryExecMode
for _, bi := range b.items {
var queryRewriter QueryRewriter
sql := bi.query
arguments := bi.arguments
optionLoop:
for len(arguments) > 0 {
switch arg := arguments[0].(type) {
case QueryRewriter:
queryRewriter = arg
arguments = arguments[1:]
default:
break optionLoop
}
}
if queryRewriter != nil {
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
}
bi.query = sql
bi.arguments = arguments
}
if mode == QueryExecModeSimpleProtocol {
var sb strings.Builder
for i, bi := range b.items {