diff --git a/pool.go b/pool.go index c2f1d3c..9021b54 100644 --- a/pool.go +++ b/pool.go @@ -276,10 +276,10 @@ func (p *Pool) TryAcquire(ctx context.Context) (*Resource, error) { // will return ErrNotAvailable if no resource is available. func (p *Pool) doAcquire(ctx context.Context, block bool) (*Resource, error) { startNano := nanotime() - p.cond.L.Lock() if doneChan := ctx.Done(); doneChan != nil { select { case <-ctx.Done(): + p.cond.L.Lock() p.canceledAcquireCount += 1 p.cond.L.Unlock() return nil, ctx.Err() @@ -287,6 +287,8 @@ func (p *Pool) doAcquire(ctx context.Context, block bool) (*Resource, error) { } } + p.cond.L.Lock() + emptyAcquire := false for { diff --git a/pool_test.go b/pool_test.go index 5b22086..b51afa5 100644 --- a/pool_test.go +++ b/pool_test.go @@ -192,6 +192,20 @@ func TestPoolTryAcquireDoesNotBlock(t *testing.T) { assert.Equal(t, 1, createCounter.Value()) } +func TestPoolAcquireNilContextDoesNotLeavePoolLocked(t *testing.T) { + constructor, createCounter := createConstructor() + pool := puddle.NewPool(constructor, stubDestructor, 10) + + assert.Panics(t, func() { pool.Acquire(nil) }) + + res, err := pool.Acquire(context.Background()) + require.NoError(t, err) + assert.Equal(t, 1, res.Value()) + res.Release() + + assert.Equal(t, 1, createCounter.Value()) +} + func TestPoolAcquireContextAlreadyCanceled(t *testing.T) { constructor := func(ctx context.Context) (interface{}, error) { panic("should never be called")