Merge branch 'plopik-fixqueryrow'
* plopik-fixqueryrow: Fix incomplete selects during batch Fix queryRow leftover message on conn
This commit is contained in:
@@ -22,6 +22,7 @@ type Batch struct {
|
|||||||
items []*batchItem
|
items []*batchItem
|
||||||
resultsRead int
|
resultsRead int
|
||||||
sent bool
|
sent bool
|
||||||
|
pendingCommandComplete bool
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
@@ -145,8 +146,15 @@ func (b *Batch) ExecResults() (CommandTag, error) {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := b.ensureCommandComplete(); err != nil {
|
||||||
|
b.die(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
b.resultsRead++
|
b.resultsRead++
|
||||||
|
|
||||||
|
b.pendingCommandComplete = true
|
||||||
|
|
||||||
for {
|
for {
|
||||||
msg, err := b.conn.rxMsg()
|
msg, err := b.conn.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -155,6 +163,7 @@ func (b *Batch) ExecResults() (CommandTag, error) {
|
|||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.CommandComplete:
|
case *pgproto3.CommandComplete:
|
||||||
|
b.pendingCommandComplete = false
|
||||||
return CommandTag(msg.CommandTag), nil
|
return CommandTag(msg.CommandTag), nil
|
||||||
default:
|
default:
|
||||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||||
@@ -182,8 +191,16 @@ func (b *Batch) QueryResults() (*Rows, error) {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := b.ensureCommandComplete(); err != nil {
|
||||||
|
b.die(err)
|
||||||
|
rows.fatal(err)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
b.resultsRead++
|
b.resultsRead++
|
||||||
|
|
||||||
|
b.pendingCommandComplete = true
|
||||||
|
|
||||||
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.die(err)
|
b.die(err)
|
||||||
@@ -244,3 +261,25 @@ func (b *Batch) die(err error) {
|
|||||||
b.connPool.Release(b.conn)
|
b.connPool.Release(b.conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Batch) ensureCommandComplete() error {
|
||||||
|
for b.pendingCommandComplete {
|
||||||
|
msg, err := b.conn.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.CommandComplete:
|
||||||
|
b.pendingCommandComplete = false
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
err = b.conn.processContextFreeMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -476,3 +476,101 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
|||||||
t.Error("conn should be dead, but was alive")
|
t.Error("conn should be dead, but was alive")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
sql := `create temporary table ledger(
|
||||||
|
id serial primary key,
|
||||||
|
description varchar not null,
|
||||||
|
amount int not null
|
||||||
|
);`
|
||||||
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
|
batch := conn.BeginBatch()
|
||||||
|
batch.Queue("select 1",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
[]int16{pgx.BinaryFormatCode},
|
||||||
|
)
|
||||||
|
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
|
||||||
|
[]interface{}{"q1", 1},
|
||||||
|
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
err := batch.Send(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var value int
|
||||||
|
err = batch.QueryRowResults().Scan(&value)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ct, err := batch.ExecResults()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if ct.RowsAffected() != 2 {
|
||||||
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.Close()
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
sql := `create temporary table ledger(
|
||||||
|
id serial primary key,
|
||||||
|
description varchar not null,
|
||||||
|
amount int not null
|
||||||
|
);`
|
||||||
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
|
batch := conn.BeginBatch()
|
||||||
|
batch.Queue("select 1 union all select 2 union all select 3",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
[]int16{pgx.BinaryFormatCode},
|
||||||
|
)
|
||||||
|
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
|
||||||
|
[]interface{}{"q1", 1},
|
||||||
|
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
err := batch.Send(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := batch.QueryResults()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
rows.Close()
|
||||||
|
|
||||||
|
ct, err := batch.ExecResults()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if ct.RowsAffected() != 2 {
|
||||||
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.Close()
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|||||||
@@ -149,6 +149,9 @@ func (rows *Rows) Next() bool {
|
|||||||
rows.values = msg.Values
|
rows.values = msg.Values
|
||||||
return true
|
return true
|
||||||
case *pgproto3.CommandComplete:
|
case *pgproto3.CommandComplete:
|
||||||
|
if rows.batch != nil {
|
||||||
|
rows.batch.pendingCommandComplete = false
|
||||||
|
}
|
||||||
rows.Close()
|
rows.Close()
|
||||||
return false
|
return false
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user