3ab8941921
Previously, stdlib.RegisterConnConfig would sometimes reuse the same connection string for different ConnConfig options (specifically, it happened when a connection was open and then closed, and then a new, different connection was opened). This behavior interferes with callers that expect that two connections with the same data source name are connecting to the same backend database in the same way. This fix updates stdlib.RegisterConnConfig to use an incrementing sequence counter to uniquify all returned connection strings. Fixes #947
1132 lines
28 KiB
Go
1132 lines
28 KiB
Go
package stdlib_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"math"
|
|
"os"
|
|
"reflect"
|
|
"regexp"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Masterminds/semver/v3"
|
|
"github.com/jackc/pgconn"
|
|
"github.com/jackc/pgx/v4"
|
|
"github.com/jackc/pgx/v4/stdlib"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func openDB(t testing.TB) *sql.DB {
|
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
return stdlib.OpenDB(*config)
|
|
}
|
|
|
|
func closeDB(t testing.TB, db *sql.DB) {
|
|
err := db.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func skipCockroachDB(t testing.TB, db *sql.DB, msg string) {
|
|
conn, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
err = conn.Raw(func(driverConn interface{}) error {
|
|
conn := driverConn.(*stdlib.Conn).Conn()
|
|
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
|
t.Skip(msg)
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func skipPostgreSQLVersion(t testing.TB, db *sql.DB, constraintStr, msg string) {
|
|
conn, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
err = conn.Raw(func(driverConn interface{}) error {
|
|
conn := driverConn.(*stdlib.Conn).Conn()
|
|
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
|
|
serverVersionStr = regexp.MustCompile(`^[0-9.]+`).FindString(serverVersionStr)
|
|
// if not PostgreSQL do nothing
|
|
if serverVersionStr == "" {
|
|
return nil
|
|
}
|
|
|
|
serverVersion, err := semver.NewVersion(serverVersionStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c, err := semver.NewConstraint(constraintStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if c.Check(serverVersion) {
|
|
t.Skip(msg)
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) {
|
|
t.Run("SimpleProto",
|
|
func(t *testing.T) {
|
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
config.PreferSimpleProtocol = true
|
|
db := stdlib.OpenDB(*config)
|
|
defer func() {
|
|
err := db.Close()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
f(t, db)
|
|
|
|
ensureDBValid(t, db)
|
|
},
|
|
)
|
|
|
|
t.Run("DefaultProto",
|
|
func(t *testing.T) {
|
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
db := stdlib.OpenDB(*config)
|
|
defer func() {
|
|
err := db.Close()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
f(t, db)
|
|
|
|
ensureDBValid(t, db)
|
|
},
|
|
)
|
|
}
|
|
|
|
// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
|
|
// cover an broken connections.
|
|
func ensureDBValid(t testing.TB, db *sql.DB) {
|
|
var sum, rowCount int32
|
|
|
|
rows, err := db.Query("select generate_series(1,$1)", 10)
|
|
require.NoError(t, err)
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var n int32
|
|
rows.Scan(&n)
|
|
sum += n
|
|
rowCount++
|
|
}
|
|
|
|
require.NoError(t, rows.Err())
|
|
|
|
if rowCount != 10 {
|
|
t.Error("Select called onDataRow wrong number of times")
|
|
}
|
|
if sum != 55 {
|
|
t.Error("Wrong values returned")
|
|
}
|
|
}
|
|
|
|
type preparer interface {
|
|
Prepare(query string) (*sql.Stmt, error)
|
|
}
|
|
|
|
func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
|
|
stmt, err := p.Prepare(sql)
|
|
require.NoError(t, err)
|
|
return stmt
|
|
}
|
|
|
|
func closeStmt(t *testing.T, stmt *sql.Stmt) {
|
|
err := stmt.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestSQLOpen(t *testing.T) {
|
|
db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
closeDB(t, db)
|
|
}
|
|
|
|
func TestNormalLifeCycle(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
|
|
defer closeStmt(t, stmt)
|
|
|
|
rows, err := stmt.Query(int32(1), int32(10))
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
err := rows.Scan(&s, &n)
|
|
require.NoError(t, err)
|
|
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
|
|
require.EqualValues(t, 10, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestStmtExec(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
tx, err := db.Begin()
|
|
require.NoError(t, err)
|
|
|
|
createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
|
|
_, err = createStmt.Exec()
|
|
require.NoError(t, err)
|
|
closeStmt(t, createStmt)
|
|
|
|
insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
|
|
result, err := insertStmt.Exec("foo")
|
|
require.NoError(t, err)
|
|
|
|
n, err := result.RowsAffected()
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, n)
|
|
closeStmt(t, insertStmt)
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestQueryCloseRowsEarly(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
|
|
defer closeStmt(t, stmt)
|
|
|
|
rows, err := stmt.Query(int32(1), int32(10))
|
|
require.NoError(t, err)
|
|
|
|
// Close rows immediately without having read them
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
|
|
// Run the query again to ensure the connection and statement are still ok
|
|
rows, err = stmt.Query(int32(1), int32(10))
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
err := rows.Scan(&s, &n)
|
|
require.NoError(t, err)
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
require.EqualValues(t, 10, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestConnExec(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec("create temporary table t(a varchar not null)")
|
|
require.NoError(t, err)
|
|
|
|
result, err := db.Exec("insert into t values('hey')")
|
|
require.NoError(t, err)
|
|
|
|
n, err := result.RowsAffected()
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, n)
|
|
})
|
|
}
|
|
|
|
func TestConnQuery(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var s string
|
|
var n int64
|
|
err := rows.Scan(&s, &n)
|
|
require.NoError(t, err)
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
require.EqualValues(t, 10, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/781
|
|
func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
var s string
|
|
var b bool
|
|
|
|
rows, err := db.Query("select true, 'foo'")
|
|
require.NoError(t, err)
|
|
|
|
require.True(t, rows.Next())
|
|
require.NoError(t, rows.Scan(&b, &s))
|
|
assert.Equal(t, true, b)
|
|
assert.Equal(t, "foo", s)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryNull(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
rows, err := db.Query("select $1::int", nil)
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
var n sql.NullInt64
|
|
err := rows.Scan(&n)
|
|
require.NoError(t, err)
|
|
if n.Valid != false {
|
|
t.Errorf("Expected n to be null, but it was %v", n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
require.EqualValues(t, 1, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryRowByteSlice(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
expected := []byte{222, 173, 190, 239}
|
|
var actual []byte
|
|
|
|
err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, expected, actual)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryFailure(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Query("select 'foo")
|
|
require.Error(t, err)
|
|
require.IsType(t, new(pgconn.PgError), err)
|
|
})
|
|
}
|
|
|
|
func TestConnSimpleSlicePassThrough(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server does not support cardinality function")
|
|
|
|
var n int64
|
|
err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n)
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, 3, n)
|
|
})
|
|
}
|
|
|
|
// Test type that pgx would handle natively in binary, but since it is not a
|
|
// database/sql native type should be passed through as a string
|
|
func TestConnQueryRowPgxBinary(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
sql := "select $1::int4[]"
|
|
expected := "{1,2,3}"
|
|
var actual string
|
|
|
|
err := db.QueryRow(sql, expected).Scan(&actual)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, expected, actual)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryRowUnknownType(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server does not support point type")
|
|
|
|
sql := "select $1::point"
|
|
expected := "(1,2)"
|
|
var actual string
|
|
|
|
err := db.QueryRow(sql, expected).Scan(&actual)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, expected, actual)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryJSONIntoByteSlice(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec(`
|
|
create temporary table docs(
|
|
body json not null
|
|
);
|
|
|
|
insert into docs(body) values('{"foo": "bar"}');
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
sql := `select * from docs`
|
|
expected := []byte(`{"foo": "bar"}`)
|
|
var actual []byte
|
|
|
|
err = db.QueryRow(sql).Scan(&actual)
|
|
if err != nil {
|
|
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
|
}
|
|
|
|
if bytes.Compare(actual, expected) != 0 {
|
|
t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
|
|
}
|
|
|
|
_, err = db.Exec(`drop table docs`)
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
|
|
// Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data
|
|
// that needs to escape. No way to know whether the destination is really a text compatible or a bytea.
|
|
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec(`
|
|
create temporary table docs(
|
|
body json not null
|
|
);
|
|
`)
|
|
require.NoError(t, err)
|
|
|
|
expected := []byte(`{"foo": "bar"}`)
|
|
|
|
_, err = db.Exec(`insert into docs(body) values($1)`, expected)
|
|
require.NoError(t, err)
|
|
|
|
var actual []byte
|
|
err = db.QueryRow(`select body from docs`).Scan(&actual)
|
|
require.NoError(t, err)
|
|
|
|
if bytes.Compare(actual, expected) != 0 {
|
|
t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
|
|
}
|
|
|
|
_, err = db.Exec(`drop table docs`)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestTransactionLifeCycle(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec("create temporary table t(a varchar not null)")
|
|
require.NoError(t, err)
|
|
|
|
tx, err := db.Begin()
|
|
require.NoError(t, err)
|
|
|
|
_, err = tx.Exec("insert into t values('hi')")
|
|
require.NoError(t, err)
|
|
|
|
err = tx.Rollback()
|
|
require.NoError(t, err)
|
|
|
|
var n int64
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 0, n)
|
|
|
|
tx, err = db.Begin()
|
|
require.NoError(t, err)
|
|
|
|
_, err = tx.Exec("insert into t values('hi')")
|
|
require.NoError(t, err)
|
|
|
|
err = tx.Commit()
|
|
require.NoError(t, err)
|
|
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, n)
|
|
})
|
|
}
|
|
|
|
func TestConnBeginTxIsolation(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server always uses serializable isolation level")
|
|
|
|
var defaultIsoLevel string
|
|
err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
|
|
require.NoError(t, err)
|
|
|
|
supportedTests := []struct {
|
|
sqlIso sql.IsolationLevel
|
|
pgIso string
|
|
}{
|
|
{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
|
|
{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
|
|
{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
|
|
{sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
|
|
{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
|
|
{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
|
|
}
|
|
for i, tt := range supportedTests {
|
|
func() {
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
|
|
if err != nil {
|
|
t.Errorf("%d. BeginTx failed: %v", i, err)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var pgIso string
|
|
err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
|
|
if err != nil {
|
|
t.Errorf("%d. QueryRow failed: %v", i, err)
|
|
}
|
|
|
|
if pgIso != tt.pgIso {
|
|
t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
|
|
}
|
|
}()
|
|
}
|
|
|
|
unsupportedTests := []struct {
|
|
sqlIso sql.IsolationLevel
|
|
}{
|
|
{sqlIso: sql.LevelWriteCommitted},
|
|
{sqlIso: sql.LevelLinearizable},
|
|
}
|
|
for i, tt := range unsupportedTests {
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
|
|
if err == nil {
|
|
t.Errorf("%d. BeginTx should have failed", i)
|
|
tx.Rollback()
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConnBeginTxReadOnly(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
|
|
require.NoError(t, err)
|
|
defer tx.Rollback()
|
|
|
|
var pgReadOnly string
|
|
err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
|
|
if err != nil {
|
|
t.Errorf("QueryRow failed: %v", err)
|
|
}
|
|
|
|
if pgReadOnly != "on" {
|
|
t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBeginTxContextCancel(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.Exec("drop table if exists t")
|
|
require.NoError(t, err)
|
|
|
|
ctx, cancelFn := context.WithCancel(context.Background())
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
require.NoError(t, err)
|
|
|
|
_, err = tx.Exec("create table t(id serial)")
|
|
require.NoError(t, err)
|
|
|
|
cancelFn()
|
|
|
|
err = tx.Commit()
|
|
if err != context.Canceled && err != sql.ErrTxDone {
|
|
t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
|
|
}
|
|
|
|
var n int
|
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
|
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
|
|
t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAcquireConn(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
var conns []*pgx.Conn
|
|
|
|
for i := 1; i < 6; i++ {
|
|
conn, err := stdlib.AcquireConn(db)
|
|
if err != nil {
|
|
t.Errorf("%d. AcquireConn failed: %v", i, err)
|
|
continue
|
|
}
|
|
|
|
var n int32
|
|
err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
|
|
if err != nil {
|
|
t.Errorf("%d. QueryRow failed: %v", i, err)
|
|
}
|
|
if n != 1 {
|
|
t.Errorf("%d. n => %d, want %d", i, n, 1)
|
|
}
|
|
|
|
stats := db.Stats()
|
|
if stats.OpenConnections != i {
|
|
t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i)
|
|
}
|
|
|
|
conns = append(conns, conn)
|
|
}
|
|
|
|
for i, conn := range conns {
|
|
if err := stdlib.ReleaseConn(db, conn); err != nil {
|
|
t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConnRaw(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
conn, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
var n int
|
|
err = conn.Raw(func(driverConn interface{}) error {
|
|
conn := driverConn.(*stdlib.Conn).Conn()
|
|
return conn.QueryRow(context.Background(), "select 42").Scan(&n)
|
|
})
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, 42, n)
|
|
})
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/673
|
|
func TestReleaseConnWithTxInProgress(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server does not support backend PID")
|
|
|
|
c1, err := stdlib.AcquireConn(db)
|
|
require.NoError(t, err)
|
|
|
|
_, err = c1.Exec(context.Background(), "begin")
|
|
require.NoError(t, err)
|
|
|
|
c1PID := c1.PgConn().PID()
|
|
|
|
err = stdlib.ReleaseConn(db, c1)
|
|
require.NoError(t, err)
|
|
|
|
c2, err := stdlib.AcquireConn(db)
|
|
require.NoError(t, err)
|
|
|
|
c2PID := c2.PgConn().PID()
|
|
|
|
err = stdlib.ReleaseConn(db, c2)
|
|
require.NoError(t, err)
|
|
|
|
require.NotEqual(t, c1PID, c2PID)
|
|
|
|
// Releasing a conn with a tx in progress should close the connection
|
|
stats := db.Stats()
|
|
require.Equal(t, 1, stats.OpenConnections)
|
|
})
|
|
}
|
|
|
|
func TestConnPingContextSuccess(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
err := db.PingContext(context.Background())
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnPrepareContextSuccess(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
stmt, err := db.PrepareContext(context.Background(), "select now()")
|
|
require.NoError(t, err)
|
|
err = stmt.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnExecContextSuccess(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnExecContextFailureRetry(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
// We get a connection, immediately close it, and then get it back;
|
|
// DB.Conn along with Conn.ResetSession does the retry for us.
|
|
{
|
|
conn, err := stdlib.AcquireConn(db)
|
|
require.NoError(t, err)
|
|
conn.Close(context.Background())
|
|
stdlib.ReleaseConn(db, conn)
|
|
}
|
|
conn, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
_, err = conn.ExecContext(context.Background(), "select 1")
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConnQueryContextSuccess(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
|
|
require.NoError(t, err)
|
|
|
|
for rows.Next() {
|
|
var n int64
|
|
err := rows.Scan(&n)
|
|
require.NoError(t, err)
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
})
|
|
}
|
|
|
|
func TestConnQueryContextFailureRetry(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
// We get a connection, immediately close it, and then get it back;
|
|
// DB.Conn along with Conn.ResetSession does the retry for us.
|
|
{
|
|
conn, err := stdlib.AcquireConn(db)
|
|
require.NoError(t, err)
|
|
conn.Close(context.Background())
|
|
stdlib.ReleaseConn(db, conn)
|
|
}
|
|
conn, err := db.Conn(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
_, err = conn.QueryContext(context.Background(), "select 1")
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
rows, err := db.Query("select 42::bigint")
|
|
require.NoError(t, err)
|
|
|
|
columnTypes, err := rows.ColumnTypes()
|
|
require.NoError(t, err)
|
|
require.Len(t, columnTypes, 1)
|
|
|
|
if columnTypes[0].DatabaseTypeName() != "INT8" {
|
|
t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8")
|
|
}
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestStmtExecContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("create temporary table t(id int primary key)")
|
|
require.NoError(t, err)
|
|
|
|
stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
|
|
require.NoError(t, err)
|
|
defer stmt.Close()
|
|
|
|
_, err = stmt.ExecContext(context.Background(), 42)
|
|
require.NoError(t, err)
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestStmtExecContextCancel(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
_, err := db.Exec("create temporary table t(id int primary key)")
|
|
require.NoError(t, err)
|
|
|
|
stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
|
|
require.NoError(t, err)
|
|
defer stmt.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = stmt.ExecContext(ctx, 42)
|
|
if !pgconn.Timeout(err) {
|
|
t.Errorf("expected timeout error, got %v", err)
|
|
}
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestStmtQueryContextSuccess(t *testing.T) {
|
|
db := openDB(t)
|
|
defer closeDB(t, db)
|
|
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
|
|
require.NoError(t, err)
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.QueryContext(context.Background(), 5)
|
|
require.NoError(t, err)
|
|
|
|
for rows.Next() {
|
|
var n int64
|
|
if err := rows.Scan(&n); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
t.Error(rows.Err())
|
|
}
|
|
|
|
ensureDBValid(t, db)
|
|
}
|
|
|
|
func TestRowsColumnTypes(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
columnTypesTests := []struct {
|
|
Name string
|
|
TypeName string
|
|
Length struct {
|
|
Len int64
|
|
OK bool
|
|
}
|
|
DecimalSize struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}
|
|
ScanType reflect.Type
|
|
}{
|
|
{
|
|
Name: "a",
|
|
TypeName: "INT8",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: 0,
|
|
OK: false,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 0,
|
|
Scale: 0,
|
|
OK: false,
|
|
},
|
|
ScanType: reflect.TypeOf(int64(0)),
|
|
}, {
|
|
Name: "bar",
|
|
TypeName: "TEXT",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: math.MaxInt64,
|
|
OK: true,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 0,
|
|
Scale: 0,
|
|
OK: false,
|
|
},
|
|
ScanType: reflect.TypeOf(""),
|
|
}, {
|
|
Name: "dec",
|
|
TypeName: "NUMERIC",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: 0,
|
|
OK: false,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 9,
|
|
Scale: 2,
|
|
OK: true,
|
|
},
|
|
ScanType: reflect.TypeOf(float64(0)),
|
|
}, {
|
|
Name: "d",
|
|
TypeName: "1266",
|
|
Length: struct {
|
|
Len int64
|
|
OK bool
|
|
}{
|
|
Len: 0,
|
|
OK: false,
|
|
},
|
|
DecimalSize: struct {
|
|
Precision int64
|
|
Scale int64
|
|
OK bool
|
|
}{
|
|
Precision: 0,
|
|
Scale: 0,
|
|
OK: false,
|
|
},
|
|
ScanType: reflect.TypeOf(""),
|
|
},
|
|
}
|
|
|
|
rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d")
|
|
require.NoError(t, err)
|
|
|
|
columns, err := rows.ColumnTypes()
|
|
require.NoError(t, err)
|
|
assert.Len(t, columns, 4)
|
|
|
|
for i, tt := range columnTypesTests {
|
|
c := columns[i]
|
|
if c.Name() != tt.Name {
|
|
t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
|
|
}
|
|
if c.DatabaseTypeName() != tt.TypeName {
|
|
t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
|
|
}
|
|
l, ok := c.Length()
|
|
if l != tt.Length.Len {
|
|
t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
|
|
}
|
|
if ok != tt.Length.OK {
|
|
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
|
|
}
|
|
p, s, ok := c.DecimalSize()
|
|
if p != tt.DecimalSize.Precision {
|
|
t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
|
|
}
|
|
if s != tt.DecimalSize.Scale {
|
|
t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
|
|
}
|
|
if ok != tt.DecimalSize.OK {
|
|
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
|
|
}
|
|
if c.ScanType() != tt.ScanType {
|
|
t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestQueryLifeCycle(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
|
|
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
|
require.NoError(t, err)
|
|
|
|
rowCount := int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
var (
|
|
s string
|
|
n int64
|
|
)
|
|
|
|
err := rows.Scan(&s, &n)
|
|
require.NoError(t, err)
|
|
|
|
if s != "foo" {
|
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
|
}
|
|
|
|
if n != rowCount {
|
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
|
}
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
|
|
rows, err = db.Query("select 1 where false")
|
|
require.NoError(t, err)
|
|
|
|
rowCount = int64(0)
|
|
|
|
for rows.Next() {
|
|
rowCount++
|
|
}
|
|
require.NoError(t, rows.Err())
|
|
require.EqualValues(t, 0, rowCount)
|
|
|
|
err = rows.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/409
|
|
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
var msg json.RawMessage
|
|
|
|
err := db.QueryRow("select '{}'::json").Scan(&msg)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, []byte("{}"), []byte(msg))
|
|
})
|
|
}
|
|
|
|
type testLog struct {
|
|
lvl pgx.LogLevel
|
|
msg string
|
|
data map[string]interface{}
|
|
}
|
|
|
|
type testLogger struct {
|
|
logs []testLog
|
|
}
|
|
|
|
func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
|
|
}
|
|
|
|
func TestRegisterConnConfig(t *testing.T) {
|
|
connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
logger := &testLogger{}
|
|
connConfig.Logger = logger
|
|
|
|
// Issue 947: Register and unregister a ConnConfig and ensure that the
|
|
// returned connection string is not reused.
|
|
connStr := stdlib.RegisterConnConfig(connConfig)
|
|
require.Equal(t, "registeredConnConfig0", connStr)
|
|
stdlib.UnregisterConnConfig(connStr)
|
|
|
|
connStr = stdlib.RegisterConnConfig(connConfig)
|
|
defer stdlib.UnregisterConnConfig(connStr)
|
|
require.Equal(t, "registeredConnConfig1", connStr)
|
|
|
|
db, err := sql.Open("pgx", connStr)
|
|
require.NoError(t, err)
|
|
defer closeDB(t, db)
|
|
|
|
var n int64
|
|
err = db.QueryRow("select 1").Scan(&n)
|
|
require.NoError(t, err)
|
|
|
|
l := logger.logs[len(logger.logs)-1]
|
|
assert.Equal(t, "Query", l.msg)
|
|
assert.Equal(t, "select 1", l.data["sql"])
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/958
|
|
func TestConnQueryRowConstraintErrors(t *testing.T) {
|
|
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
|
|
skipPostgreSQLVersion(t, db, "< 11", "Test requires PG 11+")
|
|
skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
|
|
|
_, err := db.Exec(`create temporary table defer_test (
|
|
id text primary key,
|
|
n int not null, unique (n),
|
|
unique (n) deferrable initially deferred )`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec(`drop function if exists test_trigger cascade`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$
|
|
begin
|
|
if new.n = 4 then
|
|
raise exception 'n cant be 4!';
|
|
end if;
|
|
return new;
|
|
end$$`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec(`create constraint trigger test
|
|
after insert or update on defer_test
|
|
deferrable initially deferred
|
|
for each row
|
|
execute function test_trigger()`)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`)
|
|
require.NoError(t, err)
|
|
|
|
var id string
|
|
err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|