Added LastStmtSent and use it to retry on errors if statement was not sent
Previously, a failed connection could be put back in a pool and when the next query was attempted it would fail immediately trying to prepare the query or reset the deadline. It wasn't clear if the Query or Exec call could safely be retried since there was no way to know where it failed. You can now call LastQuerySent and if it returns false then you're guaranteed that the last call to Query(Ex)/Exec(Ex) didn't get far enough to attempt to send the query. The call can be retried with a new connection. This is used in the stdlib to return a ErrBadConn if a network error occurred and the statement was not attempted. Fixes #427
This commit is contained in:
@@ -135,6 +135,7 @@ type Conn struct {
|
|||||||
|
|
||||||
pendingReadyForQueryCount int // number of ReadyForQuery messages expected
|
pendingReadyForQueryCount int // number of ReadyForQuery messages expected
|
||||||
cancelQueryCompleted chan struct{}
|
cancelQueryCompleted chan struct{}
|
||||||
|
lastStmtSent bool
|
||||||
|
|
||||||
// context support
|
// context support
|
||||||
ctxInProgress bool
|
ctxInProgress bool
|
||||||
@@ -1731,6 +1732,7 @@ func (c *Conn) Ping(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) {
|
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) {
|
||||||
|
c.lastStmtSent = false
|
||||||
err := c.waitForPreviousCancelQuery(ctx)
|
err := c.waitForPreviousCancelQuery(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1770,6 +1772,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
||||||
|
c.lastStmtSent = true
|
||||||
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
|
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1786,6 +1789,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||||||
|
|
||||||
buf = appendSync(buf)
|
buf = appendSync(buf)
|
||||||
|
|
||||||
|
c.lastStmtSent = true
|
||||||
n, err := c.conn.Write(buf)
|
n, err := c.conn.Write(buf)
|
||||||
if err != nil && fatalWriteErr(n, err) {
|
if err != nil && fatalWriteErr(n, err) {
|
||||||
c.die(err)
|
c.die(err)
|
||||||
@@ -1803,11 +1807,13 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.lastStmtSent = true
|
||||||
err = c.sendPreparedQuery(ps, arguments...)
|
err = c.sendPreparedQuery(ps, arguments...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
c.lastStmtSent = true
|
||||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1978,3 +1984,14 @@ func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) {
|
|||||||
|
|
||||||
return nameOIDs, err
|
return nameOIDs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to
|
||||||
|
// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets
|
||||||
|
// the value to false initially until the statement has been sent. This does
|
||||||
|
// NOT mean that the statement was successful or even received, it just means
|
||||||
|
// that a write was attempted and therefore it could have been executed. Calls
|
||||||
|
// to prepare a statement are ignored, only when the prepared statement is
|
||||||
|
// attempted to be executed will this return true.
|
||||||
|
func (c *Conn) LastStmtSent() bool {
|
||||||
|
return c.lastStmtSent
|
||||||
|
}
|
||||||
|
|||||||
@@ -1131,12 +1131,32 @@ func TestExecFailure(t *testing.T) {
|
|||||||
if _, err := conn.Exec("selct;"); err == nil {
|
if _, err := conn.Exec("selct;"); err == nil {
|
||||||
t.Fatal("Expected SQL syntax error")
|
t.Fatal("Expected SQL syntax error")
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
|
||||||
rows, _ := conn.Query("select 1")
|
rows, _ := conn.Query("select 1")
|
||||||
rows.Close()
|
rows.Close()
|
||||||
if rows.Err() != nil {
|
if rows.Err() != nil {
|
||||||
t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err())
|
t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err())
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecFailureWithArguments(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
if _, err := conn.Exec("selct $1;", 1); err == nil {
|
||||||
|
t.Fatal("Expected SQL syntax error")
|
||||||
|
}
|
||||||
|
if conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return false")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExContextWithoutCancelation(t *testing.T) {
|
func TestExecExContextWithoutCancelation(t *testing.T) {
|
||||||
@@ -1155,6 +1175,9 @@ func TestExecExContextWithoutCancelation(t *testing.T) {
|
|||||||
if commandTag != "CREATE TABLE" {
|
if commandTag != "CREATE TABLE" {
|
||||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
||||||
@@ -1169,12 +1192,35 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
|||||||
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
|
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
|
||||||
t.Fatal("Expected SQL syntax error")
|
t.Fatal("Expected SQL syntax error")
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
|
||||||
rows, _ := conn.Query("select 1")
|
rows, _ := conn.Query("select 1")
|
||||||
rows.Close()
|
rows.Close()
|
||||||
if rows.Err() != nil {
|
if rows.Err() != nil {
|
||||||
t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err())
|
t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err())
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
if _, err := conn.ExecEx(ctx, "selct $1;", nil, 1); err == nil {
|
||||||
|
t.Fatal("Expected SQL syntax error")
|
||||||
|
}
|
||||||
|
if conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return false")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
||||||
@@ -1193,10 +1239,27 @@ func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
|||||||
if err != context.Canceled {
|
if err != context.Canceled {
|
||||||
t.Fatalf("Expected context.Canceled err, got %v", err)
|
t.Fatalf("Expected context.Canceled err, got %v", err)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecFailureCloseBefore(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
closeConn(t, conn)
|
||||||
|
|
||||||
|
if _, err := conn.Exec("select 1"); err == nil {
|
||||||
|
t.Fatal("Expected network error")
|
||||||
|
}
|
||||||
|
if conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecExExtendedProtocol(t *testing.T) {
|
func TestExecExExtendedProtocol(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -1246,6 +1309,9 @@ func TestExecExSimpleProtocol(t *testing.T) {
|
|||||||
if commandTag != "CREATE TABLE" {
|
if commandTag != "CREATE TABLE" {
|
||||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
|
||||||
commandTag, err = conn.ExecEx(
|
commandTag, err = conn.ExecEx(
|
||||||
ctx,
|
ctx,
|
||||||
@@ -1259,6 +1325,9 @@ func TestExecExSimpleProtocol(t *testing.T) {
|
|||||||
if commandTag != "INSERT 0 1" {
|
if commandTag != "INSERT 0 1" {
|
||||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
|
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
|
||||||
@@ -1281,6 +1350,9 @@ func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
|
|||||||
if commandTag != "INSERT 0 1" {
|
if commandTag != "INSERT 0 1" {
|
||||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
|
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
|
||||||
@@ -1300,6 +1372,9 @@ func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error but got none")
|
t.Fatal("expected error but got none")
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) {
|
func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) {
|
||||||
@@ -1328,6 +1403,23 @@ func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error but got none")
|
t.Fatal("expected error but got none")
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecExFailureCloseBefore(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
closeConn(t, conn)
|
||||||
|
|
||||||
|
if _, err := conn.ExecEx(context.Background(), "select 1", nil); err == nil {
|
||||||
|
t.Fatal("Expected network error")
|
||||||
|
}
|
||||||
|
if conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return false")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepare(t *testing.T) {
|
func TestPrepare(t *testing.T) {
|
||||||
|
|||||||
@@ -368,6 +368,7 @@ type QueryExOptions struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
|
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
|
||||||
|
c.lastStmtSent = false
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
rows = c.getRows(sql, args)
|
rows = c.getRows(sql, args)
|
||||||
|
|
||||||
@@ -395,6 +396,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
||||||
|
c.lastStmtSent = true
|
||||||
err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
@@ -414,6 +416,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||||||
|
|
||||||
buf = appendSync(buf)
|
buf = appendSync(buf)
|
||||||
|
|
||||||
|
c.lastStmtSent = true
|
||||||
n, err := c.conn.Write(buf)
|
n, err := c.conn.Write(buf)
|
||||||
if err != nil && fatalWriteErr(n, err) {
|
if err != nil && fatalWriteErr(n, err) {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
@@ -460,6 +463,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||||||
rows.sql = ps.SQL
|
rows.sql = ps.SQL
|
||||||
rows.fields = ps.FieldDescriptions
|
rows.fields = ps.FieldDescriptions
|
||||||
|
|
||||||
|
c.lastStmtSent = true
|
||||||
err = c.sendPreparedQuery(ps, args...)
|
err = c.sendPreparedQuery(ps, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
|
|||||||
@@ -283,6 +283,9 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("conn.Query failed: %v", err)
|
t.Fatalf("conn.Query failed: %v", err)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
@@ -431,6 +434,9 @@ func TestQueryEncodeError(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("conn.Query failure: %v", err)
|
t.Errorf("conn.Query failure: %v", err)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
rows.Next()
|
rows.Next()
|
||||||
@@ -1186,6 +1192,9 @@ func TestQueryExContextSuccess(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
|
||||||
var result, rowCount int
|
var result, rowCount int
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@@ -1263,6 +1272,9 @@ func TestQueryExContextCancelationCancelsQuery(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
t.Fatal("No rows should ever be ready -- context cancel apparently did not happen")
|
t.Fatal("No rows should ever be ready -- context cancel apparently did not happen")
|
||||||
@@ -1292,6 +1304,9 @@ func TestQueryRowExContextSuccess(t *testing.T) {
|
|||||||
if result != 42 {
|
if result != 42 {
|
||||||
t.Fatalf("Expected result 42, got %d", result)
|
t.Fatalf("Expected result 42, got %d", result)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
@@ -1331,6 +1346,9 @@ func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) {
|
|||||||
if err != context.Canceled {
|
if err != context.Canceled {
|
||||||
t.Fatalf("Expected context.Canceled error, got %v", err)
|
t.Fatalf("Expected context.Canceled error, got %v", err)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
@@ -1384,6 +1402,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||||||
if expected != actual {
|
if expected != actual {
|
||||||
t.Errorf("expected %v got %v", expected, actual)
|
t.Errorf("expected %v got %v", expected, actual)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -1401,6 +1422,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||||||
if expected != actual {
|
if expected != actual {
|
||||||
t.Errorf("expected %v got %v", expected, actual)
|
t.Errorf("expected %v got %v", expected, actual)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -1418,6 +1442,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||||||
if expected != actual {
|
if expected != actual {
|
||||||
t.Errorf("expected %v got %v", expected, actual)
|
t.Errorf("expected %v got %v", expected, actual)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -1435,6 +1462,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||||||
if bytes.Compare(actual, expected) != 0 {
|
if bytes.Compare(actual, expected) != 0 {
|
||||||
t.Errorf("expected %v got %v", expected, actual)
|
t.Errorf("expected %v got %v", expected, actual)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -1452,6 +1482,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||||||
if expected != actual {
|
if expected != actual {
|
||||||
t.Errorf("expected %v got %v", expected, actual)
|
t.Errorf("expected %v got %v", expected, actual)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test high-level type
|
// Test high-level type
|
||||||
@@ -1471,6 +1504,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||||||
if expected != actual {
|
if expected != actual {
|
||||||
t.Errorf("expected %v got %v", expected, actual)
|
t.Errorf("expected %v got %v", expected, actual)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test multiple args in single query
|
// Test multiple args in single query
|
||||||
@@ -1510,6 +1546,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||||||
if expectedString != actualString {
|
if expectedString != actualString {
|
||||||
t.Errorf("expected %v got %v", expectedString, actualString)
|
t.Errorf("expected %v got %v", expectedString, actualString)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test dangerous cases
|
// Test dangerous cases
|
||||||
@@ -1529,6 +1568,9 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||||||
if expected != actual {
|
if expected != actual {
|
||||||
t.Errorf("expected %v got %v", expected, actual)
|
t.Errorf("expected %v got %v", expected, actual)
|
||||||
}
|
}
|
||||||
|
if !conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return true")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
@@ -1577,3 +1619,17 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
|
|||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryExCloseBefore(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
closeConn(t, conn)
|
||||||
|
|
||||||
|
if _, err := conn.QueryEx(context.Background(), "select 1", nil); err == nil {
|
||||||
|
t.Fatal("Expected network error")
|
||||||
|
}
|
||||||
|
if conn.LastStmtSent() {
|
||||||
|
t.Error("Expected LastStmtSent to return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -292,6 +293,12 @@ func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
|
|||||||
|
|
||||||
args := valueToInterface(argsV)
|
args := valueToInterface(argsV)
|
||||||
commandTag, err := c.conn.Exec(query, args...)
|
commandTag, err := c.conn.Exec(query, args...)
|
||||||
|
// if we got a network error before we had a chance to send the query, retry
|
||||||
|
if err != nil && !c.conn.LastStmtSent() {
|
||||||
|
if _, is := err.(net.Error); is {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
}
|
||||||
return driver.RowsAffected(commandTag.RowsAffected()), err
|
return driver.RowsAffected(commandTag.RowsAffected()), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,6 +310,12 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
|
|||||||
args := namedValueToInterface(argsV)
|
args := namedValueToInterface(argsV)
|
||||||
|
|
||||||
commandTag, err := c.conn.ExecEx(ctx, query, nil, args...)
|
commandTag, err := c.conn.ExecEx(ctx, query, nil, args...)
|
||||||
|
// if we got a network error before we had a chance to send the query, retry
|
||||||
|
if err != nil && !c.conn.LastStmtSent() {
|
||||||
|
if _, is := err.(net.Error); is {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
}
|
||||||
return driver.RowsAffected(commandTag.RowsAffected()), err
|
return driver.RowsAffected(commandTag.RowsAffected()), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +336,12 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
|||||||
|
|
||||||
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
|
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// if we got a network error before we had a chance to send the query, retry
|
||||||
|
if !c.conn.LastStmtSent() {
|
||||||
|
if _, is := err.(net.Error); is {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,6 +358,11 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
|
|||||||
if !c.connConfig.PreferSimpleProtocol {
|
if !c.connConfig.PreferSimpleProtocol {
|
||||||
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
|
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// since PrepareEx failed, we didn't actually get to send the values, so
|
||||||
|
// we can safely retry
|
||||||
|
if _, is := err.(net.Error); is {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,6 +372,12 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
|
|||||||
|
|
||||||
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
|
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// if we got a network error before we had a chance to send the query, retry
|
||||||
|
if !c.conn.LastStmtSent() {
|
||||||
|
if _, is := err.(net.Error); is {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
@@ -989,6 +990,28 @@ func TestConnExecContextCancel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnExecContextFailureRetry(t *testing.T) {
|
||||||
|
db := openDB(t)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
// we get a connection, immediately close it, and then get it back
|
||||||
|
{
|
||||||
|
conn, err := stdlib.AcquireConn(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
stdlib.ReleaseConn(db, conn)
|
||||||
|
}
|
||||||
|
conn, err := db.Conn(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Conn unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := conn.ExecContext(context.Background(), "select 1"); err != driver.ErrBadConn {
|
||||||
|
t.Fatalf("Expected conn.ExecContext to return driver.ErrBadConn, but instead received: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnQueryContextSuccess(t *testing.T) {
|
func TestConnQueryContextSuccess(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
@@ -1083,6 +1106,28 @@ func TestConnQueryContextCancel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnQueryContextFailureRetry(t *testing.T) {
|
||||||
|
db := openDB(t)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
// we get a connection, immediately close it, and then get it back
|
||||||
|
{
|
||||||
|
conn, err := stdlib.AcquireConn(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
stdlib.ReleaseConn(db, conn)
|
||||||
|
}
|
||||||
|
conn, err := db.Conn(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Conn unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := conn.QueryContext(context.Background(), "select 1"); err != driver.ErrBadConn {
|
||||||
|
t.Fatalf("Expected conn.QueryContext to return driver.ErrBadConn, but instead received: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
|
func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|||||||
Reference in New Issue
Block a user