2
0

Add simple protocol suuport with (Query|Exec)Ex

This commit is contained in:
Jack Christensen
2017-04-10 08:58:51 -05:00
parent 54d9cbc743
commit 7b1f461ec3
16 changed files with 999 additions and 326 deletions
+229 -281
View File
@@ -797,275 +797,6 @@ func TestQueryRowNoResults(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowCoreInt16Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []int16
tests := []struct {
sql string
expected []int16
}{
{"select $1::int2[]", []int16{1, 2, 3, 4, 5}},
{"select $1::int2[]", []int16{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
if err != nil && !(strings.Contains(err.Error(), "Cannot decode null") || strings.Contains(err.Error(), "cannot assign")) {
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreInt32Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []int32
tests := []struct {
sql string
expected []int32
}{
{"select $1::int4[]", []int32{1, 2, 3, 4, 5}},
{"select $1::int4[]", []int32{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreInt64Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []int64
tests := []struct {
sql string
expected []int64
}{
{"select $1::int8[]", []int64{1, 2, 3, 4, 5}},
{"select $1::int8[]", []int64{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreFloat32Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []float32
tests := []struct {
sql string
expected []float32
}{
{"select $1::float4[]", []float32{1.5, 2.0, 3.5}},
{"select $1::float4[]", []float32{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreFloat64Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []float64
tests := []struct {
sql string
expected []float64
}{
{"select $1::float8[]", []float64{1.5, 2.0, 3.5}},
{"select $1::float8[]", []float64{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreStringSlice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []string
tests := []struct {
sql string
expected []string
}{
{"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
{"select $1::text[]", []string{}},
{"select $1::varchar[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
{"select $1::varchar[]", []string{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestReadingValueAfterEmptyArray(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
@@ -1236,7 +967,7 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryContextSuccess(t *testing.T) {
func TestQueryExContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@@ -1245,7 +976,7 @@ func TestQueryContextSuccess(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryContext(ctx, "select 42::integer")
rows, err := conn.QueryEx(ctx, "select 42::integer", nil)
if err != nil {
t.Fatal(err)
}
@@ -1273,7 +1004,7 @@ func TestQueryContextSuccess(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
func TestQueryExContextErrorWhileReceivingRows(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@@ -1282,7 +1013,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n")
rows, err := conn.QueryEx(ctx, "select 10/(10-n) from generate_series(1, 100) n", nil)
if err != nil {
t.Fatal(err)
}
@@ -1310,7 +1041,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryContextCancelationCancelsQuery(t *testing.T) {
func TestQueryExContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@@ -1322,7 +1053,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) {
cancelFunc()
}()
rows, err := conn.QueryContext(ctx, "select pg_sleep(5)")
rows, err := conn.QueryEx(ctx, "select pg_sleep(5)", nil)
if err != nil {
t.Fatal(err)
}
@@ -1338,7 +1069,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowContextSuccess(t *testing.T) {
func TestQueryRowExContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@@ -1348,7 +1079,7 @@ func TestQueryRowContextSuccess(t *testing.T) {
defer cancelFunc()
var result int
err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result)
err := conn.QueryRowEx(ctx, "select 42::integer", nil).Scan(&result)
if err != nil {
t.Fatal(err)
}
@@ -1359,7 +1090,7 @@ func TestQueryRowContextSuccess(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@@ -1369,7 +1100,7 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
defer cancelFunc()
var result int
err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result)
err := conn.QueryRowEx(ctx, "select 10/0", nil).Scan(&result)
if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
t.Fatalf("Expected division by zero error, but got %v", err)
}
@@ -1377,7 +1108,7 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@@ -1390,10 +1121,227 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
}()
var result []byte
err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result)
err := conn.QueryRowEx(ctx, "select pg_sleep(5)", nil).Scan(&result)
if err != context.Canceled {
t.Fatalf("Expected context.Canceled error, got %v", err)
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocol(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
// Test all supported low-level types
{
expected := int64(42)
var actual int64
err := conn.QueryRowEx(
context.Background(),
"select $1::int8",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := float64(1.23)
var actual float64
err := conn.QueryRowEx(
context.Background(),
"select $1::float8",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := true
var actual bool
err := conn.QueryRowEx(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
var actual []byte
err := conn.QueryRowEx(
context.Background(),
"select $1::bytea",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if bytes.Compare(actual, expected) != 0 {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := "test"
var actual string
err := conn.QueryRowEx(
context.Background(),
"select $1::text",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
// Test high-level type
{
expected := pgtype.Line{A: 1, B: 2, C: 1.5, Status: pgtype.Present}
actual := expected
err := conn.QueryRowEx(
context.Background(),
"select $1::line",
&pgx.QueryExOptions{SimpleProtocol: true},
&expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
// Test multiple args in single query
{
expectedInt64 := int64(234423)
expectedFloat64 := float64(-0.2312)
expectedBool := true
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
expectedString := "test"
var actualInt64 int64
var actualFloat64 float64
var actualBool bool
var actualBytes []byte
var actualString string
err := conn.QueryRowEx(
context.Background(),
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
&pgx.QueryExOptions{SimpleProtocol: true},
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
if err != nil {
t.Error(err)
}
if expectedInt64 != actualInt64 {
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
}
if expectedFloat64 != actualFloat64 {
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
}
if expectedBool != actualBool {
t.Errorf("expected %v got %v", expectedBool, actualBool)
}
if bytes.Compare(expectedBytes, actualBytes) != 0 {
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
}
if expectedString != actualString {
t.Errorf("expected %v got %v", expectedString, actualString)
}
}
// Test dangerous cases
{
expected := "foo';drop table users;"
var actual string
err := conn.QueryRowEx(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
var expected string
err := conn.QueryRowEx(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
"test",
).Scan(&expected)
if err == nil {
t.Error("expected error when client_encoding not UTF8, but no error occurred")
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, "set standard_conforming_strings to off")
var expected string
err := conn.QueryRowEx(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
`\'; drop table users; --`,
).Scan(&expected)
if err == nil {
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
}
ensureConnValid(t, conn)
}