From 15e5de604b1bb72f8c4cd47f67278434e0a90fca Mon Sep 17 00:00:00 2001 From: Alejandro Durante Date: Sat, 12 Mar 2022 09:33:09 -0300 Subject: [PATCH] Allow concurrent calls to StopAndWait --- pond.go | 47 ++++++++++++++++++++++++------------------- pond_blackbox_test.go | 29 ++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/pond.go b/pond.go index 9667300..70ab373 100644 --- a/pond.go +++ b/pond.go @@ -89,11 +89,12 @@ type WorkerPool struct { failedTaskCount uint64 // Private properties tasks chan func() + tasksCloseOnce sync.Once workersWaitGroup sync.WaitGroup tasksWaitGroup sync.WaitGroup - purgerDoneChan chan struct{} mutex sync.Mutex stopped bool + stoppedOnce sync.Once } // New creates a worker pool with that can scale up to the given maximum number of workers (maxWorkers). @@ -139,7 +140,7 @@ func New(maxWorkers, maxCapacity int, options ...Option) *WorkerPool { pool.tasks = make(chan func(), pool.maxCapacity) // Start purger goroutine - pool.purgerDoneChan = make(chan struct{}) + pool.workersWaitGroup.Add(1) go pool.purge() // Start minWorkers workers @@ -330,26 +331,27 @@ func (p *WorkerPool) SubmitBefore(task func(), deadline time.Duration) { }) } -// Stop causes this pool to stop accepting new tasks and signals all workers to stop processing new tasks. -// Tasks being processed by workers will continue until completion unless the process is terminated. +// Stop causes this pool to stop accepting new tasks and signals all workers to exit. +// Tasks being executed by workers will continue until completion (unless the process is terminated). +// Tasks in the queue will not be executed. func (p *WorkerPool) Stop() { - go p.stop() + go p.stop(false) } // StopAndWait causes this pool to stop accepting new tasks and then waits for all tasks in the queue // to complete before returning. func (p *WorkerPool) StopAndWait() { - p.stop() + p.stop(true) } -// StopAndWaitFor stops this pool and waits for all tasks in the queue to complete before returning -// or until the given deadline is reached, whichever comes first. +// StopAndWaitFor stops this pool and waits until either all tasks in the queue are completed +// or the given deadline is reached, whichever comes first. func (p *WorkerPool) StopAndWaitFor(deadline time.Duration) { // Launch goroutine to detect when worker pool has stopped gracefully workersDone := make(chan struct{}) go func() { - p.stop() + p.stop(true) workersDone <- struct{}{} }() @@ -363,29 +365,32 @@ func (p *WorkerPool) StopAndWaitFor(deadline time.Duration) { } } -func (p *WorkerPool) stop() { - // Mark pool as stopped - p.stopped = true +func (p *WorkerPool) stop(waitForQueuedTasksToComplete bool) { + // Mark pool as stopped (only once, in case multiple concurrent calls to StopAndWait are made) + p.stoppedOnce.Do(func() { + p.stopped = true + }) - // Wait for all queued tasks to complete - p.tasksWaitGroup.Wait() + if waitForQueuedTasksToComplete { + // Wait for all queued tasks to complete + p.tasksWaitGroup.Wait() + } // Terminate all workers & purger goroutine p.contextCancel() - // Wait for all workers to exit + // Wait for all workers & purger goroutine to exit p.workersWaitGroup.Wait() - // Wait purger goroutine to exit - <-p.purgerDoneChan - - // close tasks channel - close(p.tasks) + // close tasks channel (only once, in case multiple concurrent calls to StopAndWait are made) + p.tasksCloseOnce.Do(func() { + close(p.tasks) + }) } // purge represents the work done by the purger goroutine func (p *WorkerPool) purge() { - defer func() { p.purgerDoneChan <- struct{}{} }() + defer p.workersWaitGroup.Done() idleTicker := time.NewTicker(p.idleTimeout) defer idleTicker.Stop() diff --git a/pond_blackbox_test.go b/pond_blackbox_test.go index 36967b1..093af9c 100644 --- a/pond_blackbox_test.go +++ b/pond_blackbox_test.go @@ -3,6 +3,7 @@ package pond_test import ( "context" "fmt" + "sync" "sync/atomic" "testing" "time" @@ -568,3 +569,31 @@ func TestSubmitWithContext(t *testing.T) { assertEqual(t, int32(1), atomic.LoadInt32(&taskCount)) assertEqual(t, int32(0), atomic.LoadInt32(&doneCount)) } + +func TestConcurrentStopAndWait(t *testing.T) { + + pool := pond.New(1, 5) + + // Submit tasks + var doneCount int32 + for i := 0; i < 10; i++ { + pool.Submit(func() { + time.Sleep(1 * time.Millisecond) + atomic.AddInt32(&doneCount, 1) + }) + } + + wg := sync.WaitGroup{} + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + pool.StopAndWait() + assertEqual(t, int32(10), atomic.LoadInt32(&doneCount)) + }() + } + + wg.Wait() +}