2
0

Add SendBatch to pool

This commit is contained in:
Jack Christensen
2019-04-25 15:07:35 -05:00
parent 00d123a944
commit 7b1272d254
13 changed files with 170 additions and 29 deletions
+52
View File
@@ -0,0 +1,52 @@
package pool
import (
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
)
type errBatchResults struct {
err error
}
func (br errBatchResults) ExecResults() (pgconn.CommandTag, error) {
return nil, br.err
}
func (br errBatchResults) QueryResults() (pgx.Rows, error) {
return errRows{err: br.err}, br.err
}
func (br errBatchResults) QueryRowResults() pgx.Row {
return errRow{err: br.err}
}
func (br errBatchResults) Close() error {
return br.err
}
type poolBatchResults struct {
br pgx.BatchResults
c *Conn
}
func (br *poolBatchResults) ExecResults() (pgconn.CommandTag, error) {
return br.br.ExecResults()
}
func (br *poolBatchResults) QueryResults() (pgx.Rows, error) {
return br.br.QueryResults()
}
func (br *poolBatchResults) QueryRowResults() pgx.Row {
return br.br.QueryRowResults()
}
func (br *poolBatchResults) Close() error {
err := br.br.Close()
if br.c != nil {
br.c.Release()
br.c = nil
}
return err
}
+25
View File
@@ -69,3 +69,28 @@ func testQueryRow(t *testing.T, db queryRower) {
assert.Equal(t, "hello", what)
assert.Equal(t, "world", who)
}
type sendBatcher interface {
SendBatch(context.Context, *pgx.Batch) pgx.BatchResults
}
func testSendBatch(t *testing.T, db sendBatcher) {
batch := &pgx.Batch{}
batch.Queue("select 1", nil, nil, nil)
batch.Queue("select 2", nil, nil, nil)
br := db.SendBatch(context.Background(), batch)
var err error
var n int32
err = br.QueryRowResults().Scan(&n)
assert.NoError(t, err)
assert.EqualValues(t, 1, n)
err = br.QueryRowResults().Scan(&n)
assert.NoError(t, err)
assert.EqualValues(t, 2, n)
err = br.Close()
assert.NoError(t, err)
}
+4
View File
@@ -61,6 +61,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pg
return c.Conn().QueryRow(ctx, sql, args...)
}
func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return c.Conn().SendBatch(ctx, b)
}
func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) {
return c.Conn().Begin(ctx, txOptions)
}
+12
View File
@@ -44,3 +44,15 @@ func TestConnQueryRow(t *testing.T) {
testQueryRow(t, c)
}
func TestConnSendBatch(t *testing.T) {
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer pool.Close()
c, err := pool.Acquire(context.Background())
require.NoError(t, err)
defer c.Release()
testSendBatch(t, c)
}
+10
View File
@@ -127,6 +127,16 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg
return &poolRow{r: row, c: c}
}
func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
c, err := p.Acquire(ctx)
if err != nil {
return errBatchResults{err: err}
}
br := c.SendBatch(ctx, b)
return &poolBatchResults{br: br, c: c}
}
func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error) {
c, err := p.Acquire(ctx)
if err != nil {
+13
View File
@@ -90,6 +90,19 @@ func TestPoolQueryRow(t *testing.T) {
assert.EqualValues(t, 1, stats.TotalConns())
}
func TestPoolSendBatch(t *testing.T) {
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer pool.Close()
testSendBatch(t, pool)
waitForReleaseToComplete()
stats := pool.Stat()
assert.EqualValues(t, 0, stats.AcquiredConns())
assert.EqualValues(t, 1, stats.TotalConns())
}
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
-2
View File
@@ -1,5 +1,3 @@
func (p *ConnPool) BeginBatch() *Batch
func (p *ConnPool) Close()
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error)
func (p *ConnPool) Deallocate(name string) (err error)
func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error)
+4
View File
@@ -45,3 +45,7 @@ func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.R
func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
return tx.c.QueryRow(ctx, sql, args...)
}
func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return tx.c.SendBatch(ctx, b)
}
+12
View File
@@ -44,3 +44,15 @@ func TestTxQueryRow(t *testing.T) {
testQueryRow(t, tx)
}
func TestTxSendBatch(t *testing.T) {
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer pool.Close()
tx, err := pool.Begin(context.Background(), nil)
require.NoError(t, err)
defer tx.Rollback(context.Background())
testSendBatch(t, tx)
}