Split batch command and result
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgtype"
|
||||
errors "golang.org/x/xerrors"
|
||||
@@ -18,21 +16,7 @@ type batchItem struct {
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips.
|
||||
type Batch struct {
|
||||
conn *Conn
|
||||
items []*batchItem
|
||||
err error
|
||||
|
||||
mrr *pgconn.MultiResultReader
|
||||
}
|
||||
|
||||
// BeginBatch returns a *Batch query for c.
|
||||
func (c *Conn) BeginBatch() *Batch {
|
||||
return &Batch{conn: c}
|
||||
}
|
||||
|
||||
// Conn returns the underlying connection that b will or was performed on.
|
||||
func (b *Batch) Conn() *Conn {
|
||||
return b.conn
|
||||
}
|
||||
|
||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. parameterOIDs and
|
||||
@@ -47,92 +31,43 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
|
||||
})
|
||||
}
|
||||
|
||||
// Send sends all queued queries to the server at once. All queries are run in an implicit transaction unless explicit
|
||||
// transaction control statements are executed.
|
||||
func (b *Batch) Send(ctx context.Context) error {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
var parameterOIDs []pgtype.OID
|
||||
ps := b.conn.preparedStatements[bi.query]
|
||||
|
||||
if ps != nil {
|
||||
parameterOIDs = ps.ParameterOIDs
|
||||
} else {
|
||||
parameterOIDs = bi.parameterOIDs
|
||||
}
|
||||
|
||||
args, err := convertDriverValuers(bi.arguments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
paramFormats := make([]int16, len(args))
|
||||
paramValues := make([][]byte, len(args))
|
||||
for i := range args {
|
||||
paramFormats[i] = chooseParameterFormatCode(b.conn.ConnInfo, parameterOIDs[i], args[i])
|
||||
paramValues[i], err = newencodePreparedStatementArgument(b.conn.ConnInfo, parameterOIDs[i], args[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if ps != nil {
|
||||
resultFormats := bi.resultFormatCodes
|
||||
if resultFormats == nil {
|
||||
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
||||
for i := range resultFormats {
|
||||
if dt, ok := b.conn.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
resultFormats[i] = BinaryFormatCode
|
||||
} else {
|
||||
resultFormats[i] = TextFormatCode
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batch.ExecPrepared(ps.Name, paramValues, paramFormats, resultFormats)
|
||||
} else {
|
||||
oids := make([]uint32, len(parameterOIDs))
|
||||
for i := 0; i < len(parameterOIDs); i++ {
|
||||
oids[i] = uint32(parameterOIDs[i])
|
||||
}
|
||||
batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes)
|
||||
}
|
||||
}
|
||||
|
||||
b.mrr = b.conn.pgConn.ExecBatch(ctx, batch)
|
||||
|
||||
return nil
|
||||
type BatchResults struct {
|
||||
conn *Conn
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
}
|
||||
|
||||
// ExecResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Exec.
|
||||
func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
|
||||
if !b.mrr.NextResult() {
|
||||
err := b.mrr.Close()
|
||||
func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) {
|
||||
if br.err != nil {
|
||||
return nil, br.err
|
||||
}
|
||||
|
||||
if !br.mrr.NextResult() {
|
||||
err := br.mrr.Close()
|
||||
if err == nil {
|
||||
err = errors.New("no result")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.mrr.ResultReader().Close()
|
||||
return br.mrr.ResultReader().Close()
|
||||
}
|
||||
|
||||
// QueryResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Query.
|
||||
func (b *Batch) QueryResults() (Rows, error) {
|
||||
rows := b.conn.getRows("batch query", nil)
|
||||
func (br *BatchResults) QueryResults() (Rows, error) {
|
||||
rows := br.conn.getRows("batch query", nil)
|
||||
|
||||
if !b.mrr.NextResult() {
|
||||
rows.err = b.mrr.Close()
|
||||
if br.err != nil {
|
||||
rows.err = br.err
|
||||
rows.closed = true
|
||||
return rows, br.err
|
||||
}
|
||||
|
||||
if !br.mrr.NextResult() {
|
||||
rows.err = br.mrr.Close()
|
||||
if rows.err == nil {
|
||||
rows.err = errors.New("no result")
|
||||
}
|
||||
@@ -140,14 +75,14 @@ func (b *Batch) QueryResults() (Rows, error) {
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
rows.resultReader = b.mrr.ResultReader()
|
||||
rows.resultReader = br.mrr.ResultReader()
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// QueryRowResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with QueryRow.
|
||||
func (b *Batch) QueryRowResults() Row {
|
||||
rows, _ := b.QueryResults()
|
||||
func (br *BatchResults) QueryRowResults() Row {
|
||||
rows, _ := br.QueryResults()
|
||||
return (*connRow)(rows.(*connRows))
|
||||
|
||||
}
|
||||
@@ -155,6 +90,10 @@ func (b *Batch) QueryRowResults() Row {
|
||||
// Close closes the batch operation. Any error that occured during a batch
|
||||
// operation may have made it impossible to resyncronize the connection with the
|
||||
// server. In this case the underlying connection will have been closed.
|
||||
func (b *Batch) Close() (err error) {
|
||||
return b.mrr.Close()
|
||||
func (br *BatchResults) Close() error {
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
||||
return br.mrr.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user