Add close resource
This commit is contained in:
@@ -2,6 +2,7 @@ package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -14,7 +15,11 @@ const (
|
||||
const maxUint = ^uint(0)
|
||||
const maxInt = int(maxUint >> 1)
|
||||
|
||||
// ErrClosedPool occurs on an attempt to get a connection from a closed pool.
|
||||
var ErrClosedPool = errors.New("cannot get from closed pool")
|
||||
|
||||
type CreateFunc func() (res interface{}, err error)
|
||||
type CloseFunc func(res interface{}) (err error)
|
||||
|
||||
type resourceWrapper struct {
|
||||
resource interface{}
|
||||
@@ -28,19 +33,36 @@ type Pool struct {
|
||||
allResources map[interface{}]*resourceWrapper
|
||||
availableResources []*resourceWrapper
|
||||
maxSize int
|
||||
closed bool
|
||||
|
||||
create CreateFunc
|
||||
create CreateFunc
|
||||
closeRes CloseFunc
|
||||
}
|
||||
|
||||
func New(create CreateFunc) *Pool {
|
||||
func New(create CreateFunc, closeRes CloseFunc) *Pool {
|
||||
return &Pool{
|
||||
cond: sync.NewCond(new(sync.Mutex)),
|
||||
allResources: make(map[interface{}]*resourceWrapper),
|
||||
maxSize: maxInt,
|
||||
create: create,
|
||||
closeRes: closeRes,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all resources in the pool and rejects future Get calls.
|
||||
// Unavailable resources will be closes when they are returned to the pool.
|
||||
func (p *Pool) Close() {
|
||||
p.cond.L.Lock()
|
||||
p.closed = true
|
||||
|
||||
for _, rw := range p.availableResources {
|
||||
p.closeRes(rw.resource)
|
||||
// TODO - something with error
|
||||
delete(p.allResources, rw.resource)
|
||||
}
|
||||
p.cond.L.Unlock()
|
||||
}
|
||||
|
||||
// Size returns the current size of the pool.
|
||||
func (p *Pool) Size() int {
|
||||
p.cond.L.Lock()
|
||||
@@ -82,6 +104,11 @@ func (p *Pool) Get(ctx context.Context) (interface{}, error) {
|
||||
|
||||
p.cond.L.Lock()
|
||||
|
||||
if p.closed {
|
||||
p.cond.L.Unlock()
|
||||
return nil, ErrClosedPool
|
||||
}
|
||||
|
||||
// If a resource is available now
|
||||
if len(p.availableResources) > 0 {
|
||||
res := p.lockedAvailableGet()
|
||||
@@ -181,6 +208,14 @@ func (p *Pool) Return(res interface{}) {
|
||||
panic("Return called on resource that does not belong to pool")
|
||||
}
|
||||
|
||||
if p.closed {
|
||||
p.closeRes(rw.resource)
|
||||
// TODO - something with error
|
||||
delete(p.allResources, rw.resource)
|
||||
p.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
rw.status = resourceStatusAvailable
|
||||
p.availableResources = append(p.availableResources, rw)
|
||||
|
||||
|
||||
+92
-16
@@ -34,12 +34,14 @@ func (c *Counter) Value() int {
|
||||
return n
|
||||
}
|
||||
|
||||
func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) {
|
||||
func stubCloseRes(interface{}) error { return nil }
|
||||
|
||||
func TestPoolGetCreatesResourceWhenNoneAvailable(t *testing.T) {
|
||||
var createCalls Counter
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
|
||||
res, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -48,12 +50,12 @@ func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) {
|
||||
pool.Return(res)
|
||||
}
|
||||
|
||||
func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
|
||||
func TestPoolGetDoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
|
||||
var createCalls Counter
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
pool.SetMaxSize(1)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -77,24 +79,24 @@ func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
|
||||
assert.Equal(t, 1, pool.Size())
|
||||
}
|
||||
|
||||
func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) {
|
||||
func TestPoolGetReturnsErrorFromFailedResourceCreate(t *testing.T) {
|
||||
errCreateFailed := errors.New("create failed")
|
||||
createFunc := func() (interface{}, error) {
|
||||
return nil, errCreateFailed
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
|
||||
res, err := pool.Get(context.Background())
|
||||
assert.Equal(t, errCreateFailed, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func TestPoolGet_ReusesResources(t *testing.T) {
|
||||
func TestPoolGetReusesResources(t *testing.T) {
|
||||
var createCalls Counter
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
|
||||
res, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -111,11 +113,11 @@ func TestPoolGet_ReusesResources(t *testing.T) {
|
||||
assert.Equal(t, 1, createCalls.Value())
|
||||
}
|
||||
|
||||
func TestPoolGet_ContextAlreadyCanceled(t *testing.T) {
|
||||
func TestPoolGetContextAlreadyCanceled(t *testing.T) {
|
||||
createFunc := func() (interface{}, error) {
|
||||
panic("should never be called")
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
@@ -124,7 +126,7 @@ func TestPoolGet_ContextAlreadyCanceled(t *testing.T) {
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) {
|
||||
func TestPoolGetContextCanceledDuringCreate(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var createCalls Counter
|
||||
@@ -133,29 +135,103 @@ func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) {
|
||||
time.Sleep(1 * time.Second)
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
|
||||
res, err := pool.Get(ctx)
|
||||
assert.Equal(t, context.Canceled, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func TestPoolReturn_PanicsIfResourceNotPartOfPool(t *testing.T) {
|
||||
func TestPoolReturnPanicsIfResourceNotPartOfPool(t *testing.T) {
|
||||
var createCalls Counter
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
|
||||
assert.Panics(t, func() { pool.Return(42) })
|
||||
}
|
||||
|
||||
func TestPoolCloseClosesAllAvailableResources(t *testing.T) {
|
||||
var createCalls Counter
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
|
||||
var closeCalls Counter
|
||||
closeFunc := func(interface{}) error {
|
||||
closeCalls.Next()
|
||||
return nil
|
||||
}
|
||||
|
||||
p := pool.New(createFunc, closeFunc)
|
||||
|
||||
resources := make([]interface{}, 4)
|
||||
for i := range resources {
|
||||
var err error
|
||||
resources[i], err = p.Get(context.Background())
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
for _, res := range resources {
|
||||
p.Return(res)
|
||||
}
|
||||
|
||||
p.Close()
|
||||
|
||||
assert.Equal(t, len(resources), closeCalls.Value())
|
||||
}
|
||||
|
||||
func TestPoolReturnClosesResourcePoolIsAlreadyClosed(t *testing.T) {
|
||||
var createCalls Counter
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
|
||||
var closeCalls Counter
|
||||
closeFunc := func(interface{}) error {
|
||||
closeCalls.Next()
|
||||
return nil
|
||||
}
|
||||
|
||||
p := pool.New(createFunc, closeFunc)
|
||||
|
||||
resources := make([]interface{}, 4)
|
||||
for i := range resources {
|
||||
var err error
|
||||
resources[i], err = p.Get(context.Background())
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
p.Close()
|
||||
assert.Equal(t, 0, closeCalls.Value())
|
||||
|
||||
for _, res := range resources {
|
||||
p.Return(res)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(resources), closeCalls.Value())
|
||||
}
|
||||
|
||||
func TestPoolGetReturnsErrorWhenPoolIsClosed(t *testing.T) {
|
||||
var createCalls Counter
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
p := pool.New(createFunc, stubCloseRes)
|
||||
p.Close()
|
||||
|
||||
res, err := p.Get(context.Background())
|
||||
assert.Equal(t, pool.ErrClosedPool, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func BenchmarkPoolGetAndReturnNoContention(b *testing.B) {
|
||||
var createCalls Counter
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, err := pool.Get(context.Background())
|
||||
@@ -174,7 +250,7 @@ func BenchmarkPoolGetAndReturnHeavyContention(b *testing.B) {
|
||||
createFunc := func() (interface{}, error) {
|
||||
return createCalls.Next(), nil
|
||||
}
|
||||
pool := pool.New(createFunc)
|
||||
pool := pool.New(createFunc, stubCloseRes)
|
||||
pool.SetMaxSize(poolSize)
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
Reference in New Issue
Block a user