2
0

Add close resource

This commit is contained in:
Jack Christensen
2018-12-22 18:55:53 -06:00
parent f39c666932
commit f19cb3c6c4
2 changed files with 129 additions and 18 deletions
+37 -2
View File
@@ -2,6 +2,7 @@ package pool
import (
"context"
"errors"
"sync"
)
@@ -14,7 +15,11 @@ const (
const maxUint = ^uint(0)
const maxInt = int(maxUint >> 1)
// ErrClosedPool occurs on an attempt to get a connection from a closed pool.
var ErrClosedPool = errors.New("cannot get from closed pool")
type CreateFunc func() (res interface{}, err error)
type CloseFunc func(res interface{}) (err error)
type resourceWrapper struct {
resource interface{}
@@ -28,19 +33,36 @@ type Pool struct {
allResources map[interface{}]*resourceWrapper
availableResources []*resourceWrapper
maxSize int
closed bool
create CreateFunc
create CreateFunc
closeRes CloseFunc
}
func New(create CreateFunc) *Pool {
func New(create CreateFunc, closeRes CloseFunc) *Pool {
return &Pool{
cond: sync.NewCond(new(sync.Mutex)),
allResources: make(map[interface{}]*resourceWrapper),
maxSize: maxInt,
create: create,
closeRes: closeRes,
}
}
// Close closes all resources in the pool and rejects future Get calls.
// Unavailable resources will be closes when they are returned to the pool.
func (p *Pool) Close() {
p.cond.L.Lock()
p.closed = true
for _, rw := range p.availableResources {
p.closeRes(rw.resource)
// TODO - something with error
delete(p.allResources, rw.resource)
}
p.cond.L.Unlock()
}
// Size returns the current size of the pool.
func (p *Pool) Size() int {
p.cond.L.Lock()
@@ -82,6 +104,11 @@ func (p *Pool) Get(ctx context.Context) (interface{}, error) {
p.cond.L.Lock()
if p.closed {
p.cond.L.Unlock()
return nil, ErrClosedPool
}
// If a resource is available now
if len(p.availableResources) > 0 {
res := p.lockedAvailableGet()
@@ -181,6 +208,14 @@ func (p *Pool) Return(res interface{}) {
panic("Return called on resource that does not belong to pool")
}
if p.closed {
p.closeRes(rw.resource)
// TODO - something with error
delete(p.allResources, rw.resource)
p.cond.L.Unlock()
return
}
rw.status = resourceStatusAvailable
p.availableResources = append(p.availableResources, rw)
+92 -16
View File
@@ -34,12 +34,14 @@ func (c *Counter) Value() int {
return n
}
func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) {
func stubCloseRes(interface{}) error { return nil }
func TestPoolGetCreatesResourceWhenNoneAvailable(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
res, err := pool.Get(context.Background())
require.NoError(t, err)
@@ -48,12 +50,12 @@ func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) {
pool.Return(res)
}
func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
func TestPoolGetDoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
pool.SetMaxSize(1)
wg := &sync.WaitGroup{}
@@ -77,24 +79,24 @@ func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
assert.Equal(t, 1, pool.Size())
}
func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) {
func TestPoolGetReturnsErrorFromFailedResourceCreate(t *testing.T) {
errCreateFailed := errors.New("create failed")
createFunc := func() (interface{}, error) {
return nil, errCreateFailed
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
res, err := pool.Get(context.Background())
assert.Equal(t, errCreateFailed, err)
assert.Nil(t, res)
}
func TestPoolGet_ReusesResources(t *testing.T) {
func TestPoolGetReusesResources(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
res, err := pool.Get(context.Background())
require.NoError(t, err)
@@ -111,11 +113,11 @@ func TestPoolGet_ReusesResources(t *testing.T) {
assert.Equal(t, 1, createCalls.Value())
}
func TestPoolGet_ContextAlreadyCanceled(t *testing.T) {
func TestPoolGetContextAlreadyCanceled(t *testing.T) {
createFunc := func() (interface{}, error) {
panic("should never be called")
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -124,7 +126,7 @@ func TestPoolGet_ContextAlreadyCanceled(t *testing.T) {
assert.Nil(t, res)
}
func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) {
func TestPoolGetContextCanceledDuringCreate(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var createCalls Counter
@@ -133,29 +135,103 @@ func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) {
time.Sleep(1 * time.Second)
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
res, err := pool.Get(ctx)
assert.Equal(t, context.Canceled, err)
assert.Nil(t, res)
}
func TestPoolReturn_PanicsIfResourceNotPartOfPool(t *testing.T) {
func TestPoolReturnPanicsIfResourceNotPartOfPool(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
assert.Panics(t, func() { pool.Return(42) })
}
func TestPoolCloseClosesAllAvailableResources(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
var closeCalls Counter
closeFunc := func(interface{}) error {
closeCalls.Next()
return nil
}
p := pool.New(createFunc, closeFunc)
resources := make([]interface{}, 4)
for i := range resources {
var err error
resources[i], err = p.Get(context.Background())
require.Nil(t, err)
}
for _, res := range resources {
p.Return(res)
}
p.Close()
assert.Equal(t, len(resources), closeCalls.Value())
}
func TestPoolReturnClosesResourcePoolIsAlreadyClosed(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
var closeCalls Counter
closeFunc := func(interface{}) error {
closeCalls.Next()
return nil
}
p := pool.New(createFunc, closeFunc)
resources := make([]interface{}, 4)
for i := range resources {
var err error
resources[i], err = p.Get(context.Background())
require.Nil(t, err)
}
p.Close()
assert.Equal(t, 0, closeCalls.Value())
for _, res := range resources {
p.Return(res)
}
assert.Equal(t, len(resources), closeCalls.Value())
}
func TestPoolGetReturnsErrorWhenPoolIsClosed(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
p := pool.New(createFunc, stubCloseRes)
p.Close()
res, err := p.Get(context.Background())
assert.Equal(t, pool.ErrClosedPool, err)
assert.Nil(t, res)
}
func BenchmarkPoolGetAndReturnNoContention(b *testing.B) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
for i := 0; i < b.N; i++ {
res, err := pool.Get(context.Background())
@@ -174,7 +250,7 @@ func BenchmarkPoolGetAndReturnHeavyContention(b *testing.B) {
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool := pool.New(createFunc, stubCloseRes)
pool.SetMaxSize(poolSize)
doneChan := make(chan struct{})