Return Context error if it is canceled before at least 1 task failed
This commit is contained in:
@@ -50,26 +50,16 @@ func (g *TaskGroupWithContext) Submit(task func() error) {
|
|||||||
defer g.waitGroup.Done()
|
defer g.waitGroup.Done()
|
||||||
|
|
||||||
// If context has already been cancelled, skip task execution
|
// If context has already been cancelled, skip task execution
|
||||||
if g.ctx != nil {
|
select {
|
||||||
select {
|
case <-g.ctx.Done():
|
||||||
case <-g.ctx.Done():
|
return
|
||||||
return
|
default:
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't actually ignore errors
|
// don't actually ignore errors
|
||||||
err := task()
|
err := task()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.errSync.once.Do(func() {
|
g.setError(err)
|
||||||
g.errSync.guard.Lock()
|
|
||||||
g.err = err
|
|
||||||
g.errSync.guard.Unlock()
|
|
||||||
|
|
||||||
if g.cancel != nil {
|
|
||||||
g.cancel()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -91,11 +81,26 @@ func (g *TaskGroupWithContext) Wait() error {
|
|||||||
// If context was provided, cancel it to signal all running tasks to stop
|
// If context was provided, cancel it to signal all running tasks to stop
|
||||||
g.cancel()
|
g.cancel()
|
||||||
case <-g.ctx.Done():
|
case <-g.ctx.Done():
|
||||||
|
g.setError(g.ctx.Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return g.getError()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *TaskGroupWithContext) getError() error {
|
||||||
g.errSync.guard.RLock()
|
g.errSync.guard.RLock()
|
||||||
err := g.err
|
err := g.err
|
||||||
g.errSync.guard.RUnlock()
|
g.errSync.guard.RUnlock()
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *TaskGroupWithContext) setError(err error) {
|
||||||
|
g.errSync.once.Do(func() {
|
||||||
|
g.errSync.guard.Lock()
|
||||||
|
g.err = err
|
||||||
|
g.errSync.guard.Unlock()
|
||||||
|
|
||||||
|
// Cancel execution of any pending task in this group
|
||||||
|
g.cancel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -120,3 +120,34 @@ func TestGroupContextWithNilContext(t *testing.T) {
|
|||||||
|
|
||||||
assertEqual(t, "a non-nil context needs to be specified when using GroupContext", thrownPanic)
|
assertEqual(t, "a non-nil context needs to be specified when using GroupContext", thrownPanic)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGroupContextWithCanceledContext(t *testing.T) {
|
||||||
|
|
||||||
|
pool := pond.New(3, 100)
|
||||||
|
assertEqual(t, 0, pool.RunningWorkers())
|
||||||
|
|
||||||
|
// Submit a group of tasks
|
||||||
|
var doneCount, startedCount int32
|
||||||
|
userCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
group, ctx := pool.GroupContext(userCtx)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
group.Submit(func() error {
|
||||||
|
atomic.AddInt32(&startedCount, 1)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
atomic.AddInt32(&doneCount, 1)
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel context right after submitting tasks
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := group.Wait()
|
||||||
|
assertEqual(t, context.Canceled, err)
|
||||||
|
assertEqual(t, int32(0), atomic.LoadInt32(&doneCount))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user