diff --git a/pgconn.go b/pgconn.go index 223b8e3d..db741d47 100644 --- a/pgconn.go +++ b/pgconn.go @@ -504,6 +504,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ select { case <-ctx.Done(): + pgConn.unlock() return nil, ctx.Err() default: } @@ -626,6 +627,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { select { case <-ctx.Done(): + pgConn.unlock() return ctx.Err() default: } @@ -668,6 +670,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() + pgConn.unlock() return multiResult default: } @@ -828,6 +831,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): + pgConn.unlock() return "", ctx.Err() default: } @@ -1278,6 +1282,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() + pgConn.unlock() return multiResult default: } diff --git a/pgconn_test.go b/pgconn_test.go index b2514e48..66a4337b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -255,6 +255,23 @@ func TestConnPrepareSyntaxError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnPrepareContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) + require.Nil(t, psd) + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnExec(t *testing.T) { t.Parallel() @@ -360,6 +377,22 @@ func TestConnExecContextCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnExecParams(t *testing.T) { t.Parallel() @@ -449,6 +482,22 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecParamsPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, context.Canceled, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecPrepared(t *testing.T) { t.Parallel() @@ -558,6 +607,25 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecPreparedPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, context.Canceled, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatch(t *testing.T) { t.Parallel() @@ -590,6 +658,31 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnExecBatchPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnLocking(t *testing.T) { t.Parallel() @@ -726,6 +819,24 @@ func TestConnWaitForNotification(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnWaitForNotificationPrecanceled(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = pgConn.WaitForNotification(ctx) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnWaitForNotificationTimeout(t *testing.T) { t.Parallel() @@ -855,6 +966,25 @@ func TestConnCopyToCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnCopyToPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") + require.Error(t, err) + require.Equal(t, context.Canceled, err) + assert.Equal(t, pgconn.CommandTag(""), res) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFrom(t *testing.T) { t.Parallel() @@ -926,6 +1056,42 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnCopyFromPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.Error(t, err) + require.Equal(t, context.Canceled, err) + assert.Equal(t, pgconn.CommandTag(""), ct) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel()