Restore simple protocol support
This commit is contained in:
+217
@@ -1305,3 +1305,220 @@ func TestRowsFromResultReader(t *testing.T) {
|
||||
t.Error("Wrong values returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Test all supported low-level types
|
||||
|
||||
{
|
||||
expected := int64(42)
|
||||
var actual int64
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := float64(1.23)
|
||||
var actual float64
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::float8",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := true
|
||||
var actual bool
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
|
||||
var actual []byte
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::bytea",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if bytes.Compare(actual, expected) != 0 {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := "test"
|
||||
var actual string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::text",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Test high-level type
|
||||
|
||||
{
|
||||
expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present}
|
||||
actual := expected
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::circle",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
&expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Test multiple args in single query
|
||||
|
||||
{
|
||||
expectedInt64 := int64(234423)
|
||||
expectedFloat64 := float64(-0.2312)
|
||||
expectedBool := true
|
||||
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
|
||||
expectedString := "test"
|
||||
var actualInt64 int64
|
||||
var actualFloat64 float64
|
||||
var actualBool bool
|
||||
var actualBytes []byte
|
||||
var actualString string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
|
||||
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expectedInt64 != actualInt64 {
|
||||
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
|
||||
}
|
||||
if expectedFloat64 != actualFloat64 {
|
||||
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
|
||||
}
|
||||
if expectedBool != actualBool {
|
||||
t.Errorf("expected %v got %v", expectedBool, actualBool)
|
||||
}
|
||||
if bytes.Compare(expectedBytes, actualBytes) != 0 {
|
||||
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
|
||||
}
|
||||
if expectedString != actualString {
|
||||
t.Errorf("expected %v got %v", expectedString, actualString)
|
||||
}
|
||||
}
|
||||
|
||||
// Test dangerous cases
|
||||
|
||||
{
|
||||
expected := "foo';drop table users;"
|
||||
var actual string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
"test",
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when client_encoding not UTF8, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set standard_conforming_strings to off")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
`\'; drop table users; --`,
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user