7a2b93323c
This error was introduced by 0f0d236599.
If the same statement was prepared multiple times then whenever Close
was called on one of the statements the underlying prepared statement
would be closed even if other statements were still using it.
https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634
1375 lines
35 KiB
Go
1375 lines
35 KiB
Go
package stdlib_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"os"
|
|
"reflect"
|
|
"regexp"
|
|
"strconv"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/jackc/pgx/v5/stdlib"
|
|
"github.com/jackc/pgx/v5/tracelog"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func openDB(t testing.TB) *sql.DB {
|
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
return stdlib.OpenDB(*config)
|
|
}
|
|
|
|
func closeDB(t testing.TB, db *sql.DB) {
|
|
err := db.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func skipCockroachDB(t testing.TB, db *sql.DB, msg string) {
|
|
conn, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
err = conn.Raw(func(driverConn any) error {
|
|
conn := driverConn.(*stdlib.Conn).Conn()
|
|
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
|
t.Skip(msg)
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) {
|
|
conn, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
err = conn.Raw(func(driverConn any) error {
|
|
conn := driverConn.(*stdlib.Conn).Conn()
|
|
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
|
|
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
|
|
// if not PostgreSQL do nothing
|
|
if serverVersionStr == "" {
|
|
return nil
|
|
}
|
|
|
|
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if serverVersion < minVersion {
|
|
t.Skipf("Test requires PostgreSQL v%d+", minVersion)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) {
|
|
for _, mode := range []pgx.QueryExecMode{
|
|
pgx.QueryExecModeCacheStatement,
|
|
pgx.QueryExecModeCacheDescribe,
|
|
pgx.QueryExecModeDescribeExec,
|
|
pgx.QueryExecModeExec,
|
|
pgx.QueryExecModeSimpleProtocol,
|
|
} {
|
|
t.Run(mode.String(),
|
|
func(t *testing.T) {
|
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
config.DefaultQueryExecMode = mode
|
|
db := stdlib.OpenDB(*config)
|
|
defer func() {
|
|
err := db.Close()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
f(t, db)
|
|
|
|
ensureDBValid(t, db)
|
|
},
|
|
)
|
|
}
|
|
}
|
|
|
|
// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
|
|
// cover broken connections.
|
|
func ensureDBValid(t testing.TB, db *sql.DB) {
|
|
var sum, rowCount int32
|
|
|
|
rows, err := db.Query("select generate_series(1,$1)", 10)
|
|
require.NoError(t, err)
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var n int32
|
|
rows.Scan(&n)
|
|
sum += n
|
|
rowCount++
|
|
}
|
|
|
|
require.NoError(t, rows.Err())
|
|
|
|
if rowCount != 10 {
|
|
t.Error("Select called onDataRow wrong number of times")
|
|
}
|
|
if sum != 55 {
|
|
t.Error("Wrong values returned")
|
|
}
|
|
}
|
|
|
|
type preparer interface {
|
|
Prepare(query string) (*sql.Stmt, error)
|
|
}
|
|
|
|
func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
|
|
stmt, err := p.Prepare(sql)
|
|
require.NoError(t, err)
|
|
return stmt
|
|
}
|
|
|
|
func closeStmt(t *testing.T, stmt *sql.Stmt) {
|
|
err := stmt.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestSQLOpen(t *testing.T) {
|
|
tests := []struct {
|
|
driverName string
|
|
}{
|
|
{driverName: "pgx"},
|
|
{driverName: "pgx/v5"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.driverName, func(t *testing.T) {
|
|
db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
closeDB(t, db)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSQLOpenFromPool(t *testing.T) {
|
|
pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
t.Cleanup(pool.Close)
|
|
|
|
db := stdlib.OpenDBFromPool(pool)
|
|
ensureDBValid(t, db)
|
|
|
|
db.Close()
|
|
}
|
|
|
|
func TestNormalLifeCycle(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
|
|
defer closeStmt(t, stmt)
|
|
|
|
rows, err := stmt.Query(int32(1), int32(10))
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
err := rows.Scan(&s, &n)
|
|
require.NoError(t, err)
|
|
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
|
|
require.EqualValues(t, 10, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestStmtExec(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
tx, err := db.Begin()
|
|
require.NoError(t, err)
|
|
|
|
createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
|
|
_, err = createStmt.Exec()
|
|
require.NoError(t, err)
|
|
closeStmt(t, createStmt)
|
|
|
|
insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
|
|
result, err := insertStmt.Exec("foo")
|
|
require.NoError(t, err)
|
|
|
|
n, err := result.RowsAffected()
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, n)
|
|
closeStmt(t, insertStmt)
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestQueryCloseRowsEarly(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
|
|
defer closeStmt(t, stmt)
|
|
|
|
rows, err := stmt.Query(int32(1), int32(10))
|
|
require.NoError(t, err)
|
|
|
|
// Close rows immediately without having read them
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
|
|
// Run the query again to ensure the connection and statement are still ok
|
|
rows, err = stmt.Query(int32(1), int32(10))
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
err := rows.Scan(&s, &n)
|
|
require.NoError(t, err)
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
require.EqualValues(t, 10, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestConnExec(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec("create temporary table t(a varchar not null)")
|
|
require.NoError(t, err)
|
|
|
|
result, err := db.Exec("insert into t values('hey')")
|
|
require.NoError(t, err)
|
|
|
|
n, err := result.RowsAffected()
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, n)
|
|
})
|
|
}
|
|
|
|
func TestConnQuery(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
err := rows.Scan(&s, &n)
|
|
require.NoError(t, err)
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
require.EqualValues(t, 10, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnConcurrency(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec("create table t (id integer primary key, str text, dur_str interval)")
|
|
require.NoError(t, err)
|
|
|
|
defer func() {
|
|
_, err := db.Exec("drop table t")
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
concurrency := 50
|
|
errChan := make(chan error, concurrency)
|
|
|
|
for i := 1; i <= concurrency; i++ {
|
|
wg.Add(1)
|
|
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
str := strconv.Itoa(idx)
|
|
duration := time.Duration(idx) * time.Second
|
|
_, err := db.ExecContext(ctx, "insert into t values($1)", idx)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("insert failed: %d %w", idx, err)
|
|
return
|
|
}
|
|
_, err = db.ExecContext(ctx, "update t set str = $1 where id = $2", str, idx)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("update 1 failed: %d %w", idx, err)
|
|
return
|
|
}
|
|
_, err = db.ExecContext(ctx, "update t set dur_str = $1 where id = $2", duration, idx)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("update 2 failed: %d %w", idx, err)
|
|
return
|
|
}
|
|
|
|
errChan <- nil
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
for i := 1; i <= concurrency; i++ {
|
|
err := <-errChan
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
for i := 1; i <= concurrency; i++ {
|
|
wg.Add(1)
|
|
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
var id int
|
|
var str string
|
|
var duration pgtype.Interval
|
|
err := db.QueryRowContext(ctx, "select id,str,dur_str from t where id = $1", idx).Scan(&id, &str, &duration)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("select failed: %d %w", idx, err)
|
|
return
|
|
}
|
|
if id != idx {
|
|
errChan <- fmt.Errorf("id mismatch: %d %d", idx, id)
|
|
return
|
|
}
|
|
if str != strconv.Itoa(idx) {
|
|
errChan <- fmt.Errorf("str mismatch: %d %s", idx, str)
|
|
return
|
|
}
|
|
expectedDuration := pgtype.Interval{
|
|
Microseconds: int64(idx) * time.Second.Microseconds(),
|
|
Valid: true,
|
|
}
|
|
if duration != expectedDuration {
|
|
errChan <- fmt.Errorf("duration mismatch: %d %v", idx, duration)
|
|
return
|
|
}
|
|
|
|
errChan <- nil
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
for i := 1; i <= concurrency; i++ {
|
|
err := <-errChan
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/781
|
|
func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
var s string
|
|
var b bool
|
|
|
|
rows, err := db.Query("select true, 'foo'")
|
|
require.NoError(t, err)
|
|
|
|
require.True(t, rows.Next())
|
|
require.NoError(t, rows.Scan(&b, &s))
|
|
assert.Equal(t, true, b)
|
|
assert.Equal(t, "foo", s)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryNull(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
rows, err := db.Query("select $1::int", nil)
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var n sql.NullInt64
|
|
err := rows.Scan(&n)
|
|
require.NoError(t, err)
|
|
if n.Valid != false {
|
|
t.Errorf("Expected n to be null, but it was %v", n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
require.EqualValues(t, 1, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryRowByteSlice(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
expected := []byte{222, 173, 190, 239}
|
|
var actual []byte
|
|
|
|
err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, expected, actual)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryFailure(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Query("select 'foo")
|
|
require.Error(t, err)
|
|
require.IsType(t, new(pgconn.PgError), err)
|
|
})
|
|
}
|
|
|
|
func TestConnSimpleSlicePassThrough(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server does not support cardinality function")
|
|
|
|
var n int64
|
|
err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n)
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, 3, n)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryScanGoArray(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
m := pgtype.NewMap()
|
|
|
|
var a []int64
|
|
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, []int64{1, 2, 3}, a)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryScanArray(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
m := pgtype.NewMap()
|
|
|
|
var a pgtype.Array[int64]
|
|
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
|
|
|
|
err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryScanRange(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server does not support int4range")
|
|
|
|
m := pgtype.NewMap()
|
|
|
|
var r pgtype.Range[pgtype.Int4]
|
|
err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r))
|
|
require.NoError(t, err)
|
|
assert.Equal(
|
|
t,
|
|
pgtype.Range[pgtype.Int4]{
|
|
Lower: pgtype.Int4{Int32: 1, Valid: true},
|
|
Upper: pgtype.Int4{Int32: 5, Valid: true},
|
|
LowerType: pgtype.Inclusive,
|
|
UpperType: pgtype.Exclusive,
|
|
Valid: true,
|
|
},
|
|
r)
|
|
})
|
|
}
|
|
|
|
// Test type that pgx would handle natively in binary, but since it is not a
|
|
// database/sql native type should be passed through as a string
|
|
func TestConnQueryRowPgxBinary(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
sql := "select $1::int4[]"
|
|
expected := "{1,2,3}"
|
|
var actual string
|
|
|
|
err := db.QueryRow(sql, expected).Scan(&actual)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, expected, actual)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryRowUnknownType(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server does not support point type")
|
|
|
|
sql := "select $1::point"
|
|
expected := "(1,2)"
|
|
var actual string
|
|
|
|
err := db.QueryRow(sql, expected).Scan(&actual)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, expected, actual)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryJSONIntoByteSlice(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec(`
|
|
create temporary table docs(
|
|
body json not null
|
|
);
|
|
|
|
insert into docs(body) values('{"foo": "bar"}');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
sql := `select * from docs`
|
|
expected := []byte(`{"foo": "bar"}`)
|
|
var actual []byte
|
|
|
|
err = db.QueryRow(sql).Scan(&actual)
|
|
if err != nil {
|
|
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
|
}
|
|
|
|
if !bytes.Equal(actual, expected) {
|
|
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
|
|
}
|
|
|
|
_, err = db.Exec(`drop table docs`)
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
|
|
// Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data
|
|
// that needs to escape. No way to know whether the destination is really a text compatible or a bytea.
|
|
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec(`
|
|
create temporary table docs(
|
|
body json not null
|
|
);
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
expected := []byte(`{"foo": "bar"}`)
|
|
|
|
_, err = db.Exec(`insert into docs(body) values($1)`, expected)
|
|
require.NoError(t, err)
|
|
|
|
var actual []byte
|
|
err = db.QueryRow(`select body from docs`).Scan(&actual)
|
|
require.NoError(t, err)
|
|
|
|
if !bytes.Equal(actual, expected) {
|
|
t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
|
|
}
|
|
|
|
_, err = db.Exec(`drop table docs`)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestTransactionLifeCycle(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec("create temporary table t(a varchar not null)")
|
|
require.NoError(t, err)
|
|
|
|
tx, err := db.Begin()
|
|
require.NoError(t, err)
|
|
|
|
_, err = tx.Exec("insert into t values('hi')")
|
|
require.NoError(t, err)
|
|
|
|
err = tx.Rollback()
|
|
require.NoError(t, err)
|
|
|
|
var n int64
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 0, n)
|
|
|
|
tx, err = db.Begin()
|
|
require.NoError(t, err)
|
|
|
|
_, err = tx.Exec("insert into t values('hi')")
|
|
require.NoError(t, err)
|
|
|
|
err = tx.Commit()
|
|
require.NoError(t, err)
|
|
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, n)
|
|
})
|
|
}
|
|
|
|
func TestConnBeginTxIsolation(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server always uses serializable isolation level")
|
|
|
|
var defaultIsoLevel string
|
|
err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
|
|
require.NoError(t, err)
|
|
|
|
supportedTests := []struct {
|
|
sqlIso sql.IsolationLevel
|
|
pgIso string
|
|
}{
|
|
{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
|
|
{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
|
|
{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
|
|
{sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
|
|
{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
|
|
{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
|
|
}
|
|
for i, tt := range supportedTests {
|
|
func() {
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
|
|
if err != nil {
|
|
t.Errorf("%d. BeginTx failed: %v", i, err)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var pgIso string
|
|
err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
|
|
if err != nil {
|
|
t.Errorf("%d. QueryRow failed: %v", i, err)
|
|
}
|
|
|
|
if pgIso != tt.pgIso {
|
|
t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
|
|
}
|
|
}()
|
|
}
|
|
|
|
unsupportedTests := []struct {
|
|
sqlIso sql.IsolationLevel
|
|
}{
|
|
{sqlIso: sql.LevelWriteCommitted},
|
|
{sqlIso: sql.LevelLinearizable},
|
|
}
|
|
for i, tt := range unsupportedTests {
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
|
|
if err == nil {
|
|
t.Errorf("%d. BeginTx should have failed", i)
|
|
tx.Rollback()
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConnBeginTxReadOnly(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
|
|
require.NoError(t, err)
|
|
defer tx.Rollback()
|
|
|
|
var pgReadOnly string
|
|
err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
|
|
if err != nil {
|
|
t.Errorf("QueryRow failed: %v", err)
|
|
}
|
|
|
|
if pgReadOnly != "on" {
|
|
t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBeginTxContextCancel(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec("drop table if exists t")
|
|
require.NoError(t, err)
|
|
|
|
ctx, cancelFn := context.WithCancel(context.Background())
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
require.NoError(t, err)
|
|
|
|
_, err = tx.Exec("create table t(id serial)")
|
|
require.NoError(t, err)
|
|
|
|
cancelFn()
|
|
|
|
err = tx.Commit()
|
|
if err != context.Canceled && err != sql.ErrTxDone {
|
|
t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
|
|
}
|
|
|
|
var n int
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
|
|
t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConnRaw(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
conn, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
var n int
|
|
err = conn.Raw(func(driverConn any) error {
|
|
conn := driverConn.(*stdlib.Conn).Conn()
|
|
return conn.QueryRow(context.Background(), "select 42").Scan(&n)
|
|
})
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, 42, n)
|
|
})
|
|
}
|
|
|
|
func TestConnPingContextSuccess(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
err := db.PingContext(context.Background())
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnPrepareContextSuccess(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
stmt, err := db.PrepareContext(context.Background(), "select now()")
|
|
require.NoError(t, err)
|
|
err = stmt.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281
|
|
// https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634
|
|
func TestConnMultiplePrepareAndDeallocate(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
sql := "select 42"
|
|
stmt1, err := db.PrepareContext(context.Background(), sql)
|
|
require.NoError(t, err)
|
|
stmt2, err := db.PrepareContext(context.Background(), sql)
|
|
require.NoError(t, err)
|
|
err = stmt1.Close()
|
|
require.NoError(t, err)
|
|
|
|
var preparedStmtCount int64
|
|
err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, preparedStmtCount)
|
|
|
|
err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate.
|
|
require.NoError(t, err)
|
|
|
|
err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 0, preparedStmtCount)
|
|
})
|
|
}
|
|
|
|
func TestConnExecContextSuccess(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryContextSuccess(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
|
|
require.NoError(t, err)
|
|
|
|
for rows.Next() {
|
|
var n int64
|
|
err := rows.Scan(&n)
|
|
require.NoError(t, err)
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
})
|
|
}
|
|
|
|
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
rows, err := db.Query("select 42::bigint")
|
|
require.NoError(t, err)
|
|
|
|
columnTypes, err := rows.ColumnTypes()
|
|
require.NoError(t, err)
|
|
require.Len(t, columnTypes, 1)
|
|
|
|
if columnTypes[0].DatabaseTypeName() != "INT8" {
|
|
t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8")
|
|
}
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestStmtExecContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("create temporary table t(id int primary key)")
|
|
require.NoError(t, err)
|
|
|
|
stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
|
|
require.NoError(t, err)
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.ExecContext(context.Background(), 42)
|
|
require.NoError(t, err)
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestStmtExecContextCancel(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("create temporary table t(id int primary key)")
|
|
require.NoError(t, err)
|
|
|
|
stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
|
|
require.NoError(t, err)
|
|
defer stmt.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = stmt.ExecContext(ctx, 42)
|
|
if !pgconn.Timeout(err) {
|
|
t.Errorf("expected timeout error, got %v", err)
|
|
}
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestStmtQueryContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
|
|
require.NoError(t, err)
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.QueryContext(context.Background(), 5)
|
|
require.NoError(t, err)
|
|
|
|
for rows.Next() {
|
|
var n int64
|
|
if err := rows.Scan(&n); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
t.Error(rows.Err())
|
|
}
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestRowsColumnTypes(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
columnTypesTests := []struct {
|
|
Name string
|
|
TypeName string
|
|
Length struct {
|
|
Len int64
|
|
OK bool
|
|
}
|
|
DecimalSize struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}
|
|
ScanType reflect.Type
|
|
}{
|
|
{
|
|
Name: "a",
|
|
TypeName: "INT8",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: 0,
|
|
OK: false,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 0,
|
|
Scale: 0,
|
|
OK: false,
|
|
},
|
|
ScanType: reflect.TypeOf(int64(0)),
|
|
}, {
|
|
Name: "bar",
|
|
TypeName: "TEXT",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: math.MaxInt64,
|
|
OK: true,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 0,
|
|
Scale: 0,
|
|
OK: false,
|
|
},
|
|
ScanType: reflect.TypeOf(""),
|
|
}, {
|
|
Name: "dec",
|
|
TypeName: "NUMERIC",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: 0,
|
|
OK: false,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 9,
|
|
Scale: 2,
|
|
OK: true,
|
|
},
|
|
ScanType: reflect.TypeOf(float64(0)),
|
|
}, {
|
|
Name: "d",
|
|
TypeName: "1266",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: 0,
|
|
OK: false,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 0,
|
|
Scale: 0,
|
|
OK: false,
|
|
},
|
|
ScanType: reflect.TypeOf(""),
|
|
},
|
|
}
|
|
|
|
rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d")
|
|
require.NoError(t, err)
|
|
|
|
columns, err := rows.ColumnTypes()
|
|
require.NoError(t, err)
|
|
assert.Len(t, columns, 4)
|
|
|
|
for i, tt := range columnTypesTests {
|
|
c := columns[i]
|
|
if c.Name() != tt.Name {
|
|
t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
|
|
}
|
|
if c.DatabaseTypeName() != tt.TypeName {
|
|
t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
|
|
}
|
|
l, ok := c.Length()
|
|
if l != tt.Length.Len {
|
|
t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
|
|
}
|
|
if ok != tt.Length.OK {
|
|
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
|
|
}
|
|
p, s, ok := c.DecimalSize()
|
|
if p != tt.DecimalSize.Precision {
|
|
t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
|
|
}
|
|
if s != tt.DecimalSize.Scale {
|
|
t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
|
|
}
|
|
if ok != tt.DecimalSize.OK {
|
|
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
|
|
}
|
|
if c.ScanType() != tt.ScanType {
|
|
t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestQueryLifeCycle(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
var (
|
|
s string
|
|
n int64
|
|
)
|
|
|
|
err := rows.Scan(&s, &n)
|
|
require.NoError(t, err)
|
|
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
|
|
rows, err = db.Query("select 1 where false")
|
|
require.NoError(t, err)
|
|
|
|
rowCount = int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
require.EqualValues(t, 0, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/409
|
|
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
var msg json.RawMessage
|
|
|
|
err := db.QueryRow("select '{}'::json").Scan(&msg)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, []byte("{}"), []byte(msg))
|
|
})
|
|
}
|
|
|
|
type testLog struct {
|
|
lvl tracelog.LogLevel
|
|
msg string
|
|
data map[string]any
|
|
}
|
|
|
|
type testLogger struct {
|
|
logs []testLog
|
|
}
|
|
|
|
func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) {
|
|
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
|
|
}
|
|
|
|
func TestRegisterConnConfig(t *testing.T) {
|
|
connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
logger := &testLogger{}
|
|
connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo}
|
|
|
|
// Issue 947: Register and unregister a ConnConfig and ensure that the
|
|
// returned connection string is not reused.
|
|
connStr := stdlib.RegisterConnConfig(connConfig)
|
|
require.Equal(t, "registeredConnConfig0", connStr)
|
|
stdlib.UnregisterConnConfig(connStr)
|
|
|
|
connStr = stdlib.RegisterConnConfig(connConfig)
|
|
defer stdlib.UnregisterConnConfig(connStr)
|
|
require.Equal(t, "registeredConnConfig1", connStr)
|
|
|
|
db, err := sql.Open("pgx", connStr)
|
|
require.NoError(t, err)
|
|
defer closeDB(t, db)
|
|
|
|
var n int64
|
|
err = db.QueryRow("select 1").Scan(&n)
|
|
require.NoError(t, err)
|
|
|
|
l := logger.logs[len(logger.logs)-1]
|
|
assert.Equal(t, "Query", l.msg)
|
|
assert.Equal(t, "select 1", l.data["sql"])
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/958
|
|
func TestConnQueryRowConstraintErrors(t *testing.T) {
|
|
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
|
skipPostgreSQLVersionLessThan(t, db, 11)
|
|
skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
|
|
|
_, err := db.Exec(`create temporary table defer_test (
|
|
id text primary key,
|
|
n int not null, unique (n),
|
|
unique (n) deferrable initially deferred )`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec(`drop function if exists test_trigger cascade`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$
|
|
begin
|
|
if new.n = 4 then
|
|
raise exception 'n cant be 4!';
|
|
end if;
|
|
return new;
|
|
end$$`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec(`create constraint trigger test
|
|
after insert or update on defer_test
|
|
deferrable initially deferred
|
|
for each row
|
|
execute function test_trigger()`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`)
|
|
require.NoError(t, err)
|
|
|
|
var id string
|
|
err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestOptionBeforeAfterConnect(t *testing.T) {
|
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
var beforeConnConfigs []*pgx.ConnConfig
|
|
var afterConns []*pgx.Conn
|
|
db := stdlib.OpenDB(*config,
|
|
stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error {
|
|
beforeConnConfigs = append(beforeConnConfigs, connConfig)
|
|
return nil
|
|
}),
|
|
stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error {
|
|
afterConns = append(afterConns, conn)
|
|
return nil
|
|
}))
|
|
defer closeDB(t, db)
|
|
|
|
// Force it to close and reopen a new connection after each query
|
|
db.SetMaxIdleConns(0)
|
|
|
|
_, err = db.Exec("select 1")
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec("select 1")
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, beforeConnConfigs, 2)
|
|
require.Len(t, afterConns, 2)
|
|
|
|
// Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they
|
|
// are different objects, so can't use require.NotEqual
|
|
require.False(t, config == beforeConnConfigs[0])
|
|
require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1])
|
|
}
|
|
|
|
func TestRandomizeHostOrderFunc(t *testing.T) {
|
|
config, err := pgx.ParseConfig("postgres://host1,host2,host3")
|
|
require.NoError(t, err)
|
|
|
|
// Test that at some point we connect to all 3 hosts
|
|
hostsNotSeenYet := map[string]struct{}{
|
|
"host1": {},
|
|
"host2": {},
|
|
"host3": {},
|
|
}
|
|
|
|
// If we don't succeed within this many iterations, something is certainly wrong
|
|
for i := 0; i < 100000; i++ {
|
|
connCopy := *config
|
|
stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy)
|
|
|
|
delete(hostsNotSeenYet, connCopy.Host)
|
|
if len(hostsNotSeenYet) == 0 {
|
|
return
|
|
}
|
|
|
|
hostCheckLoop:
|
|
for _, h := range []string{"host1", "host2", "host3"} {
|
|
if connCopy.Host == h {
|
|
continue
|
|
}
|
|
for _, f := range connCopy.Fallbacks {
|
|
if f.Host == h {
|
|
continue hostCheckLoop
|
|
}
|
|
}
|
|
require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy)
|
|
}
|
|
}
|
|
|
|
require.Fail(t, "did not get all hosts as primaries after many randomizations")
|
|
}
|
|
|
|
func TestResetSessionHookCalled(t *testing.T) {
|
|
var mockCalled bool
|
|
|
|
connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error {
|
|
mockCalled = true
|
|
|
|
return nil
|
|
}))
|
|
|
|
defer closeDB(t, db)
|
|
|
|
err = db.Ping()
|
|
require.NoError(t, err)
|
|
|
|
err = db.Ping()
|
|
require.NoError(t, err)
|
|
|
|
require.True(t, mockCalled)
|
|
}
|
|
|
|
func TestCheckIdleConn(t *testing.T) {
|
|
controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeDB(t, controllerConn)
|
|
|
|
skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
|
|
|
db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeDB(t, db)
|
|
|
|
var conns []*sql.Conn
|
|
for i := 0; i < 3; i++ {
|
|
c, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
conns = append(conns, c)
|
|
}
|
|
|
|
require.EqualValues(t, 3, db.Stats().OpenConnections)
|
|
|
|
var pids []uint32
|
|
for _, c := range conns {
|
|
err := c.Raw(func(driverConn any) error {
|
|
pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID())
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
err = c.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// The database/sql connection pool seems to automatically close idle connections to only keep 2 alive.
|
|
// require.EqualValues(t, 3, db.Stats().OpenConnections)
|
|
|
|
_, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids)
|
|
require.NoError(t, err)
|
|
|
|
// All conns are dead they don't know it and neither does the pool. But because of database/sql automatically closing
|
|
// idle connections we can't be sure how many we should have. require.EqualValues(t, 3, db.Stats().OpenConnections)
|
|
|
|
// Wait long enough so the pool will realize it needs to check the connections.
|
|
time.Sleep(time.Second)
|
|
|
|
// Pool should try all existing connections and find them dead, then create a new connection which should successfully ping.
|
|
err = db.PingContext(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
// The original 3 conns should have been terminated and the a new conn established for the ping.
|
|
require.EqualValues(t, 1, db.Stats().OpenConnections)
|
|
c, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
var cPID uint32
|
|
err = c.Raw(func(driverConn any) error {
|
|
cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID()
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
err = c.Close()
|
|
require.NoError(t, err)
|
|
|
|
require.NotContains(t, pids, cPID)
|
|
}
|