2
0

Remove Ex versions of Query and QueryRow

Always require context and prepend options to arguments if necessary.
This commit is contained in:
Jack Christensen
2019-04-10 12:12:22 -05:00
parent b69179cebb
commit 7718ee6207
26 changed files with 217 additions and 353 deletions
+59 -57
View File
@@ -28,7 +28,7 @@ func TestConnQueryScan(t *testing.T) {
var sum, rowCount int32
rows, err := conn.Query("select generate_series(1,$1)", 10)
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -73,7 +73,7 @@ func TestConnQueryScanWithManyColumns(t *testing.T) {
var rowCount int
rows, err := conn.Query(sql)
rows, err := conn.Query(context.Background(), sql)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -113,7 +113,7 @@ func TestConnQueryValues(t *testing.T) {
var rowCount int32
rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10)
rows, err := conn.Query(context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -187,7 +187,7 @@ func TestConnQueryValuesWithMultipleComplexColumnsOfSameType(t *testing.T) {
var rowCount int32
rows, err := conn.Query("select '{1,2,3}'::bigint[], '{4,5,6}'::bigint[] from generate_series(1,$1) n", 10)
rows, err := conn.Query(context.Background(), "select '{1,2,3}'::bigint[], '{4,5,6}'::bigint[] from generate_series(1,$1) n", 10)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -229,7 +229,7 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T)
var s string
err := conn.QueryRow("select 1").Scan(&s)
err := conn.QueryRow(context.Background(), "select 1").Scan(&s)
if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) {
t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err)
}
@@ -245,7 +245,7 @@ func TestConnQueryCloseEarly(t *testing.T) {
defer closeConn(t, conn)
// Immediately close query without reading any rows
rows, err := conn.Query("select generate_series(1,$1)", 10)
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -254,7 +254,7 @@ func TestConnQueryCloseEarly(t *testing.T) {
ensureConnValid(t, conn)
// Read partial response then close
rows, err = conn.Query("select generate_series(1,$1)", 10)
rows, err = conn.Query(context.Background(), "select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -281,7 +281,7 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
rows, err := conn.Query("select 1/(10-n) from generate_series(1,10) n")
rows, err := conn.Query(context.Background(), "select 1/(10-n) from generate_series(1,10) n")
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -301,7 +301,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) {
defer closeConn(t, conn)
// Read a single value incorrectly
rows, err := conn.Query("select generate_series(1,$1)", 10)
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -337,7 +337,7 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
defer closeConn(t, conn)
// Read too many values
rows, err := conn.Query("select generate_series(1,$1)", 10)
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -367,7 +367,7 @@ func TestConnQueryScanIgnoreColumn(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
rows, err := conn.Query("select 1::int8, 2::int8, 3::int8")
rows, err := conn.Query(context.Background(), "select 1::int8, 2::int8, 3::int8")
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -405,7 +405,7 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) {
func() {
sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)`
rows, err := conn.Query(sql)
rows, err := conn.Query(context.Background(), sql)
if err != nil {
t.Fatal(err)
}
@@ -432,7 +432,7 @@ func TestQueryEncodeError(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
rows, err := conn.Query("select $1::integer", "wrong")
rows, err := conn.Query(context.Background(), "select $1::integer", "wrong")
if err != nil {
t.Errorf("conn.Query failure: %v", err)
}
@@ -487,7 +487,7 @@ func TestQueryRowCoreTypes(t *testing.T) {
for i, tt := range tests {
actual = zero
err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
}
@@ -499,7 +499,7 @@ func TestQueryRowCoreTypes(t *testing.T) {
ensureConnValid(t, conn)
// Check that Scan errors when a core type is null
err = conn.QueryRow(tt.sql, nil).Scan(tt.scanArgs...)
err = conn.QueryRow(context.Background(), tt.sql, nil).Scan(tt.scanArgs...)
if err == nil {
t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql)
}
@@ -575,7 +575,7 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) {
for i, tt := range successfulEncodeTests {
actual = zero
err := conn.QueryRow(tt.sql, tt.queryArg).Scan(tt.scanArg)
err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(tt.scanArg)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg)
continue
@@ -612,7 +612,7 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) {
}
for i, tt := range failedEncodeTests {
err := conn.QueryRow(tt.sql, tt.queryArg).Scan(nil)
err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(nil)
if err == nil {
t.Errorf("%d. Expected failure to encode, but unexpectedly succeeded: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg)
} else if !strings.Contains(err.Error(), "is greater than") {
@@ -718,7 +718,7 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) {
for i, tt := range successfulDecodeTests {
actual = zero
err := conn.QueryRow(tt.sql).Scan(tt.scanArg)
err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
continue
@@ -787,7 +787,7 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) {
}
for i, tt := range failedDecodeTests {
err := conn.QueryRow(tt.sql).Scan(tt.scanArg)
err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg)
if err == nil {
t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql)
} else if !strings.Contains(err.Error(), tt.expectedErr) {
@@ -818,7 +818,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) {
for i, tt := range tests {
var actual []byte
err := conn.QueryRow(tt.sql, tt.queryArg).Scan(&actual)
err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
}
@@ -854,7 +854,7 @@ func TestQueryRowUnknownType(t *testing.T) {
expected := "(1,0)"
var actual string
err := conn.QueryRow(sql, expected).Scan(&actual)
err := conn.QueryRow(context.Background(), sql, expected).Scan(&actual)
if err != nil {
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
}
@@ -896,7 +896,7 @@ func TestQueryRowErrors(t *testing.T) {
for i, tt := range tests {
actual = zero
err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
if err == nil {
t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs)
}
@@ -921,17 +921,18 @@ func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) {
select some_int from t where some_text = $1`
paramOIDs := []pgtype.OID{pgtype.TextArrayOID}
queryArgs := []interface{}{"bar"}
queryOptions := &pgx.QueryExOptions{
ParameterOIDs: paramOIDs,
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
}
optionsAndArgs := append([]interface{}{queryOptions}, queryArgs...)
expectedErr := "operator does not exist: text = text[] (SQLSTATE 42883)"
var result int64
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
sql,
&pgx.QueryExOptions{
ParameterOIDs: paramOIDs,
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
},
queryArgs...,
optionsAndArgs...,
).Scan(&result)
if err == nil {
@@ -952,7 +953,7 @@ func TestQueryRowNoResults(t *testing.T) {
defer closeConn(t, conn)
var n int32
err := conn.QueryRow("select 1 where 1=0").Scan(&n)
err := conn.QueryRow(context.Background(), "select 1 where 1=0").Scan(&n)
if err != pgx.ErrNoRows {
t.Errorf("Expected pgx.ErrNoRows, got %v", err)
}
@@ -966,7 +967,7 @@ func TestReadingValueAfterEmptyArray(t *testing.T) {
var a []string
var b int32
err := conn.QueryRow("select '{}'::text[], 42::integer").Scan(&a, &b)
err := conn.QueryRow(context.Background(), "select '{}'::text[], 42::integer").Scan(&a, &b)
if err != nil {
t.Fatalf("conn.QueryRow failed: %v", err)
}
@@ -985,7 +986,7 @@ func TestReadingNullByteArray(t *testing.T) {
defer closeConn(t, conn)
var a []byte
err := conn.QueryRow("select null::text").Scan(&a)
err := conn.QueryRow(context.Background(), "select null::text").Scan(&a)
if err != nil {
t.Fatalf("conn.QueryRow failed: %v", err)
}
@@ -999,7 +1000,7 @@ func TestReadingNullByteArrays(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
rows, err := conn.Query("select null::text union all select null::text")
rows, err := conn.Query(context.Background(), "select null::text union all select null::text")
if err != nil {
t.Fatalf("conn.Query failed: %v", err)
}
@@ -1030,7 +1031,7 @@ func TestConnQueryDatabaseSQLScanner(t *testing.T) {
var num decimal.Decimal
err := conn.QueryRow("select '1234.567'::decimal").Scan(&num)
err := conn.QueryRow(context.Background(), "select '1234.567'::decimal").Scan(&num)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}
@@ -1061,7 +1062,7 @@ func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
}
var num decimal.Decimal
err = conn.QueryRow("select $1::decimal", &expected).Scan(&num)
err = conn.QueryRow(context.Background(), "select $1::decimal", &expected).Scan(&num)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}
@@ -1112,7 +1113,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *
}
var u2 uuid.UUID
err = conn.QueryRow("select $1::uuid", expected).Scan(&u2)
err = conn.QueryRow(context.Background(), "select $1::uuid", expected).Scan(&u2)
if err != nil {
t.Fatalf("Scan failed: %v", err)
}
@@ -1151,6 +1152,7 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
var actual row
err := conn.QueryRow(
context.Background(),
"select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text",
expected.boolValid,
expected.boolNull,
@@ -1181,7 +1183,7 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryExContextSuccess(t *testing.T) {
func TestQueryContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -1190,7 +1192,7 @@ func TestQueryExContextSuccess(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryEx(ctx, "select 42::integer", nil)
rows, err := conn.Query(ctx, "select 42::integer")
if err != nil {
t.Fatal(err)
}
@@ -1221,7 +1223,7 @@ func TestQueryExContextSuccess(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryExContextErrorWhileReceivingRows(t *testing.T) {
func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -1230,7 +1232,7 @@ func TestQueryExContextErrorWhileReceivingRows(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryEx(ctx, "select 10/(10-n) from generate_series(1, 100) n", nil)
rows, err := conn.Query(ctx, "select 10/(10-n) from generate_series(1, 100) n")
if err != nil {
t.Fatal(err)
}
@@ -1258,7 +1260,7 @@ func TestQueryExContextErrorWhileReceivingRows(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowExContextSuccess(t *testing.T) {
func TestQueryRowContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -1268,7 +1270,7 @@ func TestQueryRowExContextSuccess(t *testing.T) {
defer cancelFunc()
var result int
err := conn.QueryRowEx(ctx, "select 42::integer", nil).Scan(&result)
err := conn.QueryRow(ctx, "select 42::integer").Scan(&result)
if err != nil {
t.Fatal(err)
}
@@ -1282,7 +1284,7 @@ func TestQueryRowExContextSuccess(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) {
func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -1292,7 +1294,7 @@ func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) {
defer cancelFunc()
var result int
err := conn.QueryRowEx(ctx, "select 10/0", nil).Scan(&result)
err := conn.QueryRow(ctx, "select 10/0").Scan(&result)
if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
t.Fatalf("Expected division by zero error, but got %v", err)
}
@@ -1300,14 +1302,14 @@ func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnQueryRowExSingleRoundTrip(t *testing.T) {
func TestConnQueryRowSingleRoundTrip(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
var result int32
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1 + $2",
&pgx.QueryExOptions{
@@ -1337,7 +1339,7 @@ func TestConnSimpleProtocol(t *testing.T) {
{
expected := int64(42)
var actual int64
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1::int8",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1357,7 +1359,7 @@ func TestConnSimpleProtocol(t *testing.T) {
{
expected := float64(1.23)
var actual float64
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1::float8",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1377,7 +1379,7 @@ func TestConnSimpleProtocol(t *testing.T) {
{
expected := true
var actual bool
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1397,7 +1399,7 @@ func TestConnSimpleProtocol(t *testing.T) {
{
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
var actual []byte
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1::bytea",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1417,7 +1419,7 @@ func TestConnSimpleProtocol(t *testing.T) {
{
expected := "test"
var actual string
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1::text",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1439,7 +1441,7 @@ func TestConnSimpleProtocol(t *testing.T) {
{
expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present}
actual := expected
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1::circle",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1469,7 +1471,7 @@ func TestConnSimpleProtocol(t *testing.T) {
var actualBool bool
var actualBytes []byte
var actualString string
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1503,7 +1505,7 @@ func TestConnSimpleProtocol(t *testing.T) {
{
expected := "foo';drop table users;"
var actual string
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1532,7 +1534,7 @@ func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
var expected string
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1554,7 +1556,7 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
mustExec(t, conn, "set standard_conforming_strings to off")
var expected string
err := conn.QueryRowEx(
err := conn.QueryRow(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
@@ -1567,13 +1569,13 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryExCloseBefore(t *testing.T) {
func TestQueryCloseBefore(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
closeConn(t, conn)
if _, err := conn.QueryEx(context.Background(), "select 1", nil); err == nil {
if _, err := conn.Query(context.Background(), "select 1"); err == nil {
t.Fatal("Expected network error")
}
if conn.LastStmtSent() {