2
0
Files
pgx/batch_test.go
T
Jack Christensen 66625e6489 Prepare takes context
Also remove PrepareEx. It's primary usage was for context. Supplying
parameter OIDs is unnecessary when you can type cast in the query SQL.
If it does become necessary or desirable to add options back it can be
added in a backwards compatible way by adding a varargs as last
argument.
2019-04-20 11:47:16 -05:00

595 lines
12 KiB
Go

package pgx_test
import (
"context"
"os"
"testing"
"github.com/jackc/pgconn"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
)
func TestConnBeginBatch(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q2", 2},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q3", 3},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
batch.Queue("select id, description, amount from ledger order by id",
nil,
nil,
[]int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode},
)
batch.Queue("select sum(amount) from ledger",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
var id int32
var description string
var amount int32
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 1 {
t.Errorf("id => %v, want %v", id, 1)
}
if description != "q1" {
t.Errorf("description => %v, want %v", description, "q1")
}
if amount != 1 {
t.Errorf("amount => %v, want %v", amount, 1)
}
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 2 {
t.Errorf("id => %v, want %v", id, 2)
}
if description != "q2" {
t.Errorf("description => %v, want %v", description, "q2")
}
if amount != 2 {
t.Errorf("amount => %v, want %v", amount, 2)
}
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 3 {
t.Errorf("id => %v, want %v", id, 3)
}
if description != "q3" {
t.Errorf("description => %v, want %v", description, "q3")
}
if amount != 3 {
t.Errorf("amount => %v, want %v", amount, 3)
}
if rows.Next() {
t.Fatal("did not expect a row to be available")
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
err = batch.QueryRowResults().Scan(&amount)
if err != nil {
t.Error(err)
}
if amount != 6 {
t.Errorf("amount => %v, want %v", amount, 6)
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
batch := conn.BeginBatch()
queryCount := 3
for i := 0; i < queryCount; i++ {
batch.Queue("ps1",
[]interface{}{5},
nil,
nil,
)
}
err = batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
for i := 0; i < queryCount; i++ {
rows, err := batch.QueryResults()
if err != nil {
t.Fatal(err)
}
for k := 0; rows.Next(); k++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Fatal(err)
}
if n != k {
t.Fatalf("n => %v, want %v", n, k)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
for i := 0; i < 3; i++ {
if !rows.Next() {
t.Error("expected a row to be available")
}
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
rows.Close()
rows, err = batch.QueryResults()
if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchQueryError(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("select n from generate_series(0,5) n",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
}
err = batch.Close()
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
batch := conn.BeginBatch()
batch.Queue("select 1 1",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
var n int32
err = batch.QueryRowResults().Scan(&n)
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
}
err = batch.Close()
if err == nil {
t.Error("Expected error")
}
ensureConnValid(t, conn)
}
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("select 1",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
var value int
err = batch.QueryRowResults().Scan(&value)
if err != nil {
t.Error(err)
}
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
}
batch.Close()
ensureConnValid(t, conn)
}
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("select 1 union all select 2 union all select 3",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
rows.Close()
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
}
batch.Close()
ensureConnValid(t, conn)
}
func TestTxBeginBatch(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
sql := `create temporary table ledger1(
id serial primary key,
description varchar not null
);`
mustExec(t, conn, sql)
sql = `create temporary table ledger2(
id int primary key,
amount int not null
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin()
batch := tx.BeginBatch()
batch.Queue("insert into ledger1(description) values($1) returning id",
[]interface{}{"q1"},
[]pgtype.OID{pgtype.VarcharOID},
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
var id int
err = batch.QueryRowResults().Scan(&id)
if err != nil {
t.Error(err)
}
batch.Close()
batch = tx.BeginBatch()
batch.Queue("insert into ledger2(id,amount) values($1, $2)",
[]interface{}{id, 2},
[]pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
nil,
)
batch.Queue("select amount from ledger2 where id = $1",
[]interface{}{id},
[]pgtype.OID{pgtype.Int4OID},
nil,
)
err = batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
var amount int
err = batch.QueryRowResults().Scan(&amount)
if err != nil {
t.Error(err)
}
batch.Close()
tx.Commit(context.Background())
var count int
conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count)
if count != 1 {
t.Errorf("count => %v, want %v", count, 1)
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestTxBeginBatchRollback(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
sql := `create temporary table ledger1(
id serial primary key,
description varchar not null
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin()
batch := tx.BeginBatch()
batch.Queue("insert into ledger1(description) values($1) returning id",
[]interface{}{"q1"},
[]pgtype.OID{pgtype.VarcharOID},
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background())
if err != nil {
t.Fatal(err)
}
var id int
err = batch.QueryRowResults().Scan(&id)
if err != nil {
t.Error(err)
}
batch.Close()
tx.Rollback(context.Background())
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
var count int
row.Scan(&count)
if count != 0 {
t.Errorf("count => %v, want %v", count, 0)
}
ensureConnValid(t, conn)
}