Merge branch 'plopik-fixqueryrow'
* plopik-fixqueryrow: Fix incomplete selects during batch Fix queryRow leftover message on conn
This commit is contained in:
@@ -17,13 +17,14 @@ type batchItem struct {
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips.
|
||||
type Batch struct {
|
||||
conn *Conn
|
||||
connPool *ConnPool
|
||||
items []*batchItem
|
||||
resultsRead int
|
||||
sent bool
|
||||
ctx context.Context
|
||||
err error
|
||||
conn *Conn
|
||||
connPool *ConnPool
|
||||
items []*batchItem
|
||||
resultsRead int
|
||||
sent bool
|
||||
pendingCommandComplete bool
|
||||
ctx context.Context
|
||||
err error
|
||||
}
|
||||
|
||||
// BeginBatch returns a *Batch query for c.
|
||||
@@ -145,8 +146,15 @@ func (b *Batch) ExecResults() (CommandTag, error) {
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
for {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
@@ -155,6 +163,7 @@ func (b *Batch) ExecResults() (CommandTag, error) {
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
@@ -182,8 +191,16 @@ func (b *Batch) QueryResults() (*Rows, error) {
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
||||
if err != nil {
|
||||
b.die(err)
|
||||
@@ -244,3 +261,25 @@ func (b *Batch) die(err error) {
|
||||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batch) ensureCommandComplete() error {
|
||||
for b.pendingCommandComplete {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return nil
|
||||
default:
|
||||
err = b.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -476,3 +476,101 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
||||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select 1",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
|
||||
[]interface{}{"q1", 1},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var value int
|
||||
err = batch.QueryRowResults().Scan(&value)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ct, err := batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 2 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
|
||||
}
|
||||
|
||||
batch.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select 1 union all select 2 union all select 3",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
|
||||
[]interface{}{"q1", 1},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
ct, err := batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 2 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
|
||||
}
|
||||
|
||||
batch.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user