diff --git a/pool.go b/pool.go index bae1e15..4308e76 100644 --- a/pool.go +++ b/pool.go @@ -177,31 +177,30 @@ func (p *Pool) Acquire(ctx context.Context) (*Resource, error) { return res, nil } - // if ctx.Done() == nil { - // p.cond.Wait() - // } else { - - // Convert p.cond.Wait into a channel - waitChan := make(chan struct{}, 1) - go func() { + if ctx.Done() == nil { p.cond.Wait() - waitChan <- struct{}{} - }() - - select { - case <-ctx.Done(): - // Allow goroutine waiting for signal to exit. Re-signal since we couldn't - // do anything with it. Another goroutine might be waiting. + } else { + // Convert p.cond.Wait into a channel + waitChan := make(chan struct{}, 1) go func() { - <-waitChan - p.cond.Signal() - p.cond.L.Unlock() + p.cond.Wait() + waitChan <- struct{}{} }() - return nil, ctx.Err() - case <-waitChan: + select { + case <-ctx.Done(): + // Allow goroutine waiting for signal to exit. Re-signal since we couldn't + // do anything with it. Another goroutine might be waiting. + go func() { + <-waitChan + p.cond.Signal() + p.cond.L.Unlock() + }() + + return nil, ctx.Err() + case <-waitChan: + } } - // } } } diff --git a/pool_test.go b/pool_test.go index 7fe042e..e7b07c4 100644 --- a/pool_test.go +++ b/pool_test.go @@ -117,6 +117,34 @@ func TestPoolAcquireDoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) assert.Equal(t, 1, pool.Size()) } +func TestPoolAcquireWithCancellableContext(t *testing.T) { + createFunc, createCounter := createCreateResourceFunc() + pool := puddle.NewPool(createFunc, stubCloseRes) + pool.SetMaxSize(1) + + wg := &sync.WaitGroup{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + for j := 0; j < 100; j++ { + ctx, cancel := context.WithCancel(context.Background()) + res, err := pool.Acquire(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, res.Value()) + res.Release() + cancel() + } + wg.Done() + }() + } + + wg.Wait() + + assert.Equal(t, 1, createCounter.Value()) + assert.Equal(t, 1, pool.Size()) +} + func TestPoolAcquireReturnsErrorFromFailedResourceCreate(t *testing.T) { errCreateFailed := errors.New("create failed") createFunc := func(ctx context.Context) (interface{}, error) { @@ -262,36 +290,67 @@ func BenchmarkPoolAcquireAndRelease(b *testing.B) { benchmarks := []struct { poolSize int clientCount int + cancellable bool }{ - {8, 2}, - {8, 8}, - {8, 32}, - {8, 128}, - {8, 512}, - {8, 2048}, - {8, 8192}, + {8, 2, false}, + {8, 8, false}, + {8, 32, false}, + {8, 128, false}, + {8, 512, false}, + {8, 2048, false}, + {8, 8192, false}, - {64, 2}, - {64, 8}, - {64, 32}, - {64, 128}, - {64, 512}, - {64, 2048}, - {64, 8192}, + {64, 2, false}, + {64, 8, false}, + {64, 32, false}, + {64, 128, false}, + {64, 512, false}, + {64, 2048, false}, + {64, 8192, false}, - {512, 2}, - {512, 8}, - {512, 32}, - {512, 128}, - {512, 512}, - {512, 2048}, - {512, 8192}, + {512, 2, false}, + {512, 8, false}, + {512, 32, false}, + {512, 128, false}, + {512, 512, false}, + {512, 2048, false}, + {512, 8192, false}, + + {8, 2, true}, + {8, 8, true}, + {8, 32, true}, + {8, 128, true}, + {8, 512, true}, + {8, 2048, true}, + {8, 8192, true}, + + {64, 2, true}, + {64, 8, true}, + {64, 32, true}, + {64, 128, true}, + {64, 512, true}, + {64, 2048, true}, + {64, 8192, true}, + + {512, 2, true}, + {512, 8, true}, + {512, 32, true}, + {512, 128, true}, + {512, 512, true}, + {512, 2048, true}, + {512, 8192, true}, } for _, bm := range benchmarks { - name := fmt.Sprintf("PoolSize=%d/ClientCount=%d", bm.poolSize, bm.clientCount) + name := fmt.Sprintf("PoolSize=%d/ClientCount=%d/Cancellable=%v", bm.poolSize, bm.clientCount, bm.cancellable) b.Run(name, func(b *testing.B) { + ctx := context.Background() + cancel := func() {} + if bm.cancellable { + ctx, cancel = context.WithCancel(ctx) + } + wg := &sync.WaitGroup{} createFunc, _ := createCreateResourceFunc() @@ -304,7 +363,7 @@ func BenchmarkPoolAcquireAndRelease(b *testing.B) { defer wg.Done() for j := 0; j < b.N; j++ { - res, err := pool.Acquire(context.Background()) + res, err := pool.Acquire(ctx) if err != nil { b.Fatal(err) } @@ -314,6 +373,7 @@ func BenchmarkPoolAcquireAndRelease(b *testing.B) { } wg.Wait() + cancel() }) } }