unlock connection when context is pre-canceled
This commit is contained in:
@@ -504,6 +504,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pgConn.unlock()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
@@ -626,6 +627,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pgConn.unlock()
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
@@ -668,6 +670,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = ctx.Err()
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
default:
|
||||
}
|
||||
@@ -828,6 +831,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pgConn.unlock()
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
@@ -1278,6 +1282,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = ctx.Err()
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
default:
|
||||
}
|
||||
|
||||
+166
@@ -255,6 +255,23 @@ func TestConnPrepareSyntaxError(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnPrepareContextPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
||||
require.Nil(t, psd)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -360,6 +377,22 @@ func TestConnExecContextCanceled(t *testing.T) {
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
}
|
||||
|
||||
func TestConnExecContextPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -449,6 +482,22 @@ func TestConnExecParamsCanceled(t *testing.T) {
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
}
|
||||
|
||||
func TestConnExecParamsPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
require.Equal(t, context.Canceled, result.Err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecPrepared(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -558,6 +607,25 @@ func TestConnExecPreparedCanceled(t *testing.T) {
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
}
|
||||
|
||||
func TestConnExecPreparedPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
require.Equal(t, context.Canceled, result.Err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -590,6 +658,31 @@ func TestConnExecBatch(t *testing.T) {
|
||||
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
|
||||
}
|
||||
|
||||
func TestConnExecBatchPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
||||
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
|
||||
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnLocking(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -726,6 +819,24 @@ func TestConnWaitForNotification(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnWaitForNotificationPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
err = pgConn.WaitForNotification(ctx)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnWaitForNotificationTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -855,6 +966,25 @@ func TestConnCopyToCanceled(t *testing.T) {
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
}
|
||||
|
||||
func TestConnCopyToPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
outputWriter := &bytes.Buffer{}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
assert.Equal(t, pgconn.CommandTag(""), res)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnCopyFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -926,6 +1056,42 @@ func TestConnCopyFromCanceled(t *testing.T) {
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
}
|
||||
|
||||
func TestConnCopyFromPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||
a int4,
|
||||
b varchar
|
||||
)`).ReadAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, w := io.Pipe()
|
||||
go func() {
|
||||
for i := 0; i < 1000000; i++ {
|
||||
a := strconv.Itoa(i)
|
||||
b := "foo " + a + " bar"
|
||||
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Microsecond)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
assert.Equal(t, pgconn.CommandTag(""), ct)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromGzipReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user