Added LastStmtSent and use it to retry on errors if statement was not sent
Previously, a failed connection could be put back in a pool and when the next query was attempted it would fail immediately trying to prepare the query or reset the deadline. It wasn't clear if the Query or Exec call could safely be retried since there was no way to know where it failed. You can now call LastQuerySent and if it returns false then you're guaranteed that the last call to Query(Ex)/Exec(Ex) didn't get far enough to attempt to send the query. The call can be retried with a new connection. This is used in the stdlib to return a ErrBadConn if a network error occurred and the statement was not attempted. Fixes #427
This commit is contained in:
@@ -283,6 +283,9 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: %v", err)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
@@ -431,6 +434,9 @@ func TestQueryEncodeError(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("conn.Query failure: %v", err)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
rows.Next()
|
||||
@@ -1186,6 +1192,9 @@ func TestQueryExContextSuccess(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
|
||||
var result, rowCount int
|
||||
for rows.Next() {
|
||||
@@ -1263,6 +1272,9 @@ func TestQueryExContextCancelationCancelsQuery(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
t.Fatal("No rows should ever be ready -- context cancel apparently did not happen")
|
||||
@@ -1292,6 +1304,9 @@ func TestQueryRowExContextSuccess(t *testing.T) {
|
||||
if result != 42 {
|
||||
t.Fatalf("Expected result 42, got %d", result)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
@@ -1331,6 +1346,9 @@ func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) {
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
@@ -1384,6 +1402,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1401,6 +1422,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1418,6 +1442,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1435,6 +1462,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
||||
if bytes.Compare(actual, expected) != 0 {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1452,6 +1482,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test high-level type
|
||||
@@ -1471,6 +1504,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test multiple args in single query
|
||||
@@ -1510,6 +1546,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
||||
if expectedString != actualString {
|
||||
t.Errorf("expected %v got %v", expectedString, actualString)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test dangerous cases
|
||||
@@ -1529,6 +1568,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
@@ -1577,3 +1619,17 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryExCloseBefore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
closeConn(t, conn)
|
||||
|
||||
if _, err := conn.QueryEx(context.Background(), "select 1", nil); err == nil {
|
||||
t.Fatal("Expected network error")
|
||||
}
|
||||
if conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return false")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user