diff --git a/replication.go b/replication.go index 52e6b915..14895ecf 100644 --- a/replication.go +++ b/replication.go @@ -163,6 +163,7 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { config.RuntimeParams = make(map[string]string) } config.RuntimeParams["replication"] = "database" + config.PreferSimpleProtocol = true c, err := Connect(config) if err != nil { diff --git a/replication_test.go b/replication_test.go index d06d73cd..54ac2b4a 100644 --- a/replication_test.go +++ b/replication_test.go @@ -343,3 +343,28 @@ func TestStandbyStatusParsing(t *testing.T) { t.Errorf("Unexpected write position %d", status.WalWritePosition) } } + +func TestSimpleProtocolEnforcement(t *testing.T) { + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + replicationConn := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn) + + query := "select count(*) from pg_replication_slots" + + // Check that the simple query protocol is used by default + rows, err := replicationConn.Query(query) + if err != nil { + t.Fatalf("Query failed: %v", err) + } + rows.Close() + + // Check that using the extended query protocol will fail + rows, err = replicationConn.QueryEx(context.Background(), query, &pgx.QueryExOptions{SimpleProtocol: false}) + if err == nil { + t.Fatal("Query expected to fail.") + } + rows.Close() +}