diff --git a/pool.go b/pool.go index 374a2b4..9333b6c 100644 --- a/pool.go +++ b/pool.go @@ -2,6 +2,7 @@ package pool import ( "context" + "errors" "sync" ) @@ -14,7 +15,11 @@ const ( const maxUint = ^uint(0) const maxInt = int(maxUint >> 1) +// ErrClosedPool occurs on an attempt to get a connection from a closed pool. +var ErrClosedPool = errors.New("cannot get from closed pool") + type CreateFunc func() (res interface{}, err error) +type CloseFunc func(res interface{}) (err error) type resourceWrapper struct { resource interface{} @@ -28,19 +33,36 @@ type Pool struct { allResources map[interface{}]*resourceWrapper availableResources []*resourceWrapper maxSize int + closed bool - create CreateFunc + create CreateFunc + closeRes CloseFunc } -func New(create CreateFunc) *Pool { +func New(create CreateFunc, closeRes CloseFunc) *Pool { return &Pool{ cond: sync.NewCond(new(sync.Mutex)), allResources: make(map[interface{}]*resourceWrapper), maxSize: maxInt, create: create, + closeRes: closeRes, } } +// Close closes all resources in the pool and rejects future Get calls. +// Unavailable resources will be closes when they are returned to the pool. +func (p *Pool) Close() { + p.cond.L.Lock() + p.closed = true + + for _, rw := range p.availableResources { + p.closeRes(rw.resource) + // TODO - something with error + delete(p.allResources, rw.resource) + } + p.cond.L.Unlock() +} + // Size returns the current size of the pool. func (p *Pool) Size() int { p.cond.L.Lock() @@ -82,6 +104,11 @@ func (p *Pool) Get(ctx context.Context) (interface{}, error) { p.cond.L.Lock() + if p.closed { + p.cond.L.Unlock() + return nil, ErrClosedPool + } + // If a resource is available now if len(p.availableResources) > 0 { res := p.lockedAvailableGet() @@ -181,6 +208,14 @@ func (p *Pool) Return(res interface{}) { panic("Return called on resource that does not belong to pool") } + if p.closed { + p.closeRes(rw.resource) + // TODO - something with error + delete(p.allResources, rw.resource) + p.cond.L.Unlock() + return + } + rw.status = resourceStatusAvailable p.availableResources = append(p.availableResources, rw) diff --git a/pool_test.go b/pool_test.go index 0aa0cfe..c5df2b8 100644 --- a/pool_test.go +++ b/pool_test.go @@ -34,12 +34,14 @@ func (c *Counter) Value() int { return n } -func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) { +func stubCloseRes(interface{}) error { return nil } + +func TestPoolGetCreatesResourceWhenNoneAvailable(t *testing.T) { var createCalls Counter createFunc := func() (interface{}, error) { return createCalls.Next(), nil } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) res, err := pool.Get(context.Background()) require.NoError(t, err) @@ -48,12 +50,12 @@ func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) { pool.Return(res) } -func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) { +func TestPoolGetDoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) { var createCalls Counter createFunc := func() (interface{}, error) { return createCalls.Next(), nil } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) pool.SetMaxSize(1) wg := &sync.WaitGroup{} @@ -77,24 +79,24 @@ func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) { assert.Equal(t, 1, pool.Size()) } -func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) { +func TestPoolGetReturnsErrorFromFailedResourceCreate(t *testing.T) { errCreateFailed := errors.New("create failed") createFunc := func() (interface{}, error) { return nil, errCreateFailed } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) res, err := pool.Get(context.Background()) assert.Equal(t, errCreateFailed, err) assert.Nil(t, res) } -func TestPoolGet_ReusesResources(t *testing.T) { +func TestPoolGetReusesResources(t *testing.T) { var createCalls Counter createFunc := func() (interface{}, error) { return createCalls.Next(), nil } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) res, err := pool.Get(context.Background()) require.NoError(t, err) @@ -111,11 +113,11 @@ func TestPoolGet_ReusesResources(t *testing.T) { assert.Equal(t, 1, createCalls.Value()) } -func TestPoolGet_ContextAlreadyCanceled(t *testing.T) { +func TestPoolGetContextAlreadyCanceled(t *testing.T) { createFunc := func() (interface{}, error) { panic("should never be called") } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -124,7 +126,7 @@ func TestPoolGet_ContextAlreadyCanceled(t *testing.T) { assert.Nil(t, res) } -func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) { +func TestPoolGetContextCanceledDuringCreate(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var createCalls Counter @@ -133,29 +135,103 @@ func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) { time.Sleep(1 * time.Second) return createCalls.Next(), nil } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) res, err := pool.Get(ctx) assert.Equal(t, context.Canceled, err) assert.Nil(t, res) } -func TestPoolReturn_PanicsIfResourceNotPartOfPool(t *testing.T) { +func TestPoolReturnPanicsIfResourceNotPartOfPool(t *testing.T) { var createCalls Counter createFunc := func() (interface{}, error) { return createCalls.Next(), nil } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) assert.Panics(t, func() { pool.Return(42) }) } +func TestPoolCloseClosesAllAvailableResources(t *testing.T) { + var createCalls Counter + createFunc := func() (interface{}, error) { + return createCalls.Next(), nil + } + + var closeCalls Counter + closeFunc := func(interface{}) error { + closeCalls.Next() + return nil + } + + p := pool.New(createFunc, closeFunc) + + resources := make([]interface{}, 4) + for i := range resources { + var err error + resources[i], err = p.Get(context.Background()) + require.Nil(t, err) + } + + for _, res := range resources { + p.Return(res) + } + + p.Close() + + assert.Equal(t, len(resources), closeCalls.Value()) +} + +func TestPoolReturnClosesResourcePoolIsAlreadyClosed(t *testing.T) { + var createCalls Counter + createFunc := func() (interface{}, error) { + return createCalls.Next(), nil + } + + var closeCalls Counter + closeFunc := func(interface{}) error { + closeCalls.Next() + return nil + } + + p := pool.New(createFunc, closeFunc) + + resources := make([]interface{}, 4) + for i := range resources { + var err error + resources[i], err = p.Get(context.Background()) + require.Nil(t, err) + } + + p.Close() + assert.Equal(t, 0, closeCalls.Value()) + + for _, res := range resources { + p.Return(res) + } + + assert.Equal(t, len(resources), closeCalls.Value()) +} + +func TestPoolGetReturnsErrorWhenPoolIsClosed(t *testing.T) { + var createCalls Counter + createFunc := func() (interface{}, error) { + return createCalls.Next(), nil + } + p := pool.New(createFunc, stubCloseRes) + p.Close() + + res, err := p.Get(context.Background()) + assert.Equal(t, pool.ErrClosedPool, err) + assert.Nil(t, res) +} + func BenchmarkPoolGetAndReturnNoContention(b *testing.B) { var createCalls Counter createFunc := func() (interface{}, error) { return createCalls.Next(), nil } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) for i := 0; i < b.N; i++ { res, err := pool.Get(context.Background()) @@ -174,7 +250,7 @@ func BenchmarkPoolGetAndReturnHeavyContention(b *testing.B) { createFunc := func() (interface{}, error) { return createCalls.Next(), nil } - pool := pool.New(createFunc) + pool := pool.New(createFunc, stubCloseRes) pool.SetMaxSize(poolSize) doneChan := make(chan struct{})