diff --git a/pgconn.go b/pgconn.go index 7d437434..4f3cdd66 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1257,15 +1257,19 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - _, err := pgConn.conn.Write(batch.buf) - if err != nil { - pgConn.hardClose() - pgConn.doneChanToDeadline.cleanup() - multiResult.closed = true - multiResult.err = preferContextOverNetTimeoutError(ctx, err) - pgConn.unlock() - return multiResult - } + + // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is + // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication + // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. + // The error the code reading the batch results receives will be a closed connection error. + // + // See https://github.com/jackc/pgx/issues/374. + go func() { + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + pgConn.conn.Close() + } + }() return multiResult } diff --git a/pgconn_test.go b/pgconn_test.go index d31e8cc9..25cc3ee3 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -682,6 +682,60 @@ func TestConnExecBatchPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +// Without concurrent reading and writing large batches can deadlock. +// +// See https://github.com/jackc/pgx/issues/374. +func TestConnExecBatchHuge(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + + queryCount := 100000 + args := make([]string, queryCount) + + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } + + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) + + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } +} + +func TestConnExecBatchImplicitTransaction(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) + + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) +} + func TestConnLocking(t *testing.T) { t.Parallel()