unlock connection when context is pre-canceled
This commit is contained in:
@@ -504,6 +504,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
pgConn.unlock()
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -626,6 +627,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
pgConn.unlock()
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -668,6 +670,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = ctx.Err()
|
multiResult.err = ctx.Err()
|
||||||
|
pgConn.unlock()
|
||||||
return multiResult
|
return multiResult
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -828,6 +831,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
pgConn.unlock()
|
||||||
return "", ctx.Err()
|
return "", ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -1278,6 +1282,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = ctx.Err()
|
multiResult.err = ctx.Err()
|
||||||
|
pgConn.unlock()
|
||||||
return multiResult
|
return multiResult
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|||||||
+166
@@ -255,6 +255,23 @@ func TestConnPrepareSyntaxError(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnPrepareContextPrecanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
||||||
|
require.Nil(t, psd)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, context.Canceled, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExec(t *testing.T) {
|
func TestConnExec(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -360,6 +377,22 @@ func TestConnExecContextCanceled(t *testing.T) {
|
|||||||
assert.False(t, pgConn.IsAlive())
|
assert.False(t, pgConn.IsAlive())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnExecContextPrecanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, context.Canceled, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExecParams(t *testing.T) {
|
func TestConnExecParams(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -449,6 +482,22 @@ func TestConnExecParamsCanceled(t *testing.T) {
|
|||||||
assert.False(t, pgConn.IsAlive())
|
assert.False(t, pgConn.IsAlive())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnExecParamsPrecanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
|
||||||
|
require.Error(t, result.Err)
|
||||||
|
require.Equal(t, context.Canceled, result.Err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExecPrepared(t *testing.T) {
|
func TestConnExecPrepared(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -558,6 +607,25 @@ func TestConnExecPreparedCanceled(t *testing.T) {
|
|||||||
assert.False(t, pgConn.IsAlive())
|
assert.False(t, pgConn.IsAlive())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnExecPreparedPrecanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
||||||
|
require.Error(t, result.Err)
|
||||||
|
require.Equal(t, context.Canceled, result.Err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExecBatch(t *testing.T) {
|
func TestConnExecBatch(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -590,6 +658,31 @@ func TestConnExecBatch(t *testing.T) {
|
|||||||
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
|
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnExecBatchPrecanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
batch := &pgconn.Batch{}
|
||||||
|
|
||||||
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
||||||
|
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
|
||||||
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, context.Canceled, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnLocking(t *testing.T) {
|
func TestConnLocking(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -726,6 +819,24 @@ func TestConnWaitForNotification(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnWaitForNotificationPrecanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
err = pgConn.WaitForNotification(ctx)
|
||||||
|
require.Equal(t, context.Canceled, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnWaitForNotificationTimeout(t *testing.T) {
|
func TestConnWaitForNotificationTimeout(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -855,6 +966,25 @@ func TestConnCopyToCanceled(t *testing.T) {
|
|||||||
assert.False(t, pgConn.IsAlive())
|
assert.False(t, pgConn.IsAlive())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnCopyToPrecanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
outputWriter := &bytes.Buffer{}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, context.Canceled, err)
|
||||||
|
assert.Equal(t, pgconn.CommandTag(""), res)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnCopyFrom(t *testing.T) {
|
func TestConnCopyFrom(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -926,6 +1056,42 @@ func TestConnCopyFromCanceled(t *testing.T) {
|
|||||||
assert.False(t, pgConn.IsAlive())
|
assert.False(t, pgConn.IsAlive())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromPrecanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||||
|
a int4,
|
||||||
|
b varchar
|
||||||
|
)`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r, w := io.Pipe()
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 1000000; i++ {
|
||||||
|
a := strconv.Itoa(i)
|
||||||
|
b := "foo " + a + " bar"
|
||||||
|
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(time.Microsecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, context.Canceled, err)
|
||||||
|
assert.Equal(t, pgconn.CommandTag(""), ct)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnCopyFromGzipReader(t *testing.T) {
|
func TestConnCopyFromGzipReader(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user