diff --git a/benchmark/go.mod b/benchmark/go.mod index ac02c36..f1e7f41 100644 --- a/benchmark/go.mod +++ b/benchmark/go.mod @@ -3,7 +3,7 @@ module github.com/alitto/pond/benchmark go 1.17 require ( - github.com/alitto/pond v1.6.1 + github.com/alitto/pond v1.7.0 github.com/gammazero/workerpool v1.1.2 github.com/panjf2000/ants/v2 v2.4.7 ) diff --git a/examples/dynamic_size/go.mod b/examples/dynamic_size/go.mod index 8a91fb9..9563074 100644 --- a/examples/dynamic_size/go.mod +++ b/examples/dynamic_size/go.mod @@ -3,7 +3,7 @@ module github.com/alitto/pond/examples/dynamic_size go 1.17 require ( - github.com/alitto/pond v1.6.1 + github.com/alitto/pond v1.7.0 ) replace github.com/alitto/pond => ../../ diff --git a/examples/fixed_size/go.mod b/examples/fixed_size/go.mod index 6ac8aa1..3461f0e 100644 --- a/examples/fixed_size/go.mod +++ b/examples/fixed_size/go.mod @@ -3,7 +3,7 @@ module github.com/alitto/pond/examples/fixed_size go 1.17 require ( - github.com/alitto/pond v1.6.1 + github.com/alitto/pond v1.7.0 ) replace github.com/alitto/pond => ../../ diff --git a/examples/pool_context/go.mod b/examples/pool_context/go.mod index 449b708..9fe755d 100644 --- a/examples/pool_context/go.mod +++ b/examples/pool_context/go.mod @@ -2,6 +2,6 @@ module github.com/alitto/pond/examples/pool_context go 1.17 -require github.com/alitto/pond v1.6.1 +require github.com/alitto/pond v1.7.0 replace github.com/alitto/pond => ../../ diff --git a/examples/prometheus/go.mod b/examples/prometheus/go.mod index 51a7975..40d0388 100644 --- a/examples/prometheus/go.mod +++ b/examples/prometheus/go.mod @@ -3,7 +3,7 @@ module github.com/alitto/pond/examples/fixed_size go 1.17 require ( - github.com/alitto/pond v1.6.1 + github.com/alitto/pond v1.7.0 github.com/prometheus/client_golang v1.9.0 ) diff --git a/examples/task_group/go.mod b/examples/task_group/go.mod index 83c53ca..b4fd505 100644 --- a/examples/task_group/go.mod +++ b/examples/task_group/go.mod @@ -3,7 +3,7 @@ module github.com/alitto/pond/examples/task_group go 1.17 require ( - github.com/alitto/pond v1.6.1 + github.com/alitto/pond v1.7.0 ) replace github.com/alitto/pond => ../../ diff --git a/pond.go b/pond.go index 9667300..865f3e1 100644 --- a/pond.go +++ b/pond.go @@ -89,11 +89,11 @@ type WorkerPool struct { failedTaskCount uint64 // Private properties tasks chan func() + tasksCloseOnce sync.Once workersWaitGroup sync.WaitGroup tasksWaitGroup sync.WaitGroup - purgerDoneChan chan struct{} mutex sync.Mutex - stopped bool + stopped int32 } // New creates a worker pool with that can scale up to the given maximum number of workers (maxWorkers). @@ -139,7 +139,7 @@ func New(maxWorkers, maxCapacity int, options ...Option) *WorkerPool { pool.tasks = make(chan func(), pool.maxCapacity) // Start purger goroutine - pool.purgerDoneChan = make(chan struct{}) + pool.workersWaitGroup.Add(1) go pool.purge() // Start minWorkers workers @@ -212,7 +212,7 @@ func (p *WorkerPool) CompletedTasks() uint64 { // Stopped returns true if the pool has been stopped and is no longer accepting tasks, and false otherwise. func (p *WorkerPool) Stopped() bool { - return p.stopped + return atomic.LoadInt32(&p.stopped) == 1 } // Submit sends a task to this worker pool for execution. If the queue is full, @@ -230,7 +230,7 @@ func (p *WorkerPool) TrySubmit(task func()) bool { func (p *WorkerPool) submit(task func(), mustSubmit bool) (submitted bool) { if task == nil { - return false + return } if p.Stopped() { @@ -238,7 +238,7 @@ func (p *WorkerPool) submit(task func(), mustSubmit bool) (submitted bool) { if mustSubmit { panic(ErrSubmitOnStoppedPool) } - return false + return } // Increment submitted and waiting task counters as soon as we receive a task @@ -255,35 +255,19 @@ func (p *WorkerPool) submit(task func(), mustSubmit bool) (submitted bool) { } }() - runningWorkerCount := p.RunningWorkers() - - // Attempt to dispatch to an idle worker without blocking - if runningWorkerCount > 0 && p.IdleWorkers() > 0 { - select { - case p.tasks <- task: - submitted = true - return - default: - // No idle worker available, continue - } - } - // Start a worker as long as we haven't reached the limit - if runningWorkerCount < p.maxWorkers { - if ok := p.maybeStartWorker(task); ok { - submitted = true - return - } + if submitted = p.maybeStartWorker(task); submitted { + return } if !mustSubmit { + // Attempt to dispatch to an idle worker without blocking select { case p.tasks <- task: submitted = true return default: // Channel is full and can't wait for an idle worker, so need to exit - submitted = false return } } @@ -330,26 +314,27 @@ func (p *WorkerPool) SubmitBefore(task func(), deadline time.Duration) { }) } -// Stop causes this pool to stop accepting new tasks and signals all workers to stop processing new tasks. -// Tasks being processed by workers will continue until completion unless the process is terminated. +// Stop causes this pool to stop accepting new tasks and signals all workers to exit. +// Tasks being executed by workers will continue until completion (unless the process is terminated). +// Tasks in the queue will not be executed. func (p *WorkerPool) Stop() { - go p.stop() + go p.stop(false) } // StopAndWait causes this pool to stop accepting new tasks and then waits for all tasks in the queue // to complete before returning. func (p *WorkerPool) StopAndWait() { - p.stop() + p.stop(true) } -// StopAndWaitFor stops this pool and waits for all tasks in the queue to complete before returning -// or until the given deadline is reached, whichever comes first. +// StopAndWaitFor stops this pool and waits until either all tasks in the queue are completed +// or the given deadline is reached, whichever comes first. func (p *WorkerPool) StopAndWaitFor(deadline time.Duration) { // Launch goroutine to detect when worker pool has stopped gracefully workersDone := make(chan struct{}) go func() { - p.stop() + p.stop(true) workersDone <- struct{}{} }() @@ -363,29 +348,30 @@ func (p *WorkerPool) StopAndWaitFor(deadline time.Duration) { } } -func (p *WorkerPool) stop() { +func (p *WorkerPool) stop(waitForQueuedTasksToComplete bool) { // Mark pool as stopped - p.stopped = true + atomic.StoreInt32(&p.stopped, 1) - // Wait for all queued tasks to complete - p.tasksWaitGroup.Wait() + if waitForQueuedTasksToComplete { + // Wait for all queued tasks to complete + p.tasksWaitGroup.Wait() + } // Terminate all workers & purger goroutine p.contextCancel() - // Wait for all workers to exit + // Wait for all workers & purger goroutine to exit p.workersWaitGroup.Wait() - // Wait purger goroutine to exit - <-p.purgerDoneChan - - // close tasks channel - close(p.tasks) + // close tasks channel (only once, in case multiple concurrent calls to StopAndWait are made) + p.tasksCloseOnce.Do(func() { + close(p.tasks) + }) } // purge represents the work done by the purger goroutine func (p *WorkerPool) purge() { - defer func() { p.purgerDoneChan <- struct{}{} }() + defer p.workersWaitGroup.Done() idleTicker := time.NewTicker(p.idleTimeout) defer idleTicker.Stop() @@ -409,15 +395,20 @@ func (p *WorkerPool) stopIdleWorker() { } } -// startWorkers creates new worker goroutines to run the given tasks +// maybeStartWorker attempts to create a new worker goroutine to run the given task. +// If the worker pool has reached the maximum number of workers or there are idle workers, +// it will not create a new worker. func (p *WorkerPool) maybeStartWorker(firstTask func()) bool { - // Attempt to increment worker count if ok := p.incrementWorkerCount(); !ok { return false } - // Launch worker + if firstTask == nil { + atomic.AddInt32(&p.idleWorkerCount, 1) + } + + // Launch worker goroutine go worker(p.context, firstTask, p.tasks, &p.idleWorkerCount, p.decrementWorkerCount, p.executeTask) return true @@ -446,33 +437,46 @@ func (p *WorkerPool) executeTask(task func()) { atomic.AddUint64(&p.successfulTaskCount, 1) } -func (p *WorkerPool) incrementWorkerCount() bool { +func (p *WorkerPool) incrementWorkerCount() (incremented bool) { - // Attempt to increment worker count - p.mutex.Lock() runningWorkerCount := p.RunningWorkers() - // Execute the resizing strategy to determine if we can create more workers - if !p.strategy.Resize(runningWorkerCount, p.minWorkers, p.maxWorkers) || runningWorkerCount >= p.maxWorkers { - p.mutex.Unlock() - return false + + // Reached max workers, do not create a new one + if runningWorkerCount >= p.maxWorkers { + return } - atomic.AddInt32(&p.workerCount, 1) - p.mutex.Unlock() - // Increment waiting group semaphore - p.workersWaitGroup.Add(1) + // Idle workers available, do not create a new one + if runningWorkerCount >= p.minWorkers && runningWorkerCount > 0 && p.IdleWorkers() > 0 { + return + } - return true + p.mutex.Lock() + defer p.mutex.Unlock() + + // Execute the resizing strategy to determine if we can create more workers + incremented = p.strategy.Resize(runningWorkerCount, p.minWorkers, p.maxWorkers) + + if incremented { + // Increment worker count + atomic.AddInt32(&p.workerCount, 1) + + // Increment wait group + p.workersWaitGroup.Add(1) + } + + return } func (p *WorkerPool) decrementWorkerCount() { - // Decrement worker count p.mutex.Lock() - atomic.AddInt32(&p.workerCount, -1) - p.mutex.Unlock() + defer p.mutex.Unlock() - // Decrement waiting group semaphore + // Decrement worker count + atomic.AddInt32(&p.workerCount, -1) + + // Decrement wait group p.workersWaitGroup.Done() } @@ -497,9 +501,10 @@ func worker(context context.Context, firstTask func(), tasks <-chan func(), idle // We have received a task, execute it if firstTask != nil { taskExecutor(firstTask) + + // Increment idle count + atomic.AddInt32(idleWorkerCount, 1) } - // Increment idle count - atomic.AddInt32(idleWorkerCount, 1) for { select { diff --git a/pond_blackbox_test.go b/pond_blackbox_test.go index 36967b1..f81280e 100644 --- a/pond_blackbox_test.go +++ b/pond_blackbox_test.go @@ -3,6 +3,7 @@ package pond_test import ( "context" "fmt" + "sync" "sync/atomic" "testing" "time" @@ -568,3 +569,55 @@ func TestSubmitWithContext(t *testing.T) { assertEqual(t, int32(1), atomic.LoadInt32(&taskCount)) assertEqual(t, int32(0), atomic.LoadInt32(&doneCount)) } + +func TestConcurrentStopAndWait(t *testing.T) { + + pool := pond.New(1, 5) + + // Submit tasks + var doneCount int32 + for i := 0; i < 10; i++ { + pool.Submit(func() { + time.Sleep(1 * time.Millisecond) + atomic.AddInt32(&doneCount, 1) + }) + } + + wg := sync.WaitGroup{} + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + pool.StopAndWait() + assertEqual(t, int32(10), atomic.LoadInt32(&doneCount)) + }() + } + + wg.Wait() +} + +func TestSubmitToIdleWorker(t *testing.T) { + + pool := pond.New(6, 0, pond.MinWorkers(3)) + + assertEqual(t, 3, pool.RunningWorkers()) + + // Submit task + var doneCount int32 + for i := 0; i < 3; i++ { + pool.Submit(func() { + time.Sleep(1 * time.Millisecond) + atomic.AddInt32(&doneCount, 1) + }) + } + + // Verify no new workers were started + assertEqual(t, 3, pool.RunningWorkers()) + + // Wait until all submitted tasks complete + pool.StopAndWait() + + assertEqual(t, int32(3), atomic.LoadInt32(&doneCount)) +} diff --git a/pond_test.go b/pond_test.go index 9785543..9324167 100644 --- a/pond_test.go +++ b/pond_test.go @@ -43,6 +43,6 @@ func TestPurgeAfterPoolStopped(t *testing.T) { assertEqual(t, 1, pool.RunningWorkers()) // Simulate purger goroutine attempting to stop a worker after tasks channel is closed - pool.stopped = true + atomic.StoreInt32(&pool.stopped, 1) pool.stopIdleWorker() } diff --git a/resizer.go b/resizer.go index 4bee95d..fb72e4c 100644 --- a/resizer.go +++ b/resizer.go @@ -2,7 +2,6 @@ package pond import ( "runtime" - "sync/atomic" ) var maxProcs = runtime.GOMAXPROCS(0) @@ -25,8 +24,8 @@ var ( // ratedResizer implements a rated resizing strategy type ratedResizer struct { - rate int - hits int32 + rate uint64 + hits uint64 } // RatedResizer creates a resizing strategy which can be configured @@ -40,7 +39,7 @@ func RatedResizer(rate int) ResizingStrategy { } return &ratedResizer{ - rate: rate, + rate: uint64(rate), } } @@ -50,7 +49,7 @@ func (r *ratedResizer) Resize(runningWorkers, minWorkers, maxWorkers int) bool { return true } - hits := int(atomic.AddInt32(&r.hits, 1)) + r.hits++ - return hits%r.rate == 1 + return r.hits%r.rate == 1 }