Partial conversion of pgx to use pgconn
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type batchItem struct {
|
||||
@@ -26,6 +27,8 @@ type Batch struct {
|
||||
ctx context.Context
|
||||
err error
|
||||
inTx bool
|
||||
|
||||
mrr *pgconn.MultiResultReader
|
||||
}
|
||||
|
||||
// BeginBatch returns a *Batch query for c.
|
||||
@@ -56,10 +59,8 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
|
||||
})
|
||||
}
|
||||
|
||||
// Send sends all queued queries to the server at once.
|
||||
// If the batch is created from a conn Object then All queries are wrapped
|
||||
// in a transaction. The transaction can optionally be configured with
|
||||
// txOptions. The context is in effect until the Batch is closed.
|
||||
// Send sends all queued queries to the server at once. All queries are run in an implicit transaction unless explicit
|
||||
// transaction control statements are executed.
|
||||
//
|
||||
// Warning: Send writes all queued queries before reading any results. This can
|
||||
// cause a deadlock if an excessive number of queries are queued. It is highly
|
||||
@@ -78,7 +79,7 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
|
||||
// able to finish sending the responses.
|
||||
//
|
||||
// See https://github.com/jackc/pgx/issues/374.
|
||||
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
||||
func (b *Batch) Send(ctx context.Context) error {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
@@ -94,112 +95,62 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := b.conn.wbuf
|
||||
if !b.inTx {
|
||||
buf = appendQuery(buf, txOptions.beginSQL())
|
||||
}
|
||||
|
||||
err = b.conn.initContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
var psName string
|
||||
var psParameterOIDs []pgtype.OID
|
||||
var parameterOIDs []pgtype.OID
|
||||
ps := b.conn.preparedStatements[bi.query]
|
||||
|
||||
if ps, ok := b.conn.preparedStatements[bi.query]; ok {
|
||||
psName = ps.Name
|
||||
psParameterOIDs = ps.ParameterOIDs
|
||||
if ps != nil {
|
||||
parameterOIDs = ps.ParameterOIDs
|
||||
} else {
|
||||
psParameterOIDs = bi.parameterOIDs
|
||||
buf = appendParse(buf, "", bi.query, psParameterOIDs)
|
||||
parameterOIDs = bi.parameterOIDs
|
||||
}
|
||||
|
||||
var err error
|
||||
buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes)
|
||||
args, err := convertDriverValuers(bi.arguments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf = appendDescribe(buf, 'P', "")
|
||||
buf = appendExecute(buf, "", 0)
|
||||
}
|
||||
|
||||
buf = appendSync(buf)
|
||||
b.conn.pendingReadyForQueryCount++
|
||||
|
||||
if !b.inTx {
|
||||
buf = appendQuery(buf, "commit")
|
||||
b.conn.pendingReadyForQueryCount++
|
||||
}
|
||||
|
||||
n, err := b.conn.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
if fatalWriteErr(n, err) {
|
||||
b.conn.die(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for !b.inTx {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
return nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
paramFormats := make([]int16, len(args))
|
||||
paramValues := make([][]byte, len(args))
|
||||
for i := range args {
|
||||
paramFormats[i] = chooseParameterFormatCode(b.conn.ConnInfo, parameterOIDs[i], args[i])
|
||||
paramValues[i], err = newencodePreparedStatementArgument(b.conn.ConnInfo, parameterOIDs[i], args[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if ps != nil {
|
||||
batch.ExecPrepared(ps.Name, paramValues, paramFormats, bi.resultFormatCodes)
|
||||
} else {
|
||||
oids := make([]uint32, len(parameterOIDs))
|
||||
for i := 0; i < len(parameterOIDs); i++ {
|
||||
oids[i] = uint32(parameterOIDs[i])
|
||||
}
|
||||
batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes)
|
||||
}
|
||||
}
|
||||
|
||||
b.mrr = b.conn.pgConn.ExecBatch(ctx, batch)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Exec.
|
||||
func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
|
||||
if b.err != nil {
|
||||
return "", b.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
return "", b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
if !b.mrr.NextResult() {
|
||||
err := b.mrr.Close()
|
||||
if err == nil {
|
||||
err = errors.New("no result")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
for {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return pgconn.CommandTag(msg.CommandTag), nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.mrr.ResultReader().Close()
|
||||
}
|
||||
|
||||
// QueryResults reads the results from the next query in the batch as if the
|
||||
@@ -207,38 +158,16 @@ func (b *Batch) ExecResults() (pgconn.CommandTag, error) {
|
||||
func (b *Batch) QueryResults() (*Rows, error) {
|
||||
rows := b.conn.getRows("batch query", nil)
|
||||
|
||||
if b.err != nil {
|
||||
rows.fatal(b.err)
|
||||
return rows, b.err
|
||||
if !b.mrr.NextResult() {
|
||||
rows.err = b.mrr.Close()
|
||||
if rows.err == nil {
|
||||
rows.err = errors.New("no result")
|
||||
}
|
||||
rows.closed = true
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.die(b.ctx.Err())
|
||||
rows.fatal(b.err)
|
||||
return rows, b.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
||||
if err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(b.err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.batch = b
|
||||
rows.fields = fieldDescriptions
|
||||
rows.resultReader = b.mrr.ResultReader()
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -254,28 +183,7 @@ func (b *Batch) QueryRowResults() *Row {
|
||||
// operation may have made it impossible to resyncronize the connection with the
|
||||
// server. In this case the underlying connection will have been closed.
|
||||
func (b *Batch) Close() (err error) {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = b.conn.termContext(err)
|
||||
if b.conn != nil && b.connPool != nil {
|
||||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := b.resultsRead; i < len(b.items); i++ {
|
||||
if _, err = b.ExecResults(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = b.conn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return b.mrr.Close()
|
||||
}
|
||||
|
||||
func (b *Batch) die(err error) {
|
||||
|
||||
Reference in New Issue
Block a user