Split batch command and result
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
package pgx
|
package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgtype"
|
"github.com/jackc/pgtype"
|
||||||
errors "golang.org/x/xerrors"
|
errors "golang.org/x/xerrors"
|
||||||
@@ -18,21 +16,7 @@ type batchItem struct {
|
|||||||
// Batch queries are a way of bundling multiple queries together to avoid
|
// Batch queries are a way of bundling multiple queries together to avoid
|
||||||
// unnecessary network round trips.
|
// unnecessary network round trips.
|
||||||
type Batch struct {
|
type Batch struct {
|
||||||
conn *Conn
|
|
||||||
items []*batchItem
|
items []*batchItem
|
||||||
err error
|
|
||||||
|
|
||||||
mrr *pgconn.MultiResultReader
|
|
||||||
}
|
|
||||||
|
|
||||||
// BeginBatch returns a *Batch query for c.
|
|
||||||
func (c *Conn) BeginBatch() *Batch {
|
|
||||||
return &Batch{conn: c}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conn returns the underlying connection that b will or was performed on.
|
|
||||||
func (b *Batch) Conn() *Conn {
|
|
||||||
return b.conn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. parameterOIDs and
|
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. parameterOIDs and
|
||||||
@@ -47,92 +31,43 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends all queued queries to the server at once. All queries are run in an implicit transaction unless explicit
|
type BatchResults struct {
|
||||||
// transaction control statements are executed.
|
conn *Conn
|
||||||
func (b *Batch) Send(ctx context.Context) error {
|
mrr *pgconn.MultiResultReader
|
||||||
if b.err != nil {
|
err error
|
||||||
return b.err
|
|
||||||
}
|
|
||||||
|
|
||||||
batch := &pgconn.Batch{}
|
|
||||||
|
|
||||||
for _, bi := range b.items {
|
|
||||||
var parameterOIDs []pgtype.OID
|
|
||||||
ps := b.conn.preparedStatements[bi.query]
|
|
||||||
|
|
||||||
if ps != nil {
|
|
||||||
parameterOIDs = ps.ParameterOIDs
|
|
||||||
} else {
|
|
||||||
parameterOIDs = bi.parameterOIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
args, err := convertDriverValuers(bi.arguments)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
paramFormats := make([]int16, len(args))
|
|
||||||
paramValues := make([][]byte, len(args))
|
|
||||||
for i := range args {
|
|
||||||
paramFormats[i] = chooseParameterFormatCode(b.conn.ConnInfo, parameterOIDs[i], args[i])
|
|
||||||
paramValues[i], err = newencodePreparedStatementArgument(b.conn.ConnInfo, parameterOIDs[i], args[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if ps != nil {
|
|
||||||
resultFormats := bi.resultFormatCodes
|
|
||||||
if resultFormats == nil {
|
|
||||||
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
|
||||||
for i := range resultFormats {
|
|
||||||
if dt, ok := b.conn.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
|
||||||
resultFormats[i] = BinaryFormatCode
|
|
||||||
} else {
|
|
||||||
resultFormats[i] = TextFormatCode
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
batch.ExecPrepared(ps.Name, paramValues, paramFormats, resultFormats)
|
|
||||||
} else {
|
|
||||||
oids := make([]uint32, len(parameterOIDs))
|
|
||||||
for i := 0; i < len(parameterOIDs); i++ {
|
|
||||||
oids[i] = uint32(parameterOIDs[i])
|
|
||||||
}
|
|
||||||
batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
b.mrr = b.conn.pgConn.ExecBatch(ctx, batch)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecResults reads the results from the next query in the batch as if the
|
// ExecResults reads the results from the next query in the batch as if the
|
||||||
// query has been sent with Exec.
|
// query has been sent with Exec.
|
||||||
func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
|
func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) {
|
||||||
if !b.mrr.NextResult() {
|
if br.err != nil {
|
||||||
err := b.mrr.Close()
|
return nil, br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !br.mrr.NextResult() {
|
||||||
|
err := br.mrr.Close()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = errors.New("no result")
|
err = errors.New("no result")
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.mrr.ResultReader().Close()
|
return br.mrr.ResultReader().Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryResults reads the results from the next query in the batch as if the
|
// QueryResults reads the results from the next query in the batch as if the
|
||||||
// query has been sent with Query.
|
// query has been sent with Query.
|
||||||
func (b *Batch) QueryResults() (Rows, error) {
|
func (br *BatchResults) QueryResults() (Rows, error) {
|
||||||
rows := b.conn.getRows("batch query", nil)
|
rows := br.conn.getRows("batch query", nil)
|
||||||
|
|
||||||
if !b.mrr.NextResult() {
|
if br.err != nil {
|
||||||
rows.err = b.mrr.Close()
|
rows.err = br.err
|
||||||
|
rows.closed = true
|
||||||
|
return rows, br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !br.mrr.NextResult() {
|
||||||
|
rows.err = br.mrr.Close()
|
||||||
if rows.err == nil {
|
if rows.err == nil {
|
||||||
rows.err = errors.New("no result")
|
rows.err = errors.New("no result")
|
||||||
}
|
}
|
||||||
@@ -140,14 +75,14 @@ func (b *Batch) QueryResults() (Rows, error) {
|
|||||||
return rows, rows.err
|
return rows, rows.err
|
||||||
}
|
}
|
||||||
|
|
||||||
rows.resultReader = b.mrr.ResultReader()
|
rows.resultReader = br.mrr.ResultReader()
|
||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryRowResults reads the results from the next query in the batch as if the
|
// QueryRowResults reads the results from the next query in the batch as if the
|
||||||
// query has been sent with QueryRow.
|
// query has been sent with QueryRow.
|
||||||
func (b *Batch) QueryRowResults() Row {
|
func (br *BatchResults) QueryRowResults() Row {
|
||||||
rows, _ := b.QueryResults()
|
rows, _ := br.QueryResults()
|
||||||
return (*connRow)(rows.(*connRows))
|
return (*connRow)(rows.(*connRows))
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -155,6 +90,10 @@ func (b *Batch) QueryRowResults() Row {
|
|||||||
// Close closes the batch operation. Any error that occured during a batch
|
// Close closes the batch operation. Any error that occured during a batch
|
||||||
// operation may have made it impossible to resyncronize the connection with the
|
// operation may have made it impossible to resyncronize the connection with the
|
||||||
// server. In this case the underlying connection will have been closed.
|
// server. In this case the underlying connection will have been closed.
|
||||||
func (b *Batch) Close() (err error) {
|
func (br *BatchResults) Close() error {
|
||||||
return b.mrr.Close()
|
if br.err != nil {
|
||||||
|
return br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
return br.mrr.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
+54
-81
@@ -23,7 +23,7 @@ func TestConnBeginBatch(t *testing.T) {
|
|||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
batch := conn.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
||||||
[]interface{}{"q1", 1},
|
[]interface{}{"q1", 1},
|
||||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||||
@@ -50,12 +50,9 @@ func TestConnBeginBatch(t *testing.T) {
|
|||||||
[]int16{pgx.BinaryFormatCode},
|
[]int16{pgx.BinaryFormatCode},
|
||||||
)
|
)
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ct, err := batch.ExecResults()
|
ct, err := br.ExecResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -63,7 +60,7 @@ func TestConnBeginBatch(t *testing.T) {
|
|||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
ct, err = batch.ExecResults()
|
ct, err = br.ExecResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -71,7 +68,7 @@ func TestConnBeginBatch(t *testing.T) {
|
|||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
ct, err = batch.ExecResults()
|
ct, err = br.ExecResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -79,7 +76,7 @@ func TestConnBeginBatch(t *testing.T) {
|
|||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := batch.QueryResults()
|
rows, err := br.QueryResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -143,7 +140,7 @@ func TestConnBeginBatch(t *testing.T) {
|
|||||||
t.Fatal(rows.Err())
|
t.Fatal(rows.Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.QueryRowResults().Scan(&amount)
|
err = br.QueryRowResults().Scan(&amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -151,7 +148,7 @@ func TestConnBeginBatch(t *testing.T) {
|
|||||||
t.Errorf("amount => %v, want %v", amount, 6)
|
t.Errorf("amount => %v, want %v", amount, 6)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.Close()
|
err = br.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -170,7 +167,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch := conn.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
|
|
||||||
queryCount := 3
|
queryCount := 3
|
||||||
for i := 0; i < queryCount; i++ {
|
for i := 0; i < queryCount; i++ {
|
||||||
@@ -181,13 +178,10 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.Send(context.Background())
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < queryCount; i++ {
|
for i := 0; i < queryCount; i++ {
|
||||||
rows, err := batch.QueryResults()
|
rows, err := br.QueryResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -207,7 +201,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.Close()
|
err = br.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -221,7 +215,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
batch := conn.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select n from generate_series(0,5) n",
|
batch.Queue("select n from generate_series(0,5) n",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
@@ -233,12 +227,9 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||||||
[]int16{pgx.BinaryFormatCode},
|
[]int16{pgx.BinaryFormatCode},
|
||||||
)
|
)
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := batch.QueryResults()
|
rows, err := br.QueryResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -259,7 +250,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||||||
|
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
|
||||||
rows, err = batch.QueryResults()
|
rows, err = br.QueryResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -278,7 +269,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||||||
t.Error(rows.Err())
|
t.Error(rows.Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.Close()
|
err = br.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -292,7 +283,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
batch := conn.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0",
|
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
@@ -304,12 +295,9 @@ func TestConnBeginBatchQueryError(t *testing.T) {
|
|||||||
[]int16{pgx.BinaryFormatCode},
|
[]int16{pgx.BinaryFormatCode},
|
||||||
)
|
)
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := batch.QueryResults()
|
rows, err := br.QueryResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -328,7 +316,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
|
|||||||
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
|
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.Close()
|
err = br.Close()
|
||||||
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
||||||
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
||||||
}
|
}
|
||||||
@@ -342,25 +330,22 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
batch := conn.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select 1 1",
|
batch.Queue("select 1 1",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
[]int16{pgx.BinaryFormatCode},
|
[]int16{pgx.BinaryFormatCode},
|
||||||
)
|
)
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var n int32
|
var n int32
|
||||||
err = batch.QueryRowResults().Scan(&n)
|
err := br.QueryRowResults().Scan(&n)
|
||||||
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
|
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
|
||||||
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
|
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.Close()
|
err = br.Close()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error")
|
t.Error("Expected error")
|
||||||
}
|
}
|
||||||
@@ -381,7 +366,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
|||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
batch := conn.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select 1",
|
batch.Queue("select 1",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
@@ -393,18 +378,15 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var value int
|
var value int
|
||||||
err = batch.QueryRowResults().Scan(&value)
|
err := br.QueryRowResults().Scan(&value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ct, err := batch.ExecResults()
|
ct, err := br.ExecResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -412,7 +394,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
|||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Close()
|
br.Close()
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
@@ -430,7 +412,7 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
|||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
batch := conn.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select 1 union all select 2 union all select 3",
|
batch.Queue("select 1 union all select 2 union all select 3",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
@@ -442,18 +424,15 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := batch.QueryResults()
|
rows, err := br.QueryResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
|
||||||
ct, err := batch.ExecResults()
|
ct, err := br.ExecResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -461,12 +440,12 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
|||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Close()
|
br.Close()
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTxBeginBatch(t *testing.T) {
|
func TestTxSendBatch(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
@@ -485,25 +464,23 @@ func TestTxBeginBatch(t *testing.T) {
|
|||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
tx, _ := conn.Begin(context.Background(), nil)
|
tx, _ := conn.Begin(context.Background(), nil)
|
||||||
batch := tx.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("insert into ledger1(description) values($1) returning id",
|
batch.Queue("insert into ledger1(description) values($1) returning id",
|
||||||
[]interface{}{"q1"},
|
[]interface{}{"q1"},
|
||||||
[]pgtype.OID{pgtype.VarcharOID},
|
[]pgtype.OID{pgtype.VarcharOID},
|
||||||
[]int16{pgx.BinaryFormatCode},
|
[]int16{pgx.BinaryFormatCode},
|
||||||
)
|
)
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := tx.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
var id int
|
var id int
|
||||||
err = batch.QueryRowResults().Scan(&id)
|
err := br.QueryRowResults().Scan(&id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
batch.Close()
|
br.Close()
|
||||||
|
|
||||||
batch = tx.BeginBatch()
|
batch = &pgx.Batch{}
|
||||||
batch.Queue("insert into ledger2(id,amount) values($1, $2)",
|
batch.Queue("insert into ledger2(id,amount) values($1, $2)",
|
||||||
[]interface{}{id, 2},
|
[]interface{}{id, 2},
|
||||||
[]pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
|
[]pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
|
||||||
@@ -516,11 +493,9 @@ func TestTxBeginBatch(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = batch.Send(context.Background())
|
br = tx.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
ct, err := br.ExecResults()
|
||||||
}
|
|
||||||
ct, err := batch.ExecResults()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -529,12 +504,12 @@ func TestTxBeginBatch(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var amount int
|
var amount int
|
||||||
err = batch.QueryRowResults().Scan(&amount)
|
err = br.QueryRowResults().Scan(&amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Close()
|
br.Close()
|
||||||
tx.Commit(context.Background())
|
tx.Commit(context.Background())
|
||||||
|
|
||||||
var count int
|
var count int
|
||||||
@@ -543,7 +518,7 @@ func TestTxBeginBatch(t *testing.T) {
|
|||||||
t.Errorf("count => %v, want %v", count, 1)
|
t.Errorf("count => %v, want %v", count, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.Close()
|
err = br.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -551,7 +526,7 @@ func TestTxBeginBatch(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTxBeginBatchRollback(t *testing.T) {
|
func TestTxSendBatchRollback(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
@@ -564,23 +539,21 @@ func TestTxBeginBatchRollback(t *testing.T) {
|
|||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
tx, _ := conn.Begin(context.Background(), nil)
|
tx, _ := conn.Begin(context.Background(), nil)
|
||||||
batch := tx.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("insert into ledger1(description) values($1) returning id",
|
batch.Queue("insert into ledger1(description) values($1) returning id",
|
||||||
[]interface{}{"q1"},
|
[]interface{}{"q1"},
|
||||||
[]pgtype.OID{pgtype.VarcharOID},
|
[]pgtype.OID{pgtype.VarcharOID},
|
||||||
[]int16{pgx.BinaryFormatCode},
|
[]int16{pgx.BinaryFormatCode},
|
||||||
)
|
)
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := tx.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
var id int
|
var id int
|
||||||
err = batch.QueryRowResults().Scan(&id)
|
err := br.QueryRowResults().Scan(&id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
batch.Close()
|
br.Close()
|
||||||
tx.Rollback(context.Background())
|
tx.Rollback(context.Background())
|
||||||
|
|
||||||
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
|
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
|
||||||
|
|||||||
+4
-7
@@ -613,7 +613,7 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
batch := conn.BeginBatch()
|
batch := &pgx.Batch{}
|
||||||
for j := 0; j < queryCount; j++ {
|
for j := 0; j < queryCount; j++ {
|
||||||
batch.Queue("select n from generate_series(0,5) n",
|
batch.Queue("select n from generate_series(0,5) n",
|
||||||
nil,
|
nil,
|
||||||
@@ -622,13 +622,10 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := batch.Send(context.Background())
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for j := 0; j < queryCount; j++ {
|
for j := 0; j < queryCount; j++ {
|
||||||
rows, err := batch.QueryResults()
|
rows, err := br.QueryResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -648,7 +645,7 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = batch.Close()
|
err := br.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -704,3 +704,67 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro
|
|||||||
rows, _ := c.Query(ctx, sql, args...)
|
rows, _ := c.Query(ctx, sql, args...)
|
||||||
return (*connRow)(rows.(*connRows))
|
return (*connRow)(rows.(*connRows))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
|
||||||
|
// explicit transaction control statements are executed.
|
||||||
|
func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
|
||||||
|
batch := &pgconn.Batch{}
|
||||||
|
|
||||||
|
for _, bi := range b.items {
|
||||||
|
var parameterOIDs []pgtype.OID
|
||||||
|
ps := c.preparedStatements[bi.query]
|
||||||
|
|
||||||
|
if ps != nil {
|
||||||
|
parameterOIDs = ps.ParameterOIDs
|
||||||
|
} else {
|
||||||
|
parameterOIDs = bi.parameterOIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
args, err := convertDriverValuers(bi.arguments)
|
||||||
|
if err != nil {
|
||||||
|
return &BatchResults{err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
paramFormats := make([]int16, len(args))
|
||||||
|
paramValues := make([][]byte, len(args))
|
||||||
|
for i := range args {
|
||||||
|
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, parameterOIDs[i], args[i])
|
||||||
|
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, parameterOIDs[i], args[i])
|
||||||
|
if err != nil {
|
||||||
|
return &BatchResults{err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if ps != nil {
|
||||||
|
resultFormats := bi.resultFormatCodes
|
||||||
|
if resultFormats == nil {
|
||||||
|
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
||||||
|
for i := range resultFormats {
|
||||||
|
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
||||||
|
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||||
|
resultFormats[i] = BinaryFormatCode
|
||||||
|
} else {
|
||||||
|
resultFormats[i] = TextFormatCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.ExecPrepared(ps.Name, paramValues, paramFormats, resultFormats)
|
||||||
|
} else {
|
||||||
|
oids := make([]uint32, len(parameterOIDs))
|
||||||
|
for i := 0; i < len(parameterOIDs); i++ {
|
||||||
|
oids[i] = uint32(parameterOIDs[i])
|
||||||
|
}
|
||||||
|
batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mrr := c.pgConn.ExecBatch(ctx, batch)
|
||||||
|
|
||||||
|
return &BatchResults{
|
||||||
|
conn: c,
|
||||||
|
mrr: mrr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ require (
|
|||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db
|
||||||
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0
|
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0
|
||||||
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b
|
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b
|
||||||
github.com/lib/pq v1.1.0
|
|
||||||
github.com/pkg/errors v0.8.1
|
|
||||||
github.com/rs/zerolog v1.13.0
|
github.com/rs/zerolog v1.13.0
|
||||||
github.com/satori/go.uuid v1.2.0
|
github.com/satori/go.uuid v1.2.0
|
||||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24
|
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaK
|
|||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
||||||
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0 h1:mX93v750WifMD1htCt7vqeolcnpaG1gz8URVGjSzcUM=
|
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0 h1:mX93v750WifMD1htCt7vqeolcnpaG1gz8URVGjSzcUM=
|
||||||
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
|
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
|
||||||
github.com/jackc/pgx v3.3.0+incompatible h1:Wa90/+qsITBAPkAZjiByeIGHFcj3Ztu+VzrrIpHjL90=
|
|
||||||
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
|
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
|
||||||
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b h1:cIcUpcEP55F/QuZWEtXyqHoWk+IV4TBiLjtBkeq/Q1c=
|
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b h1:cIcUpcEP55F/QuZWEtXyqHoWk+IV4TBiLjtBkeq/Q1c=
|
||||||
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||||
@@ -31,6 +30,7 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
|||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
|
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
|
||||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||||
|
github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4=
|
||||||
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
|||||||
@@ -185,9 +185,13 @@ func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []
|
|||||||
return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc)
|
return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeginBatch returns a *Batch query for the tx's connection.
|
// SendBatch delegates to the underlying *Conn
|
||||||
func (tx *Tx) BeginBatch() *Batch {
|
func (tx *Tx) SendBatch(ctx context.Context, b *Batch) *BatchResults {
|
||||||
return &Batch{conn: tx.conn}
|
if tx.status != TxStatusInProgress {
|
||||||
|
return &BatchResults{err: ErrTxClosed}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.conn.SendBatch(ctx, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status returns the status of the transaction from the set of
|
// Status returns the status of the transaction from the set of
|
||||||
|
|||||||
Reference in New Issue
Block a user