Remove simple protocol and one round trip query options
It is impossible to guarantee that the a query executed with the simple protocol will behave the same as with the extended protocol. This is because the normal pgx path relies on knowing the OID of query parameters. Without this encoding a value can only be determined by the value instead of the combination of value and PostgreSQL type. For example, how should a []int32 be encoded? It might be encoded into a PostgreSQL int4[] or json. Removal also simplifies the core query path. The primary reason for the simple protocol is for servers like PgBouncer that may not be able to support normal prepared statements. After further research it appears that issuing a "flush" instead "sync" after preparing the unnamed statement would allow PgBouncer to work. The one round trip mode can be better handled with prepared statements. As a last resort, all original server functionality can still be accessed by dropping down to PgConn.
This commit is contained in:
-305
@@ -908,44 +908,6 @@ func TestQueryRowErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `
|
||||
with t as (
|
||||
select 1::int8 as some_int, 'foo'::text as some_text
|
||||
)
|
||||
select some_int from t where some_text = $1`
|
||||
paramOIDs := []pgtype.OID{pgtype.TextArrayOID}
|
||||
queryArgs := []interface{}{"bar"}
|
||||
queryOptions := &pgx.QueryExOptions{
|
||||
ParameterOIDs: paramOIDs,
|
||||
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
|
||||
}
|
||||
optionsAndArgs := append([]interface{}{queryOptions}, queryArgs...)
|
||||
expectedErr := "operator does not exist: text = text[] (SQLSTATE 42883)"
|
||||
var result int64
|
||||
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
sql,
|
||||
optionsAndArgs...,
|
||||
).Scan(&result)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Unexpected success (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", sql, paramOIDs, queryArgs)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), expectedErr) {
|
||||
t.Errorf("Expected error to contain %s, but got %v (sql -> %v, paramOIDs -> %v, queryArgs -> %v)",
|
||||
expectedErr, err, sql, paramOIDs, queryArgs)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowNoResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1302,273 +1264,6 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryRowSingleRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var result int32
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1 + $2",
|
||||
&pgx.QueryExOptions{
|
||||
ParameterOIDs: []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
|
||||
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
|
||||
},
|
||||
1, 2,
|
||||
).Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != 3 {
|
||||
t.Fatalf("result => %d, want %d", result, 3)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Test all supported low-level types
|
||||
|
||||
{
|
||||
expected := int64(42)
|
||||
var actual int64
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := float64(1.23)
|
||||
var actual float64
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::float8",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := true
|
||||
var actual bool
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
|
||||
var actual []byte
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::bytea",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if bytes.Compare(actual, expected) != 0 {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := "test"
|
||||
var actual string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::text",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test high-level type
|
||||
|
||||
{
|
||||
expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present}
|
||||
actual := expected
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::circle",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
&expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test multiple args in single query
|
||||
|
||||
{
|
||||
expectedInt64 := int64(234423)
|
||||
expectedFloat64 := float64(-0.2312)
|
||||
expectedBool := true
|
||||
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
|
||||
expectedString := "test"
|
||||
var actualInt64 int64
|
||||
var actualFloat64 float64
|
||||
var actualBool bool
|
||||
var actualBytes []byte
|
||||
var actualString string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
|
||||
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expectedInt64 != actualInt64 {
|
||||
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
|
||||
}
|
||||
if expectedFloat64 != actualFloat64 {
|
||||
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
|
||||
}
|
||||
if expectedBool != actualBool {
|
||||
t.Errorf("expected %v got %v", expectedBool, actualBool)
|
||||
}
|
||||
if bytes.Compare(expectedBytes, actualBytes) != 0 {
|
||||
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
|
||||
}
|
||||
if expectedString != actualString {
|
||||
t.Errorf("expected %v got %v", expectedString, actualString)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test dangerous cases
|
||||
|
||||
{
|
||||
expected := "foo';drop table users;"
|
||||
var actual string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
"test",
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when client_encoding not UTF8, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set standard_conforming_strings to off")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
`\'; drop table users; --`,
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryCloseBefore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user