2
0

Split batch command and result

This commit is contained in:
Jack Christensen
2019-04-24 16:39:06 -05:00
parent 7b4e145e7c
commit aed6b822d9
7 changed files with 161 additions and 186 deletions
+54 -81
View File
@@ -23,7 +23,7 @@ func TestConnBeginBatch(t *testing.T) {
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch := &pgx.Batch{}
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
@@ -50,12 +50,9 @@ func TestConnBeginBatch(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := conn.SendBatch(context.Background(), batch)
ct, err := batch.ExecResults()
ct, err := br.ExecResults()
if err != nil {
t.Error(err)
}
@@ -63,7 +60,7 @@ func TestConnBeginBatch(t *testing.T) {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = batch.ExecResults()
ct, err = br.ExecResults()
if err != nil {
t.Error(err)
}
@@ -71,7 +68,7 @@ func TestConnBeginBatch(t *testing.T) {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = batch.ExecResults()
ct, err = br.ExecResults()
if err != nil {
t.Error(err)
}
@@ -79,7 +76,7 @@ func TestConnBeginBatch(t *testing.T) {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
rows, err := batch.QueryResults()
rows, err := br.QueryResults()
if err != nil {
t.Error(err)
}
@@ -143,7 +140,7 @@ func TestConnBeginBatch(t *testing.T) {
t.Fatal(rows.Err())
}
err = batch.QueryRowResults().Scan(&amount)
err = br.QueryRowResults().Scan(&amount)
if err != nil {
t.Error(err)
}
@@ -151,7 +148,7 @@ func TestConnBeginBatch(t *testing.T) {
t.Errorf("amount => %v, want %v", amount, 6)
}
err = batch.Close()
err = br.Close()
if err != nil {
t.Fatal(err)
}
@@ -170,7 +167,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
t.Fatal(err)
}
batch := conn.BeginBatch()
batch := &pgx.Batch{}
queryCount := 3
for i := 0; i < queryCount; i++ {
@@ -181,13 +178,10 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
)
}
err = batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := conn.SendBatch(context.Background(), batch)
for i := 0; i < queryCount; i++ {
rows, err := batch.QueryResults()
rows, err := br.QueryResults()
if err != nil {
t.Fatal(err)
}
@@ -207,7 +201,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
}
}
err = batch.Close()
err = br.Close()
if err != nil {
t.Fatal(err)
}
@@ -221,7 +215,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
@@ -233,12 +227,9 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := conn.SendBatch(context.Background(), batch)
rows, err := batch.QueryResults()
rows, err := br.QueryResults()
if err != nil {
t.Error(err)
}
@@ -259,7 +250,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
rows.Close()
rows, err = batch.QueryResults()
rows, err = br.QueryResults()
if err != nil {
t.Error(err)
}
@@ -278,7 +269,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
t.Error(rows.Err())
}
err = batch.Close()
err = br.Close()
if err != nil {
t.Fatal(err)
}
@@ -292,7 +283,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0",
nil,
nil,
@@ -304,12 +295,9 @@ func TestConnBeginBatchQueryError(t *testing.T) {
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := conn.SendBatch(context.Background(), batch)
rows, err := batch.QueryResults()
rows, err := br.QueryResults()
if err != nil {
t.Error(err)
}
@@ -328,7 +316,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
}
err = batch.Close()
err = br.Close()
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
}
@@ -342,25 +330,22 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch := &pgx.Batch{}
batch.Queue("select 1 1",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := conn.SendBatch(context.Background(), batch)
var n int32
err = batch.QueryRowResults().Scan(&n)
err := br.QueryRowResults().Scan(&n)
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
}
err = batch.Close()
err = br.Close()
if err == nil {
t.Error("Expected error")
}
@@ -381,7 +366,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch := &pgx.Batch{}
batch.Queue("select 1",
nil,
nil,
@@ -393,18 +378,15 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
nil,
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := conn.SendBatch(context.Background(), batch)
var value int
err = batch.QueryRowResults().Scan(&value)
err := br.QueryRowResults().Scan(&value)
if err != nil {
t.Error(err)
}
ct, err := batch.ExecResults()
ct, err := br.ExecResults()
if err != nil {
t.Error(err)
}
@@ -412,7 +394,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
}
batch.Close()
br.Close()
ensureConnValid(t, conn)
}
@@ -430,7 +412,7 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch := &pgx.Batch{}
batch.Queue("select 1 union all select 2 union all select 3",
nil,
nil,
@@ -442,18 +424,15 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
nil,
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := conn.SendBatch(context.Background(), batch)
rows, err := batch.QueryResults()
rows, err := br.QueryResults()
if err != nil {
t.Error(err)
}
rows.Close()
ct, err := batch.ExecResults()
ct, err := br.ExecResults()
if err != nil {
t.Error(err)
}
@@ -461,12 +440,12 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
}
batch.Close()
br.Close()
ensureConnValid(t, conn)
}
func TestTxBeginBatch(t *testing.T) {
func TestTxSendBatch(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -485,25 +464,23 @@ func TestTxBeginBatch(t *testing.T) {
mustExec(t, conn, sql)
tx, _ := conn.Begin(context.Background(), nil)
batch := tx.BeginBatch()
batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id",
[]interface{}{"q1"},
[]pgtype.OID{pgtype.VarcharOID},
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := tx.SendBatch(context.Background(), batch)
var id int
err = batch.QueryRowResults().Scan(&id)
err := br.QueryRowResults().Scan(&id)
if err != nil {
t.Error(err)
}
batch.Close()
br.Close()
batch = tx.BeginBatch()
batch = &pgx.Batch{}
batch.Queue("insert into ledger2(id,amount) values($1, $2)",
[]interface{}{id, 2},
[]pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
@@ -516,11 +493,9 @@ func TestTxBeginBatch(t *testing.T) {
nil,
)
err = batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
ct, err := batch.ExecResults()
br = tx.SendBatch(context.Background(), batch)
ct, err := br.ExecResults()
if err != nil {
t.Error(err)
}
@@ -529,12 +504,12 @@ func TestTxBeginBatch(t *testing.T) {
}
var amount int
err = batch.QueryRowResults().Scan(&amount)
err = br.QueryRowResults().Scan(&amount)
if err != nil {
t.Error(err)
}
batch.Close()
br.Close()
tx.Commit(context.Background())
var count int
@@ -543,7 +518,7 @@ func TestTxBeginBatch(t *testing.T) {
t.Errorf("count => %v, want %v", count, 1)
}
err = batch.Close()
err = br.Close()
if err != nil {
t.Fatal(err)
}
@@ -551,7 +526,7 @@ func TestTxBeginBatch(t *testing.T) {
ensureConnValid(t, conn)
}
func TestTxBeginBatchRollback(t *testing.T) {
func TestTxSendBatchRollback(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -564,23 +539,21 @@ func TestTxBeginBatchRollback(t *testing.T) {
mustExec(t, conn, sql)
tx, _ := conn.Begin(context.Background(), nil)
batch := tx.BeginBatch()
batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id",
[]interface{}{"q1"},
[]pgtype.OID{pgtype.VarcharOID},
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
br := tx.SendBatch(context.Background(), batch)
var id int
err = batch.QueryRowResults().Scan(&id)
err := br.QueryRowResults().Scan(&id)
if err != nil {
t.Error(err)
}
batch.Close()
br.Close()
tx.Rollback(context.Background())
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)