From b52e83255c174ce7deea91afb065862df019d420 Mon Sep 17 00:00:00 2001 From: Alejandro Durante Date: Sat, 12 Mar 2022 17:27:58 -0300 Subject: [PATCH] Prevent race condition in stopped flag --- pond.go | 15 ++++++--------- pond_test.go | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/pond.go b/pond.go index 986bde5..865f3e1 100644 --- a/pond.go +++ b/pond.go @@ -93,8 +93,7 @@ type WorkerPool struct { workersWaitGroup sync.WaitGroup tasksWaitGroup sync.WaitGroup mutex sync.Mutex - stopped bool - stoppedOnce sync.Once + stopped int32 } // New creates a worker pool with that can scale up to the given maximum number of workers (maxWorkers). @@ -213,7 +212,7 @@ func (p *WorkerPool) CompletedTasks() uint64 { // Stopped returns true if the pool has been stopped and is no longer accepting tasks, and false otherwise. func (p *WorkerPool) Stopped() bool { - return p.stopped + return atomic.LoadInt32(&p.stopped) == 1 } // Submit sends a task to this worker pool for execution. If the queue is full, @@ -350,10 +349,8 @@ func (p *WorkerPool) StopAndWaitFor(deadline time.Duration) { } func (p *WorkerPool) stop(waitForQueuedTasksToComplete bool) { - // Mark pool as stopped (only once, in case multiple concurrent calls to StopAndWait are made) - p.stoppedOnce.Do(func() { - p.stopped = true - }) + // Mark pool as stopped + atomic.StoreInt32(&p.stopped, 1) if waitForQueuedTasksToComplete { // Wait for all queued tasks to complete @@ -464,7 +461,7 @@ func (p *WorkerPool) incrementWorkerCount() (incremented bool) { // Increment worker count atomic.AddInt32(&p.workerCount, 1) - // Increment waiting group semaphore + // Increment wait group p.workersWaitGroup.Add(1) } @@ -479,7 +476,7 @@ func (p *WorkerPool) decrementWorkerCount() { // Decrement worker count atomic.AddInt32(&p.workerCount, -1) - // Decrement waiting group semaphore + // Decrement wait group p.workersWaitGroup.Done() } diff --git a/pond_test.go b/pond_test.go index 9785543..9324167 100644 --- a/pond_test.go +++ b/pond_test.go @@ -43,6 +43,6 @@ func TestPurgeAfterPoolStopped(t *testing.T) { assertEqual(t, 1, pool.RunningWorkers()) // Simulate purger goroutine attempting to stop a worker after tasks channel is closed - pool.stopped = true + atomic.StoreInt32(&pool.stopped, 1) pool.stopIdleWorker() }