diff --git a/pool.go b/pool.go index fdbf283..2da88a7 100644 --- a/pool.go +++ b/pool.go @@ -64,8 +64,9 @@ type Pool struct { destructor Destructor maxSize int - acquireCount int64 - slowAcquireCount int64 + acquireCount int64 + slowAcquireCount int64 + canceledAcquireCount int64 closed bool } @@ -106,6 +107,7 @@ type Stat struct { maxResources int acquireCount int64 slowAcquireCount int64 + canceledAcquireCount int64 } // TotalResource returns the total number of resources in the pool. @@ -145,13 +147,20 @@ func (s *Stat) SlowAcquireCount() int64 { return s.slowAcquireCount } +// CanceledAcquireCount returns the number of acquires from the pool +// that were canceled by a context. +func (s *Stat) CanceledAcquireCount() int64 { + return s.canceledAcquireCount +} + // Stat returns the current pool statistics. func (p *Pool) Stat() *Stat { p.cond.L.Lock() s := &Stat{ - maxResources: p.maxSize, - acquireCount: p.acquireCount, - slowAcquireCount: p.slowAcquireCount, + maxResources: p.maxSize, + acquireCount: p.acquireCount, + slowAcquireCount: p.slowAcquireCount, + canceledAcquireCount: p.canceledAcquireCount, } for _, res := range p.allResources { @@ -174,16 +183,18 @@ func (p *Pool) Stat() *Stat { // maximum capacity it will block until a resource is available. ctx can be used // to cancel the Acquire. func (p *Pool) Acquire(ctx context.Context) (*Resource, error) { + p.cond.L.Lock() if doneChan := ctx.Done(); doneChan != nil { select { case <-ctx.Done(): + p.canceledAcquireCount += 1 + p.cond.L.Unlock() return nil, ctx.Err() default: } } slowAcquire := false - p.cond.L.Lock() for { if p.closed { @@ -218,6 +229,15 @@ func (p *Pool) Acquire(ctx context.Context) (*Resource, error) { if err != nil { p.allResources = removeResource(p.allResources, res) p.destructWG.Done() + + select { + case <-ctx.Done(): + if err == ctx.Err() { + p.canceledAcquireCount += 1 + } + default: + } + p.cond.L.Unlock() return nil, err } @@ -242,6 +262,8 @@ func (p *Pool) Acquire(ctx context.Context) (*Resource, error) { select { case <-ctx.Done(): + p.canceledAcquireCount += 1 + // Allow goroutine waiting for signal to exit. Re-signal since we couldn't // do anything with it. Another goroutine might be waiting. go func() { diff --git a/pool_test.go b/pool_test.go index d21cf0f..d9a7425 100644 --- a/pool_test.go +++ b/pool_test.go @@ -310,7 +310,7 @@ func TestPoolStatResources(t *testing.T) { close(endWaitChan) } -func TestPoolStatCounters(t *testing.T) { +func TestPoolStatSuccessfulAcquireCounters(t *testing.T) { createFunc, _ := createCreateResourceFunc() pool := puddle.NewPool(createFunc, stubCloseRes, 1) defer pool.Close() @@ -350,6 +350,60 @@ func TestPoolStatCounters(t *testing.T) { assert.Equal(t, int64(2), stat.SlowAcquireCount()) } +func TestPoolStatCanceledAcquireBeforeStart(t *testing.T) { + createFunc, _ := createCreateResourceFunc() + pool := puddle.NewPool(createFunc, stubCloseRes, 1) + defer pool.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := pool.Acquire(ctx) + require.Equal(t, context.Canceled, err) + + stat := pool.Stat() + assert.Equal(t, int64(0), stat.AcquireCount()) + assert.Equal(t, int64(1), stat.CanceledAcquireCount()) +} + +func TestPoolStatCanceledAcquireDuringCreate(t *testing.T) { + createFunc := func(ctx context.Context) (interface{}, error) { + <-ctx.Done() + return nil, ctx.Err() + } + + pool := puddle.NewPool(createFunc, stubCloseRes, 1) + defer pool.Close() + + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(50*time.Millisecond, cancel) + _, err := pool.Acquire(ctx) + require.Equal(t, context.Canceled, err) + + stat := pool.Stat() + assert.Equal(t, int64(0), stat.AcquireCount()) + assert.Equal(t, int64(1), stat.CanceledAcquireCount()) +} + +func TestPoolStatCanceledAcquireDuringWait(t *testing.T) { + createFunc, _ := createCreateResourceFunc() + pool := puddle.NewPool(createFunc, stubCloseRes, 1) + defer pool.Close() + + res, err := pool.Acquire(context.Background()) + require.Nil(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(50*time.Millisecond, cancel) + _, err = pool.Acquire(ctx) + require.Equal(t, context.Canceled, err) + + res.Release() + + stat := pool.Stat() + assert.Equal(t, int64(1), stat.AcquireCount()) + assert.Equal(t, int64(1), stat.CanceledAcquireCount()) +} + func TestResourceDestroyRemovesResourceFromPool(t *testing.T) { createFunc, _ := createCreateResourceFunc() var closeCalls Counter