2
0
Files
pgx/stdlib/sql_test.go
T
Jack Christensen 7a2b93323c Prevent prematurely closing statements in database/sql
This error was introduced by 0f0d236599.
If the same statement was prepared multiple times then whenever Close
was called on one of the statements the underlying prepared statement
would be closed even if other statements were still using it.

https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634
2023-10-10 21:56:26 -05:00

1375 lines
35 KiB
Go

package stdlib_test
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"os"
"reflect"
"regexp"
"strconv"
"sync"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/jackc/pgx/v5/tracelog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func openDB(t testing.TB) *sql.DB {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
return stdlib.OpenDB(*config)
}
func closeDB(t testing.TB, db *sql.DB) {
err := db.Close()
require.NoError(t, err)
}
func skipCockroachDB(t testing.TB, db *sql.DB, msg string) {
conn, err := db.Conn(context.Background())
require.NoError(t, err)
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(*stdlib.Conn).Conn()
if conn.PgConn().ParameterStatus("crdb_version") != "" {
t.Skip(msg)
}
return nil
})
require.NoError(t, err)
}
func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) {
conn, err := db.Conn(context.Background())
require.NoError(t, err)
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(*stdlib.Conn).Conn()
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return nil
}
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
if err != nil {
return err
}
if serverVersion < minVersion {
t.Skipf("Test requires PostgreSQL v%d+", minVersion)
}
return nil
})
require.NoError(t, err)
}
func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) {
for _, mode := range []pgx.QueryExecMode{
pgx.QueryExecModeCacheStatement,
pgx.QueryExecModeCacheDescribe,
pgx.QueryExecModeDescribeExec,
pgx.QueryExecModeExec,
pgx.QueryExecModeSimpleProtocol,
} {
t.Run(mode.String(),
func(t *testing.T) {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.DefaultQueryExecMode = mode
db := stdlib.OpenDB(*config)
defer func() {
err := db.Close()
require.NoError(t, err)
}()
f(t, db)
ensureDBValid(t, db)
},
)
}
}
// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
// cover broken connections.
func ensureDBValid(t testing.TB, db *sql.DB) {
var sum, rowCount int32
rows, err := db.Query("select generate_series(1,$1)", 10)
require.NoError(t, err)
defer rows.Close()
for rows.Next() {
var n int32
rows.Scan(&n)
sum += n
rowCount++
}
require.NoError(t, rows.Err())
if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times")
}
if sum != 55 {
t.Error("Wrong values returned")
}
}
type preparer interface {
Prepare(query string) (*sql.Stmt, error)
}
func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
stmt, err := p.Prepare(sql)
require.NoError(t, err)
return stmt
}
func closeStmt(t *testing.T, stmt *sql.Stmt) {
err := stmt.Close()
require.NoError(t, err)
}
func TestSQLOpen(t *testing.T) {
tests := []struct {
driverName string
}{
{driverName: "pgx"},
{driverName: "pgx/v5"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.driverName, func(t *testing.T) {
db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
closeDB(t, db)
})
}
}
func TestSQLOpenFromPool(t *testing.T) {
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
t.Cleanup(pool.Close)
db := stdlib.OpenDBFromPool(pool)
ensureDBValid(t, db)
db.Close()
}
func TestNormalLifeCycle(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
defer closeStmt(t, stmt)
rows, err := stmt.Query(int32(1), int32(10))
require.NoError(t, err)
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
err := rows.Scan(&s, &n)
require.NoError(t, err)
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
require.NoError(t, rows.Err())
require.EqualValues(t, 10, rowCount)
err = rows.Close()
require.NoError(t, err)
ensureDBValid(t, db)
}
func TestStmtExec(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
tx, err := db.Begin()
require.NoError(t, err)
createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
_, err = createStmt.Exec()
require.NoError(t, err)
closeStmt(t, createStmt)
insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
result, err := insertStmt.Exec("foo")
require.NoError(t, err)
n, err := result.RowsAffected()
require.NoError(t, err)
require.EqualValues(t, 1, n)
closeStmt(t, insertStmt)
ensureDBValid(t, db)
}
func TestQueryCloseRowsEarly(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
defer closeStmt(t, stmt)
rows, err := stmt.Query(int32(1), int32(10))
require.NoError(t, err)
// Close rows immediately without having read them
err = rows.Close()
require.NoError(t, err)
// Run the query again to ensure the connection and statement are still ok
rows, err = stmt.Query(int32(1), int32(10))
require.NoError(t, err)
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
err := rows.Scan(&s, &n)
require.NoError(t, err)
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
require.NoError(t, rows.Err())
require.EqualValues(t, 10, rowCount)
err = rows.Close()
require.NoError(t, err)
ensureDBValid(t, db)
}
func TestConnExec(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
_, err := db.Exec("create temporary table t(a varchar not null)")
require.NoError(t, err)
result, err := db.Exec("insert into t values('hey')")
require.NoError(t, err)
n, err := result.RowsAffected()
require.NoError(t, err)
require.EqualValues(t, 1, n)
})
}
func TestConnQuery(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
require.NoError(t, err)
rowCount := int64(0)
for rows.Next() {
rowCount++
var s string
var n int64
err := rows.Scan(&s, &n)
require.NoError(t, err)
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
require.NoError(t, rows.Err())
require.EqualValues(t, 10, rowCount)
err = rows.Close()
require.NoError(t, err)
})
}
func TestConnConcurrency(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
_, err := db.Exec("create table t (id integer primary key, str text, dur_str interval)")
require.NoError(t, err)
defer func() {
_, err := db.Exec("drop table t")
require.NoError(t, err)
}()
var wg sync.WaitGroup
concurrency := 50
errChan := make(chan error, concurrency)
for i := 1; i <= concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
str := strconv.Itoa(idx)
duration := time.Duration(idx) * time.Second
_, err := db.ExecContext(ctx, "insert into t values($1)", idx)
if err != nil {
errChan <- fmt.Errorf("insert failed: %d %w", idx, err)
return
}
_, err = db.ExecContext(ctx, "update t set str = $1 where id = $2", str, idx)
if err != nil {
errChan <- fmt.Errorf("update 1 failed: %d %w", idx, err)
return
}
_, err = db.ExecContext(ctx, "update t set dur_str = $1 where id = $2", duration, idx)
if err != nil {
errChan <- fmt.Errorf("update 2 failed: %d %w", idx, err)
return
}
errChan <- nil
}(i)
}
wg.Wait()
for i := 1; i <= concurrency; i++ {
err := <-errChan
require.NoError(t, err)
}
for i := 1; i <= concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
var id int
var str string
var duration pgtype.Interval
err := db.QueryRowContext(ctx, "select id,str,dur_str from t where id = $1", idx).Scan(&id, &str, &duration)
if err != nil {
errChan <- fmt.Errorf("select failed: %d %w", idx, err)
return
}
if id != idx {
errChan <- fmt.Errorf("id mismatch: %d %d", idx, id)
return
}
if str != strconv.Itoa(idx) {
errChan <- fmt.Errorf("str mismatch: %d %s", idx, str)
return
}
expectedDuration := pgtype.Interval{
Microseconds: int64(idx) * time.Second.Microseconds(),
Valid: true,
}
if duration != expectedDuration {
errChan <- fmt.Errorf("duration mismatch: %d %v", idx, duration)
return
}
errChan <- nil
}(i)
}
wg.Wait()
for i := 1; i <= concurrency; i++ {
err := <-errChan
require.NoError(t, err)
}
})
}
// https://github.com/jackc/pgx/issues/781
func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
var s string
var b bool
rows, err := db.Query("select true, 'foo'")
require.NoError(t, err)
require.True(t, rows.Next())
require.NoError(t, rows.Scan(&b, &s))
assert.Equal(t, true, b)
assert.Equal(t, "foo", s)
})
}
func TestConnQueryNull(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
rows, err := db.Query("select $1::int", nil)
require.NoError(t, err)
rowCount := int64(0)
for rows.Next() {
rowCount++
var n sql.NullInt64
err := rows.Scan(&n)
require.NoError(t, err)
if n.Valid != false {
t.Errorf("Expected n to be null, but it was %v", n)
}
}
require.NoError(t, rows.Err())
require.EqualValues(t, 1, rowCount)
err = rows.Close()
require.NoError(t, err)
})
}
func TestConnQueryRowByteSlice(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
expected := []byte{222, 173, 190, 239}
var actual []byte
err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
})
}
func TestConnQueryFailure(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
_, err := db.Query("select 'foo")
require.Error(t, err)
require.IsType(t, new(pgconn.PgError), err)
})
}
func TestConnSimpleSlicePassThrough(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support cardinality function")
var n int64
err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n)
require.NoError(t, err)
assert.EqualValues(t, 3, n)
})
}
func TestConnQueryScanGoArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
m := pgtype.NewMap()
var a []int64
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
require.NoError(t, err)
assert.Equal(t, []int64{1, 2, 3}, a)
})
}
func TestConnQueryScanArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
m := pgtype.NewMap()
var a pgtype.Array[int64]
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a))
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a)
})
}
func TestConnQueryScanRange(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support int4range")
m := pgtype.NewMap()
var r pgtype.Range[pgtype.Int4]
err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r))
require.NoError(t, err)
assert.Equal(
t,
pgtype.Range[pgtype.Int4]{
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{Int32: 5, Valid: true},
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
r)
})
}
// Test type that pgx would handle natively in binary, but since it is not a
// database/sql native type should be passed through as a string
func TestConnQueryRowPgxBinary(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
sql := "select $1::int4[]"
expected := "{1,2,3}"
var actual string
err := db.QueryRow(sql, expected).Scan(&actual)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
})
}
func TestConnQueryRowUnknownType(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support point type")
sql := "select $1::point"
expected := "(1,2)"
var actual string
err := db.QueryRow(sql, expected).Scan(&actual)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
})
}
func TestConnQueryJSONIntoByteSlice(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
_, err := db.Exec(`
create temporary table docs(
body json not null
);
insert into docs(body) values('{"foo": "bar"}');
`)
require.NoError(t, err)
sql := `select * from docs`
expected := []byte(`{"foo": "bar"}`)
var actual []byte
err = db.QueryRow(sql).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
if !bytes.Equal(actual, expected) {
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
}
_, err = db.Exec(`drop table docs`)
require.NoError(t, err)
})
}
func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
// Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data
// that needs to escape. No way to know whether the destination is really a text compatible or a bytea.
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec(`
create temporary table docs(
body json not null
);
`)
require.NoError(t, err)
expected := []byte(`{"foo": "bar"}`)
_, err = db.Exec(`insert into docs(body) values($1)`, expected)
require.NoError(t, err)
var actual []byte
err = db.QueryRow(`select body from docs`).Scan(&actual)
require.NoError(t, err)
if !bytes.Equal(actual, expected) {
t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
}
_, err = db.Exec(`drop table docs`)
require.NoError(t, err)
}
func TestTransactionLifeCycle(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
_, err := db.Exec("create temporary table t(a varchar not null)")
require.NoError(t, err)
tx, err := db.Begin()
require.NoError(t, err)
_, err = tx.Exec("insert into t values('hi')")
require.NoError(t, err)
err = tx.Rollback()
require.NoError(t, err)
var n int64
err = db.QueryRow("select count(*) from t").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 0, n)
tx, err = db.Begin()
require.NoError(t, err)
_, err = tx.Exec("insert into t values('hi')")
require.NoError(t, err)
err = tx.Commit()
require.NoError(t, err)
err = db.QueryRow("select count(*) from t").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
})
}
func TestConnBeginTxIsolation(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server always uses serializable isolation level")
var defaultIsoLevel string
err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
require.NoError(t, err)
supportedTests := []struct {
sqlIso sql.IsolationLevel
pgIso string
}{
{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
{sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
}
for i, tt := range supportedTests {
func() {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
if err != nil {
t.Errorf("%d. BeginTx failed: %v", i, err)
return
}
defer tx.Rollback()
var pgIso string
err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
if err != nil {
t.Errorf("%d. QueryRow failed: %v", i, err)
}
if pgIso != tt.pgIso {
t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
}
}()
}
unsupportedTests := []struct {
sqlIso sql.IsolationLevel
}{
{sqlIso: sql.LevelWriteCommitted},
{sqlIso: sql.LevelLinearizable},
}
for i, tt := range unsupportedTests {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
if err == nil {
t.Errorf("%d. BeginTx should have failed", i)
tx.Rollback()
}
}
})
}
func TestConnBeginTxReadOnly(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
require.NoError(t, err)
defer tx.Rollback()
var pgReadOnly string
err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
if err != nil {
t.Errorf("QueryRow failed: %v", err)
}
if pgReadOnly != "on" {
t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
}
})
}
func TestBeginTxContextCancel(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
_, err := db.Exec("drop table if exists t")
require.NoError(t, err)
ctx, cancelFn := context.WithCancel(context.Background())
tx, err := db.BeginTx(ctx, nil)
require.NoError(t, err)
_, err = tx.Exec("create table t(id serial)")
require.NoError(t, err)
cancelFn()
err = tx.Commit()
if err != context.Canceled && err != sql.ErrTxDone {
t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
}
var n int
err = db.QueryRow("select count(*) from t").Scan(&n)
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
}
})
}
func TestConnRaw(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
conn, err := db.Conn(context.Background())
require.NoError(t, err)
var n int
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(*stdlib.Conn).Conn()
return conn.QueryRow(context.Background(), "select 42").Scan(&n)
})
require.NoError(t, err)
assert.EqualValues(t, 42, n)
})
}
func TestConnPingContextSuccess(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
err := db.PingContext(context.Background())
require.NoError(t, err)
})
}
func TestConnPrepareContextSuccess(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
stmt, err := db.PrepareContext(context.Background(), "select now()")
require.NoError(t, err)
err = stmt.Close()
require.NoError(t, err)
})
}
// https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281
// https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634
func TestConnMultiplePrepareAndDeallocate(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
sql := "select 42"
stmt1, err := db.PrepareContext(context.Background(), sql)
require.NoError(t, err)
stmt2, err := db.PrepareContext(context.Background(), sql)
require.NoError(t, err)
err = stmt1.Close()
require.NoError(t, err)
var preparedStmtCount int64
err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
require.NoError(t, err)
require.EqualValues(t, 1, preparedStmtCount)
err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate.
require.NoError(t, err)
err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
require.NoError(t, err)
require.EqualValues(t, 0, preparedStmtCount)
})
}
func TestConnExecContextSuccess(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
require.NoError(t, err)
})
}
func TestConnQueryContextSuccess(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
require.NoError(t, err)
for rows.Next() {
var n int64
err := rows.Scan(&n)
require.NoError(t, err)
}
require.NoError(t, rows.Err())
})
}
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
rows, err := db.Query("select 42::bigint")
require.NoError(t, err)
columnTypes, err := rows.ColumnTypes()
require.NoError(t, err)
require.Len(t, columnTypes, 1)
if columnTypes[0].DatabaseTypeName() != "INT8" {
t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8")
}
err = rows.Close()
require.NoError(t, err)
})
}
func TestStmtExecContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("create temporary table t(id int primary key)")
require.NoError(t, err)
stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
require.NoError(t, err)
defer stmt.Close()
_, err = stmt.ExecContext(context.Background(), 42)
require.NoError(t, err)
ensureDBValid(t, db)
}
func TestStmtExecContextCancel(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
_, err := db.Exec("create temporary table t(id int primary key)")
require.NoError(t, err)
stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
require.NoError(t, err)
defer stmt.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = stmt.ExecContext(ctx, 42)
if !pgconn.Timeout(err) {
t.Errorf("expected timeout error, got %v", err)
}
ensureDBValid(t, db)
}
func TestStmtQueryContextSuccess(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
require.NoError(t, err)
defer stmt.Close()
rows, err := stmt.QueryContext(context.Background(), 5)
require.NoError(t, err)
for rows.Next() {
var n int64
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
}
ensureDBValid(t, db)
}
func TestRowsColumnTypes(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
columnTypesTests := []struct {
Name string
TypeName string
Length struct {
Len int64
OK bool
}
DecimalSize struct {
Precision int64
Scale int64
OK bool
}
ScanType reflect.Type
}{
{
Name: "a",
TypeName: "INT8",
Length: struct {
Len int64
OK bool
}{
Len: 0,
OK: false,
},
DecimalSize: struct {
Precision int64
Scale int64
OK bool
}{
Precision: 0,
Scale: 0,
OK: false,
},
ScanType: reflect.TypeOf(int64(0)),
}, {
Name: "bar",
TypeName: "TEXT",
Length: struct {
Len int64
OK bool
}{
Len: math.MaxInt64,
OK: true,
},
DecimalSize: struct {
Precision int64
Scale int64
OK bool
}{
Precision: 0,
Scale: 0,
OK: false,
},
ScanType: reflect.TypeOf(""),
}, {
Name: "dec",
TypeName: "NUMERIC",
Length: struct {
Len int64
OK bool
}{
Len: 0,
OK: false,
},
DecimalSize: struct {
Precision int64
Scale int64
OK bool
}{
Precision: 9,
Scale: 2,
OK: true,
},
ScanType: reflect.TypeOf(float64(0)),
}, {
Name: "d",
TypeName: "1266",
Length: struct {
Len int64
OK bool
}{
Len: 0,
OK: false,
},
DecimalSize: struct {
Precision int64
Scale int64
OK bool
}{
Precision: 0,
Scale: 0,
OK: false,
},
ScanType: reflect.TypeOf(""),
},
}
rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d")
require.NoError(t, err)
columns, err := rows.ColumnTypes()
require.NoError(t, err)
assert.Len(t, columns, 4)
for i, tt := range columnTypesTests {
c := columns[i]
if c.Name() != tt.Name {
t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
}
if c.DatabaseTypeName() != tt.TypeName {
t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
}
l, ok := c.Length()
if l != tt.Length.Len {
t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
}
if ok != tt.Length.OK {
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
}
p, s, ok := c.DecimalSize()
if p != tt.DecimalSize.Precision {
t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
}
if s != tt.DecimalSize.Scale {
t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
}
if ok != tt.DecimalSize.OK {
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
}
if c.ScanType() != tt.ScanType {
t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
}
}
})
}
func TestQueryLifeCycle(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
require.NoError(t, err)
rowCount := int64(0)
for rows.Next() {
rowCount++
var (
s string
n int64
)
err := rows.Scan(&s, &n)
require.NoError(t, err)
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
require.NoError(t, rows.Err())
err = rows.Close()
require.NoError(t, err)
rows, err = db.Query("select 1 where false")
require.NoError(t, err)
rowCount = int64(0)
for rows.Next() {
rowCount++
}
require.NoError(t, rows.Err())
require.EqualValues(t, 0, rowCount)
err = rows.Close()
require.NoError(t, err)
})
}
// https://github.com/jackc/pgx/issues/409
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
var msg json.RawMessage
err := db.QueryRow("select '{}'::json").Scan(&msg)
require.NoError(t, err)
require.EqualValues(t, []byte("{}"), []byte(msg))
})
}
type testLog struct {
lvl tracelog.LogLevel
msg string
data map[string]any
}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) {
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
}
func TestRegisterConnConfig(t *testing.T) {
connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
logger := &testLogger{}
connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo}
// Issue 947: Register and unregister a ConnConfig and ensure that the
// returned connection string is not reused.
connStr := stdlib.RegisterConnConfig(connConfig)
require.Equal(t, "registeredConnConfig0", connStr)
stdlib.UnregisterConnConfig(connStr)
connStr = stdlib.RegisterConnConfig(connConfig)
defer stdlib.UnregisterConnConfig(connStr)
require.Equal(t, "registeredConnConfig1", connStr)
db, err := sql.Open("pgx", connStr)
require.NoError(t, err)
defer closeDB(t, db)
var n int64
err = db.QueryRow("select 1").Scan(&n)
require.NoError(t, err)
l := logger.logs[len(logger.logs)-1]
assert.Equal(t, "Query", l.msg)
assert.Equal(t, "select 1", l.data["sql"])
}
// https://github.com/jackc/pgx/issues/958
func TestConnQueryRowConstraintErrors(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipPostgreSQLVersionLessThan(t, db, 11)
skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
_, err := db.Exec(`create temporary table defer_test (
id text primary key,
n int not null, unique (n),
unique (n) deferrable initially deferred )`)
require.NoError(t, err)
_, err = db.Exec(`drop function if exists test_trigger cascade`)
require.NoError(t, err)
_, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$
begin
if new.n = 4 then
raise exception 'n cant be 4!';
end if;
return new;
end$$`)
require.NoError(t, err)
_, err = db.Exec(`create constraint trigger test
after insert or update on defer_test
deferrable initially deferred
for each row
execute function test_trigger()`)
require.NoError(t, err)
_, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`)
require.NoError(t, err)
var id string
err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id)
assert.Error(t, err)
})
}
func TestOptionBeforeAfterConnect(t *testing.T) {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
var beforeConnConfigs []*pgx.ConnConfig
var afterConns []*pgx.Conn
db := stdlib.OpenDB(*config,
stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error {
beforeConnConfigs = append(beforeConnConfigs, connConfig)
return nil
}),
stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error {
afterConns = append(afterConns, conn)
return nil
}))
defer closeDB(t, db)
// Force it to close and reopen a new connection after each query
db.SetMaxIdleConns(0)
_, err = db.Exec("select 1")
require.NoError(t, err)
_, err = db.Exec("select 1")
require.NoError(t, err)
require.Len(t, beforeConnConfigs, 2)
require.Len(t, afterConns, 2)
// Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they
// are different objects, so can't use require.NotEqual
require.False(t, config == beforeConnConfigs[0])
require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1])
}
func TestRandomizeHostOrderFunc(t *testing.T) {
config, err := pgx.ParseConfig("postgres://host1,host2,host3")
require.NoError(t, err)
// Test that at some point we connect to all 3 hosts
hostsNotSeenYet := map[string]struct{}{
"host1": {},
"host2": {},
"host3": {},
}
// If we don't succeed within this many iterations, something is certainly wrong
for i := 0; i < 100000; i++ {
connCopy := *config
stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy)
delete(hostsNotSeenYet, connCopy.Host)
if len(hostsNotSeenYet) == 0 {
return
}
hostCheckLoop:
for _, h := range []string{"host1", "host2", "host3"} {
if connCopy.Host == h {
continue
}
for _, f := range connCopy.Fallbacks {
if f.Host == h {
continue hostCheckLoop
}
}
require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy)
}
}
require.Fail(t, "did not get all hosts as primaries after many randomizations")
}
func TestResetSessionHookCalled(t *testing.T) {
var mockCalled bool
connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error {
mockCalled = true
return nil
}))
defer closeDB(t, db)
err = db.Ping()
require.NoError(t, err)
err = db.Ping()
require.NoError(t, err)
require.True(t, mockCalled)
}
func TestCheckIdleConn(t *testing.T) {
controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeDB(t, controllerConn)
skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeDB(t, db)
var conns []*sql.Conn
for i := 0; i < 3; i++ {
c, err := db.Conn(context.Background())
require.NoError(t, err)
conns = append(conns, c)
}
require.EqualValues(t, 3, db.Stats().OpenConnections)
var pids []uint32
for _, c := range conns {
err := c.Raw(func(driverConn any) error {
pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID())
return nil
})
require.NoError(t, err)
err = c.Close()
require.NoError(t, err)
}
// The database/sql connection pool seems to automatically close idle connections to only keep 2 alive.
// require.EqualValues(t, 3, db.Stats().OpenConnections)
_, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids)
require.NoError(t, err)
// All conns are dead they don't know it and neither does the pool. But because of database/sql automatically closing
// idle connections we can't be sure how many we should have. require.EqualValues(t, 3, db.Stats().OpenConnections)
// Wait long enough so the pool will realize it needs to check the connections.
time.Sleep(time.Second)
// Pool should try all existing connections and find them dead, then create a new connection which should successfully ping.
err = db.PingContext(context.Background())
require.NoError(t, err)
// The original 3 conns should have been terminated and the a new conn established for the ping.
require.EqualValues(t, 1, db.Stats().OpenConnections)
c, err := db.Conn(context.Background())
require.NoError(t, err)
var cPID uint32
err = c.Raw(func(driverConn any) error {
cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID()
return nil
})
require.NoError(t, err)
err = c.Close()
require.NoError(t, err)
require.NotContains(t, pids, cPID)
}