diff --git a/pond.go b/pond.go index 9f29c5b..9667300 100644 --- a/pond.go +++ b/pond.go @@ -88,11 +88,12 @@ type WorkerPool struct { successfulTaskCount uint64 failedTaskCount uint64 // Private properties - tasks chan func() - stopOnce sync.Once - waitGroup sync.WaitGroup - mutex sync.Mutex - stopped bool + tasks chan func() + workersWaitGroup sync.WaitGroup + tasksWaitGroup sync.WaitGroup + purgerDoneChan chan struct{} + mutex sync.Mutex + stopped bool } // New creates a worker pool with that can scale up to the given maximum number of workers (maxWorkers). @@ -138,6 +139,7 @@ func New(maxWorkers, maxCapacity int, options ...Option) *WorkerPool { pool.tasks = make(chan func(), pool.maxCapacity) // Start purger goroutine + pool.purgerDoneChan = make(chan struct{}) go pool.purge() // Start minWorkers workers @@ -242,12 +244,14 @@ func (p *WorkerPool) submit(task func(), mustSubmit bool) (submitted bool) { // Increment submitted and waiting task counters as soon as we receive a task atomic.AddUint64(&p.submittedTaskCount, 1) atomic.AddUint64(&p.waitingTaskCount, 1) + p.tasksWaitGroup.Add(1) defer func() { if !submitted { // Task was not sumitted to the pool, decrement submitted and waiting task counters atomic.AddUint64(&p.submittedTaskCount, ^uint64(0)) atomic.AddUint64(&p.waitingTaskCount, ^uint64(0)) + p.tasksWaitGroup.Done() } }() @@ -328,46 +332,24 @@ func (p *WorkerPool) SubmitBefore(task func(), deadline time.Duration) { // Stop causes this pool to stop accepting new tasks and signals all workers to stop processing new tasks. // Tasks being processed by workers will continue until completion unless the process is terminated. -// This method can only be called once. func (p *WorkerPool) Stop() { - p.stopOnce.Do(func() { - // Mark pool as stopped - p.stopped = true - - // Stop accepting new tasks - close(p.tasks) - - // Terminate all workers & purger goroutine - p.contextCancel() - }) + go p.stop() } // StopAndWait causes this pool to stop accepting new tasks and then waits for all tasks in the queue -// to complete before returning. This method can only be called once. +// to complete before returning. func (p *WorkerPool) StopAndWait() { - p.stopOnce.Do(func() { - // Mark pool as stopped - p.stopped = true - - // Stop accepting new tasks - close(p.tasks) - - // Wait for all workers to exit - p.waitGroup.Wait() - - // Terminate all workers & purger goroutine - p.contextCancel() - }) + p.stop() } // StopAndWaitFor stops this pool and waits for all tasks in the queue to complete before returning -// or until the given deadline is reached, whichever comes first. This method can only be called once. +// or until the given deadline is reached, whichever comes first. func (p *WorkerPool) StopAndWaitFor(deadline time.Duration) { - // Detect if worker pool is already stopped + // Launch goroutine to detect when worker pool has stopped gracefully workersDone := make(chan struct{}) go func() { - p.StopAndWait() + p.stop() workersDone <- struct{}{} }() @@ -381,19 +363,38 @@ func (p *WorkerPool) StopAndWaitFor(deadline time.Duration) { } } +func (p *WorkerPool) stop() { + // Mark pool as stopped + p.stopped = true + + // Wait for all queued tasks to complete + p.tasksWaitGroup.Wait() + + // Terminate all workers & purger goroutine + p.contextCancel() + + // Wait for all workers to exit + p.workersWaitGroup.Wait() + + // Wait purger goroutine to exit + <-p.purgerDoneChan + + // close tasks channel + close(p.tasks) +} + // purge represents the work done by the purger goroutine func (p *WorkerPool) purge() { + defer func() { p.purgerDoneChan <- struct{}{} }() idleTicker := time.NewTicker(p.idleTimeout) defer idleTicker.Stop() for { select { - // Timed out waiting for any activity to happen, attempt to kill an idle worker + // Timed out waiting for any activity to happen, attempt to stop an idle worker case <-idleTicker.C: - if p.IdleWorkers() > 0 && p.RunningWorkers() > p.minWorkers { - p.tasks <- nil - } + p.stopIdleWorker() // Pool context was cancelled, exit case <-p.context.Done(): return @@ -401,6 +402,13 @@ func (p *WorkerPool) purge() { } } +// stopIdleWorker attempts to stop an idle worker by sending it a nil task +func (p *WorkerPool) stopIdleWorker() { + if p.IdleWorkers() > 0 && p.RunningWorkers() > p.minWorkers && !p.Stopped() { + p.tasks <- nil + } +} + // startWorkers creates new worker goroutines to run the given tasks func (p *WorkerPool) maybeStartWorker(firstTask func()) bool { @@ -426,6 +434,7 @@ func (p *WorkerPool) executeTask(task func()) { // Invoke panic handler p.panicHandler(panic) } + p.tasksWaitGroup.Done() }() // Decrement waiting task count @@ -451,7 +460,7 @@ func (p *WorkerPool) incrementWorkerCount() bool { p.mutex.Unlock() // Increment waiting group semaphore - p.waitGroup.Add(1) + p.workersWaitGroup.Add(1) return true } @@ -464,7 +473,7 @@ func (p *WorkerPool) decrementWorkerCount() { p.mutex.Unlock() // Decrement waiting group semaphore - p.waitGroup.Done() + p.workersWaitGroup.Done() } // Group creates a new task group diff --git a/pond_test.go b/pond_test.go index e56aca2..9785543 100644 --- a/pond_test.go +++ b/pond_test.go @@ -1,6 +1,7 @@ package pond import ( + "sync/atomic" "testing" "time" ) @@ -29,3 +30,19 @@ func TestNewWithInconsistentOptions(t *testing.T) { assertEqual(t, 1, pool.minWorkers) assertEqual(t, defaultIdleTimeout, pool.idleTimeout) } + +func TestPurgeAfterPoolStopped(t *testing.T) { + + pool := New(1, 1) + + var doneCount int32 + pool.SubmitAndWait(func() { + atomic.AddInt32(&doneCount, 1) + }) + assertEqual(t, int32(1), atomic.LoadInt32(&doneCount)) + assertEqual(t, 1, pool.RunningWorkers()) + + // Simulate purger goroutine attempting to stop a worker after tasks channel is closed + pool.stopped = true + pool.stopIdleWorker() +}