diff --git a/pond.go b/pond.go index 10a80ea..58d82e5 100644 --- a/pond.go +++ b/pond.go @@ -1,6 +1,7 @@ package pond import ( + "errors" "fmt" "runtime/debug" "sync" @@ -14,6 +15,11 @@ const ( defaultIdleTimeout = 5 * time.Second ) +var ( + // SubmitOnStoppedPoolError is thrown when attempting to submit a task to a pool that has been stopped + SubmitOnStoppedPoolError = errors.New("worker pool has been stopped and is no longer accepting tasks") +) + // defaultPanicHandler is the default panic handler func defaultPanicHandler(panic interface{}) { fmt.Printf("Worker exits from a panic: %v\nStack trace: %s\n", panic, string(debug.Stack())) @@ -77,6 +83,7 @@ type WorkerPool struct { stopOnce sync.Once waitGroup sync.WaitGroup mutex sync.Mutex + stopped bool } // New creates a worker pool with that can scale up to the given maximum number of workers (maxWorkers). @@ -193,6 +200,11 @@ func (p *WorkerPool) CompletedTasks() uint64 { return p.SuccessfulTasks() + p.FailedTasks() } +// Stopped returns true if the pool has been stopped and is no longer accepting tasks, and false otherwise. +func (p *WorkerPool) Stopped() bool { + return p.stopped +} + // Submit sends a task to this worker pool for execution. If the queue is full, // it will wait until the task is dispatched to a worker goroutine. func (p *WorkerPool) Submit(task func()) { @@ -206,11 +218,19 @@ func (p *WorkerPool) TrySubmit(task func()) bool { return p.submit(task, false) } -func (p *WorkerPool) submit(task func(), canWaitForIdleWorker bool) (submitted bool) { +func (p *WorkerPool) submit(task func(), mustSubmit bool) (submitted bool) { if task == nil { return false } + if p.Stopped() { + // Pool is stopped and caller must submit the task + if mustSubmit { + panic(SubmitOnStoppedPoolError) + } + return false + } + // Increment submitted and waiting task counters as soon as we receive a task atomic.AddUint64(&p.submittedTaskCount, 1) atomic.AddUint64(&p.waitingTaskCount, 1) @@ -244,7 +264,7 @@ func (p *WorkerPool) submit(task func(), canWaitForIdleWorker bool) (submitted b } } - if !canWaitForIdleWorker { + if !mustSubmit { select { case p.tasks <- task: submitted = true @@ -301,6 +321,9 @@ func (p *WorkerPool) SubmitBefore(task func(), deadline time.Duration) { // Stop causes this pool to stop accepting tasks, without waiting for goroutines to exit func (p *WorkerPool) Stop() { p.stopOnce.Do(func() { + // Mark pool as stopped + p.stopped = true + // Send the signal to stop the purger goroutine close(p.purgerQuit) }) @@ -412,7 +435,7 @@ func (p *WorkerPool) Group() *TaskGroup { } // worker launches a worker goroutine -func worker(firstTask func(), tasks chan func(), idleWorkerCount *int32, exitHandler func(), taskExecutor func(func())) { +func worker(firstTask func(), tasks <-chan func(), idleWorkerCount *int32, exitHandler func(), taskExecutor func(func())) { defer func() { // Decrement idle count diff --git a/pond_blackbox_test.go b/pond_blackbox_test.go index de52402..b9f90f2 100644 --- a/pond_blackbox_test.go +++ b/pond_blackbox_test.go @@ -197,6 +197,20 @@ func TestTrySubmit(t *testing.T) { assertEqual(t, int32(1), atomic.LoadInt32(&doneCount)) } +func TestTrySubmitOnStoppedPool(t *testing.T) { + + // Create a pool and stop it immediately + pool := pond.New(1, 0) + assertEqual(t, false, pool.Stopped()) + pool.StopAndWait() + assertEqual(t, true, pool.Stopped()) + + submitted := pool.TrySubmit(func() {}) + + // Task should not be accepted by the pool + assertEqual(t, false, submitted) +} + func TestSubmitToIdle(t *testing.T) { pool := pond.New(1, 5) @@ -224,6 +238,27 @@ func TestSubmitToIdle(t *testing.T) { assertEqual(t, int(0), pool.IdleWorkers()) } +func TestSubmitOnStoppedPool(t *testing.T) { + + // Create a pool and stop it immediately + pool := pond.New(1, 0) + assertEqual(t, false, pool.Stopped()) + pool.StopAndWait() + assertEqual(t, true, pool.Stopped()) + + // Attempt to submit a task on a stopped pool + var err interface{} = nil + func() { + defer func() { + err = recover() + }() + pool.Submit(func() {}) + }() + + // Call to Submit should have failed with SubmitOnStoppedPoolError error + assertEqual(t, pond.SubmitOnStoppedPoolError, err) +} + func TestRunning(t *testing.T) { workerCount := 5