2
0

unlock connection when context is pre-canceled

This commit is contained in:
Jack Christensen
2019-04-05 12:06:59 -05:00
parent 408837dcb1
commit 7ad3625edd
2 changed files with 171 additions and 0 deletions
+5
View File
@@ -504,6 +504,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
select {
case <-ctx.Done():
pgConn.unlock()
return nil, ctx.Err()
default:
}
@@ -626,6 +627,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
select {
case <-ctx.Done():
pgConn.unlock()
return ctx.Err()
default:
}
@@ -668,6 +670,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = ctx.Err()
pgConn.unlock()
return multiResult
default:
}
@@ -828,6 +831,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
select {
case <-ctx.Done():
pgConn.unlock()
return "", ctx.Err()
default:
}
@@ -1278,6 +1282,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
case <-ctx.Done():
multiResult.closed = true
multiResult.err = ctx.Err()
pgConn.unlock()
return multiResult
default:
}
+166
View File
@@ -255,6 +255,23 @@ func TestConnPrepareSyntaxError(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnPrepareContextPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithCancel(context.Background())
cancel()
psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
require.Nil(t, psd)
require.Error(t, err)
require.Equal(t, context.Canceled, err)
ensureConnValid(t, pgConn)
}
func TestConnExec(t *testing.T) {
t.Parallel()
@@ -360,6 +377,22 @@ func TestConnExecContextCanceled(t *testing.T) {
assert.False(t, pgConn.IsAlive())
}
func TestConnExecContextPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
require.Error(t, err)
require.Equal(t, context.Canceled, err)
ensureConnValid(t, pgConn)
}
func TestConnExecParams(t *testing.T) {
t.Parallel()
@@ -449,6 +482,22 @@ func TestConnExecParamsCanceled(t *testing.T) {
assert.False(t, pgConn.IsAlive())
}
func TestConnExecParamsPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
require.Error(t, result.Err)
require.Equal(t, context.Canceled, result.Err)
ensureConnValid(t, pgConn)
}
func TestConnExecPrepared(t *testing.T) {
t.Parallel()
@@ -558,6 +607,25 @@ func TestConnExecPreparedCanceled(t *testing.T) {
assert.False(t, pgConn.IsAlive())
}
func TestConnExecPreparedPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
require.Error(t, result.Err)
require.Equal(t, context.Canceled, result.Err)
ensureConnValid(t, pgConn)
}
func TestConnExecBatch(t *testing.T) {
t.Parallel()
@@ -590,6 +658,31 @@ func TestConnExecBatch(t *testing.T) {
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
}
func TestConnExecBatchPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
require.NoError(t, err)
batch := &pgconn.Batch{}
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
require.Error(t, err)
require.Equal(t, context.Canceled, err)
ensureConnValid(t, pgConn)
}
func TestConnLocking(t *testing.T) {
t.Parallel()
@@ -726,6 +819,24 @@ func TestConnWaitForNotification(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnWaitForNotificationPrecanceled(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = pgConn.WaitForNotification(ctx)
require.Equal(t, context.Canceled, err)
ensureConnValid(t, pgConn)
}
func TestConnWaitForNotificationTimeout(t *testing.T) {
t.Parallel()
@@ -855,6 +966,25 @@ func TestConnCopyToCanceled(t *testing.T) {
assert.False(t, pgConn.IsAlive())
}
func TestConnCopyToPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
outputWriter := &bytes.Buffer{}
ctx, cancel := context.WithCancel(context.Background())
cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
require.Error(t, err)
require.Equal(t, context.Canceled, err)
assert.Equal(t, pgconn.CommandTag(""), res)
ensureConnValid(t, pgConn)
}
func TestConnCopyFrom(t *testing.T) {
t.Parallel()
@@ -926,6 +1056,42 @@ func TestConnCopyFromCanceled(t *testing.T) {
assert.False(t, pgConn.IsAlive())
}
func TestConnCopyFromPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
r, w := io.Pipe()
go func() {
for i := 0; i < 1000000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
if err != nil {
return
}
time.Sleep(time.Microsecond)
}
}()
ctx, cancel := context.WithCancel(context.Background())
cancel()
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err)
require.Equal(t, context.Canceled, err)
assert.Equal(t, pgconn.CommandTag(""), ct)
ensureConnValid(t, pgConn)
}
func TestConnCopyFromGzipReader(t *testing.T) {
t.Parallel()