diff --git a/batch.go b/batch.go index 8c924e8d..7f5422dc 100644 --- a/batch.go +++ b/batch.go @@ -268,6 +268,23 @@ func (b *Batch) Close() (err error) { } } + for b.conn.pendingReadyForQueryCount > 0 { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return b.conn.rxErrorResponse(msg) + default: + err = b.conn.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } + if err = b.conn.ensureConnectionReadyForQuery(); err != nil { return err } diff --git a/batch_test.go b/batch_test.go index 61bbe357..d0e26875 100644 --- a/batch_test.go +++ b/batch_test.go @@ -701,3 +701,55 @@ func TestTxBeginBatchRollback(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnBeginBatchDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + batch := conn.BeginBatch() + batch.Queue(`update t set n=n+1 where id='b' returning *`, + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + err = batch.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} diff --git a/conn_test.go b/conn_test.go index fea3b659..c6ce50cc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1370,6 +1370,32 @@ func TestExecFailure(t *testing.T) { } } +func TestExecDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + _, err := conn.Exec(`update t set n=n+1 where id='b'`) + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestExecFailureWithArguments(t *testing.T) { t.Parallel() diff --git a/query.go b/query.go index 5c6cbf7f..bf4ec561 100644 --- a/query.go +++ b/query.go @@ -69,6 +69,25 @@ func (rows *Rows) Close() { return } + // If there is no error and a batch operation is not in progress read until we get the ReadyForQuery message or the + // ErrorResponse. This is necessary to detect a deferred constraint violation where the ErrorResponse is sent after + // CommandComplete. + if rows.err == nil && rows.batch == nil && rows.conn.pendingReadyForQueryCount == 1 { + for rows.conn.pendingReadyForQueryCount > 0 { + msg, err := rows.conn.rxMsg() + if err != nil { + rows.err = err + break + } + + err = rows.conn.processContextFreeMsg(msg) + if err != nil { + rows.err = err + break + } + } + } + if rows.unlockConn { rows.conn.unlock() rows.unlockConn = false diff --git a/query_test.go b/query_test.go index 06b7b8b7..ea1fd66e 100644 --- a/query_test.go +++ b/query_test.go @@ -14,7 +14,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - "github.com/satori/go.uuid" + uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" ) @@ -424,6 +424,47 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } +// https://github.com/jackc/pgx/issues/570 +func TestConnQueryDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + rows, err := conn.Query(`update t set n=n+1 where id='b' returning *`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + if rows.Err() == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := rows.Err().(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestQueryEncodeError(t *testing.T) { t.Parallel()