From e8f959e0e1a31026af1895c39b99c54f19fc2756 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Dec 2020 09:39:58 -0600 Subject: [PATCH] Add QueryFunc refs #821 --- conn.go | 43 +++++++++++++++++++++ doc.go | 17 ++++++++ pgxpool/conn.go | 4 ++ pgxpool/pool.go | 10 +++++ pgxpool/tx.go | 4 ++ query_test.go | 100 ++++++++++++++++++++++++++++++++++++++++++++++++ tx.go | 19 +++++++++ 7 files changed, 197 insertions(+) diff --git a/conn.go b/conn.go index 6c6d545f..ea3d7117 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" + "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4/internal/sanitize" ) @@ -666,6 +667,48 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro return (*connRow)(rows.(*connRows)) } +// QueryFuncRow is the argument to the QueryFunc callback function. +// +// QueryFuncRow is an interface instead of a struct to allow tests to mock QueryFunc. However, adding a method to an +// interface is technically a breaking change. Because of this the QueryFuncRow interface is partially excluded from +// semantic version requirements. Methods will not be removed or changed, but new methods may be added. +type QueryFuncRow interface { + FieldDescriptions() []pgproto3.FieldDescription + + // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid during the current + // function call. However, the underlying byte data is safe to retain a reference to and mutate. + RawValues() [][]byte +} + +// QueryFunc executes sql with args. For each row returned by the query the values will scanned into the elements of +// scans and f will be called. If any row fails to scan or f returns an error the query will be aborted and the error +// will be returned. +func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + rows, err := c.Query(ctx, sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(scans...) + if err != nil { + return nil, err + } + + err = f(rows) + if err != nil { + return nil, err + } + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return rows.CommandTag(), nil +} + // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. diff --git a/doc.go b/doc.go index 2c8f3b0d..8dce15e7 100644 --- a/doc.go +++ b/doc.go @@ -82,6 +82,23 @@ Use Exec to execute a query that does not return a result set. return errors.New("No row found to delete") } +QueryFunc can be used to execute a callback function for every row. This is often easier to use than Query. + + var sum, n int32 + _, err = conn.QueryFunc( + context.Background(), + "select generate_series(1,$1)", + []interface{}{10}, + []interface{}{&n}, + func(pgx.QueryFuncRow) error { + sum += n + return nil + }, + ) + if err != nil { + return err + } + Base Type Mapping pgx maps between all common base types directly between Go and PostgreSQL. In particular: diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 1172fbcb..afc75ced 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -58,6 +58,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pg return c.Conn().QueryRow(ctx, sql, args...) } +func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { + return c.Conn().QueryFunc(ctx, sql, args, scans, f) +} + func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return c.Conn().SendBatch(ctx, b) } diff --git a/pgxpool/pool.go b/pgxpool/pool.go index e288a86c..a9d1df65 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -430,6 +430,16 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg return c.getPoolRow(row) } +func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { + c, err := p.Acquire(ctx) + if err != nil { + return nil, err + } + defer c.Release() + + return c.QueryFunc(ctx, sql, args, scans, f) +} + func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { c, err := p.Acquire(ctx) if err != nil { diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 3ff5cb95..15e0ee2d 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -62,6 +62,10 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx return tx.t.QueryRow(ctx, sql, args...) } +func (tx *Tx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { + return tx.t.QueryFunc(ctx, sql, args, scans, f) +} + func (tx *Tx) Conn() *pgx.Conn { return tx.t.Conn() } diff --git a/query_test.go b/query_test.go index c850b5fe..f04a7a77 100644 --- a/query_test.go +++ b/query_test.go @@ -22,6 +22,7 @@ import ( "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + errors "golang.org/x/xerrors" ) func TestConnQueryScan(t *testing.T) { @@ -1971,3 +1972,102 @@ func TestQueryStatementCacheModes(t *testing.T) { }() } } + +func TestConnQueryFunc(t *testing.T) { + t.Parallel() + + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + var actualResults []interface{} + + var a, b int + ct, err := conn.QueryFunc( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + []interface{}{3}, + []interface{}{&a, &b}, + func(pgx.QueryFuncRow) error { + actualResults = append(actualResults, []interface{}{a, b}) + return nil + }, + ) + require.NoError(t, err) + + expectedResults := []interface{}{ + []interface{}{1, 2}, + []interface{}{2, 4}, + []interface{}{3, 6}, + } + require.Equal(t, expectedResults, actualResults) + require.EqualValues(t, 3, ct.RowsAffected()) + }) +} + +func TestConnQueryFuncScanError(t *testing.T) { + t.Parallel() + + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + var actualResults []interface{} + + var a, b int + ct, err := conn.QueryFunc( + context.Background(), + "select 'foo', 'bar' from generate_series(1, $1) n", + []interface{}{3}, + []interface{}{&a, &b}, + func(pgx.QueryFuncRow) error { + actualResults = append(actualResults, []interface{}{a, b}) + return nil + }, + ) + require.EqualError(t, err, "can't scan into dest[0]: unable to assign to *int") + require.Nil(t, ct) + }) +} + +func TestConnQueryFuncAbort(t *testing.T) { + t.Parallel() + + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + var a, b int + ct, err := conn.QueryFunc( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + []interface{}{3}, + []interface{}{&a, &b}, + func(pgx.QueryFuncRow) error { + return errors.New("abort") + }, + ) + require.EqualError(t, err, "abort") + require.Nil(t, ct) + }) +} + +func ExampleConn_QueryFunc() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + var a, b int + _, err = conn.QueryFunc( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + []interface{}{3}, + []interface{}{&a, &b}, + func(pgx.QueryFuncRow) error { + fmt.Printf("%v, %v\n", a, b) + return nil + }, + ) + if err != nil { + fmt.Printf("QueryFunc error: %v", err) + return + } + + // Output: + // 1, 2 + // 2, 4 + // 3, 6 +} diff --git a/tx.go b/tx.go index f19a65a1..43f8aa3e 100644 --- a/tx.go +++ b/tx.go @@ -117,6 +117,7 @@ type Tx interface { Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) QueryRow(ctx context.Context, sql string, args ...interface{}) Row + QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) // Conn returns the underlying *Conn that on which this transaction is executing. Conn() *Conn @@ -220,6 +221,15 @@ func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) R return (*connRow)(rows.(*connRows)) } +// QueryFunc delegates to the underlying *Conn. +func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + if tx.closed { + return nil, ErrTxClosed + } + + return tx.conn.QueryFunc(ctx, sql, args, scans, f) +} + // CopyFrom delegates to the underlying *Conn func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { if tx.closed { @@ -322,6 +332,15 @@ func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interfa return (*connRow)(rows.(*connRows)) } +// QueryFunc delegates to the underlying Tx. +func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.tx.QueryFunc(ctx, sql, args, scans, f) +} + // CopyFrom delegates to the underlying *Conn func (sp *dbSavepoint) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { if sp.closed {