2
0

Merge pull request #482 from fastest963/retry

Added LastStmtSent and use it to retry on errors if statement was not sent
This commit is contained in:
Jack Christensen
2018-12-01 10:38:10 -06:00
committed by GitHub
6 changed files with 244 additions and 0 deletions
+17
View File
@@ -135,6 +135,7 @@ type Conn struct {
pendingReadyForQueryCount int // number of ReadyForQuery messages expected pendingReadyForQueryCount int // number of ReadyForQuery messages expected
cancelQueryCompleted chan struct{} cancelQueryCompleted chan struct{}
lastStmtSent bool
// context support // context support
ctxInProgress bool ctxInProgress bool
@@ -1731,6 +1732,7 @@ func (c *Conn) Ping(ctx context.Context) error {
} }
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) {
c.lastStmtSent = false
err := c.waitForPreviousCancelQuery(ctx) err := c.waitForPreviousCancelQuery(ctx)
if err != nil { if err != nil {
return "", err return "", err
@@ -1770,6 +1772,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
}() }()
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
c.lastStmtSent = true
err = c.sanitizeAndSendSimpleQuery(sql, arguments...) err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
if err != nil { if err != nil {
return "", err return "", err
@@ -1786,6 +1789,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
buf = appendSync(buf) buf = appendSync(buf)
c.lastStmtSent = true
n, err := c.conn.Write(buf) n, err := c.conn.Write(buf)
if err != nil && fatalWriteErr(n, err) { if err != nil && fatalWriteErr(n, err) {
c.die(err) c.die(err)
@@ -1803,11 +1807,13 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
} }
} }
c.lastStmtSent = true
err = c.sendPreparedQuery(ps, arguments...) err = c.sendPreparedQuery(ps, arguments...)
if err != nil { if err != nil {
return "", err return "", err
} }
} else { } else {
c.lastStmtSent = true
if err = c.sendQuery(sql, arguments...); err != nil { if err = c.sendQuery(sql, arguments...); err != nil {
return return
} }
@@ -1978,3 +1984,14 @@ func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) {
return nameOIDs, err return nameOIDs, err
} }
// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to
// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets
// the value to false initially until the statement has been sent. This does
// NOT mean that the statement was successful or even received, it just means
// that a write was attempted and therefore it could have been executed. Calls
// to prepare a statement are ignored, only when the prepared statement is
// attempted to be executed will this return true.
func (c *Conn) LastStmtSent() bool {
return c.lastStmtSent
}
+92
View File
@@ -1131,12 +1131,32 @@ func TestExecFailure(t *testing.T) {
if _, err := conn.Exec("selct;"); err == nil { if _, err := conn.Exec("selct;"); err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
rows, _ := conn.Query("select 1") rows, _ := conn.Query("select 1")
rows.Close() rows.Close()
if rows.Err() != nil { if rows.Err() != nil {
t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err())
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
func TestExecFailureWithArguments(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
if _, err := conn.Exec("selct $1;", 1); err == nil {
t.Fatal("Expected SQL syntax error")
}
if conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return false")
}
} }
func TestExecExContextWithoutCancelation(t *testing.T) { func TestExecExContextWithoutCancelation(t *testing.T) {
@@ -1155,6 +1175,9 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
if commandTag != "CREATE TABLE" { if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
func TestExecExContextFailureWithoutCancelation(t *testing.T) { func TestExecExContextFailureWithoutCancelation(t *testing.T) {
@@ -1169,12 +1192,35 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) {
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil { if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
rows, _ := conn.Query("select 1") rows, _ := conn.Query("select 1")
rows.Close() rows.Close()
if rows.Err() != nil { if rows.Err() != nil {
t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err())
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
if _, err := conn.ExecEx(ctx, "selct $1;", nil, 1); err == nil {
t.Fatal("Expected SQL syntax error")
}
if conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return false")
}
} }
func TestExecExContextCancelationCancelsQuery(t *testing.T) { func TestExecExContextCancelationCancelsQuery(t *testing.T) {
@@ -1193,10 +1239,27 @@ func TestExecExContextCancelationCancelsQuery(t *testing.T) {
if err != context.Canceled { if err != context.Canceled {
t.Fatalf("Expected context.Canceled err, got %v", err) t.Fatalf("Expected context.Canceled err, got %v", err)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestExecFailureCloseBefore(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
closeConn(t, conn)
if _, err := conn.Exec("select 1"); err == nil {
t.Fatal("Expected network error")
}
if conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return false")
}
}
func TestExecExExtendedProtocol(t *testing.T) { func TestExecExExtendedProtocol(t *testing.T) {
t.Parallel() t.Parallel()
@@ -1246,6 +1309,9 @@ func TestExecExSimpleProtocol(t *testing.T) {
if commandTag != "CREATE TABLE" { if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
commandTag, err = conn.ExecEx( commandTag, err = conn.ExecEx(
ctx, ctx,
@@ -1259,6 +1325,9 @@ func TestExecExSimpleProtocol(t *testing.T) {
if commandTag != "INSERT 0 1" { if commandTag != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
@@ -1281,6 +1350,9 @@ func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
if commandTag != "INSERT 0 1" { if commandTag != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag) t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
@@ -1300,6 +1372,9 @@ func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
if err == nil { if err == nil {
t.Fatal("expected error but got none") t.Fatal("expected error but got none")
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) { func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) {
@@ -1328,6 +1403,23 @@ func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) {
if err == nil { if err == nil {
t.Fatal("expected error but got none") t.Fatal("expected error but got none")
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
func TestExecExFailureCloseBefore(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
closeConn(t, conn)
if _, err := conn.ExecEx(context.Background(), "select 1", nil); err == nil {
t.Fatal("Expected network error")
}
if conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return false")
}
} }
func TestPrepare(t *testing.T) { func TestPrepare(t *testing.T) {
+4
View File
@@ -368,6 +368,7 @@ type QueryExOptions struct {
} }
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
c.lastStmtSent = false
c.lastActivityTime = time.Now() c.lastActivityTime = time.Now()
rows = c.getRows(sql, args) rows = c.getRows(sql, args)
@@ -395,6 +396,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
} }
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
c.lastStmtSent = true
err = c.sanitizeAndSendSimpleQuery(sql, args...) err = c.sanitizeAndSendSimpleQuery(sql, args...)
if err != nil { if err != nil {
rows.fatal(err) rows.fatal(err)
@@ -414,6 +416,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
buf = appendSync(buf) buf = appendSync(buf)
c.lastStmtSent = true
n, err := c.conn.Write(buf) n, err := c.conn.Write(buf)
if err != nil && fatalWriteErr(n, err) { if err != nil && fatalWriteErr(n, err) {
rows.fatal(err) rows.fatal(err)
@@ -460,6 +463,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
rows.sql = ps.SQL rows.sql = ps.SQL
rows.fields = ps.FieldDescriptions rows.fields = ps.FieldDescriptions
c.lastStmtSent = true
err = c.sendPreparedQuery(ps, args...) err = c.sendPreparedQuery(ps, args...)
if err != nil { if err != nil {
rows.fatal(err) rows.fatal(err)
+56
View File
@@ -283,6 +283,9 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: %v", err) t.Fatalf("conn.Query failed: %v", err)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
rows.Close() rows.Close()
ensureConnValid(t, conn) ensureConnValid(t, conn)
@@ -431,6 +434,9 @@ func TestQueryEncodeError(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("conn.Query failure: %v", err) t.Errorf("conn.Query failure: %v", err)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
defer rows.Close() defer rows.Close()
rows.Next() rows.Next()
@@ -1186,6 +1192,9 @@ func TestQueryExContextSuccess(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
var result, rowCount int var result, rowCount int
for rows.Next() { for rows.Next() {
@@ -1263,6 +1272,9 @@ func TestQueryExContextCancelationCancelsQuery(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
for rows.Next() { for rows.Next() {
t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") t.Fatal("No rows should ever be ready -- context cancel apparently did not happen")
@@ -1292,6 +1304,9 @@ func TestQueryRowExContextSuccess(t *testing.T) {
if result != 42 { if result != 42 {
t.Fatalf("Expected result 42, got %d", result) t.Fatalf("Expected result 42, got %d", result)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
@@ -1331,6 +1346,9 @@ func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) {
if err != context.Canceled { if err != context.Canceled {
t.Fatalf("Expected context.Canceled error, got %v", err) t.Fatalf("Expected context.Canceled error, got %v", err)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
@@ -1384,6 +1402,9 @@ func TestConnSimpleProtocol(t *testing.T) {
if expected != actual { if expected != actual {
t.Errorf("expected %v got %v", expected, actual) t.Errorf("expected %v got %v", expected, actual)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
{ {
@@ -1401,6 +1422,9 @@ func TestConnSimpleProtocol(t *testing.T) {
if expected != actual { if expected != actual {
t.Errorf("expected %v got %v", expected, actual) t.Errorf("expected %v got %v", expected, actual)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
{ {
@@ -1418,6 +1442,9 @@ func TestConnSimpleProtocol(t *testing.T) {
if expected != actual { if expected != actual {
t.Errorf("expected %v got %v", expected, actual) t.Errorf("expected %v got %v", expected, actual)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
{ {
@@ -1435,6 +1462,9 @@ func TestConnSimpleProtocol(t *testing.T) {
if bytes.Compare(actual, expected) != 0 { if bytes.Compare(actual, expected) != 0 {
t.Errorf("expected %v got %v", expected, actual) t.Errorf("expected %v got %v", expected, actual)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
{ {
@@ -1452,6 +1482,9 @@ func TestConnSimpleProtocol(t *testing.T) {
if expected != actual { if expected != actual {
t.Errorf("expected %v got %v", expected, actual) t.Errorf("expected %v got %v", expected, actual)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
// Test high-level type // Test high-level type
@@ -1471,6 +1504,9 @@ func TestConnSimpleProtocol(t *testing.T) {
if expected != actual { if expected != actual {
t.Errorf("expected %v got %v", expected, actual) t.Errorf("expected %v got %v", expected, actual)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
// Test multiple args in single query // Test multiple args in single query
@@ -1510,6 +1546,9 @@ func TestConnSimpleProtocol(t *testing.T) {
if expectedString != actualString { if expectedString != actualString {
t.Errorf("expected %v got %v", expectedString, actualString) t.Errorf("expected %v got %v", expectedString, actualString)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
// Test dangerous cases // Test dangerous cases
@@ -1529,6 +1568,9 @@ func TestConnSimpleProtocol(t *testing.T) {
if expected != actual { if expected != actual {
t.Errorf("expected %v got %v", expected, actual) t.Errorf("expected %v got %v", expected, actual)
} }
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
} }
ensureConnValid(t, conn) ensureConnValid(t, conn)
@@ -1577,3 +1619,17 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestQueryExCloseBefore(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
closeConn(t, conn)
if _, err := conn.QueryEx(context.Background(), "select 1", nil); err == nil {
t.Fatal("Expected network error")
}
if conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return false")
}
}
+30
View File
@@ -75,6 +75,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"net"
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
@@ -292,6 +293,12 @@ func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
args := valueToInterface(argsV) args := valueToInterface(argsV)
commandTag, err := c.conn.Exec(query, args...) commandTag, err := c.conn.Exec(query, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return driver.RowsAffected(commandTag.RowsAffected()), err return driver.RowsAffected(commandTag.RowsAffected()), err
} }
@@ -303,6 +310,12 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
args := namedValueToInterface(argsV) args := namedValueToInterface(argsV)
commandTag, err := c.conn.ExecEx(ctx, query, nil, args...) commandTag, err := c.conn.ExecEx(ctx, query, nil, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return driver.RowsAffected(commandTag.RowsAffected()), err return driver.RowsAffected(commandTag.RowsAffected()), err
} }
@@ -323,6 +336,12 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
rows, err := c.conn.Query(query, valueToInterface(argsV)...) rows, err := c.conn.Query(query, valueToInterface(argsV)...)
if err != nil { if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return nil, err return nil, err
} }
@@ -339,6 +358,11 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
if !c.connConfig.PreferSimpleProtocol { if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.PrepareEx(ctx, "", query, nil) ps, err := c.conn.PrepareEx(ctx, "", query, nil)
if err != nil { if err != nil {
// since PrepareEx failed, we didn't actually get to send the values, so
// we can safely retry
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
return nil, err return nil, err
} }
@@ -348,6 +372,12 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...) rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
if err != nil { if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return nil, err return nil, err
} }
+45
View File
@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
@@ -989,6 +990,28 @@ func TestConnExecContextCancel(t *testing.T) {
} }
} }
func TestConnExecContextFailureRetry(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
// we get a connection, immediately close it, and then get it back
{
conn, err := stdlib.AcquireConn(db)
if err != nil {
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
}
conn.Close()
stdlib.ReleaseConn(db, conn)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatalf("db.Conn unexpectedly failed: %v", err)
}
if _, err := conn.ExecContext(context.Background(), "select 1"); err != driver.ErrBadConn {
t.Fatalf("Expected conn.ExecContext to return driver.ErrBadConn, but instead received: %v", err)
}
}
func TestConnQueryContextSuccess(t *testing.T) { func TestConnQueryContextSuccess(t *testing.T) {
db := openDB(t) db := openDB(t)
defer closeDB(t, db) defer closeDB(t, db)
@@ -1083,6 +1106,28 @@ func TestConnQueryContextCancel(t *testing.T) {
} }
} }
func TestConnQueryContextFailureRetry(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
// we get a connection, immediately close it, and then get it back
{
conn, err := stdlib.AcquireConn(db)
if err != nil {
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
}
conn.Close()
stdlib.ReleaseConn(db, conn)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatalf("db.Conn unexpectedly failed: %v", err)
}
if _, err := conn.QueryContext(context.Background(), "select 1"); err != driver.ErrBadConn {
t.Fatalf("Expected conn.QueryContext to return driver.ErrBadConn, but instead received: %v", err)
}
}
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
db := openDB(t) db := openDB(t)
defer closeDB(t, db) defer closeDB(t, db)