2
0

Add close resource

This commit is contained in:
Jack Christensen
2018-12-22 18:55:53 -06:00
parent f39c666932
commit f19cb3c6c4
2 changed files with 129 additions and 18 deletions
+37 -2
View File
@@ -2,6 +2,7 @@ package pool
import ( import (
"context" "context"
"errors"
"sync" "sync"
) )
@@ -14,7 +15,11 @@ const (
const maxUint = ^uint(0) const maxUint = ^uint(0)
const maxInt = int(maxUint >> 1) const maxInt = int(maxUint >> 1)
// ErrClosedPool occurs on an attempt to get a connection from a closed pool.
var ErrClosedPool = errors.New("cannot get from closed pool")
type CreateFunc func() (res interface{}, err error) type CreateFunc func() (res interface{}, err error)
type CloseFunc func(res interface{}) (err error)
type resourceWrapper struct { type resourceWrapper struct {
resource interface{} resource interface{}
@@ -28,19 +33,36 @@ type Pool struct {
allResources map[interface{}]*resourceWrapper allResources map[interface{}]*resourceWrapper
availableResources []*resourceWrapper availableResources []*resourceWrapper
maxSize int maxSize int
closed bool
create CreateFunc create CreateFunc
closeRes CloseFunc
} }
func New(create CreateFunc) *Pool { func New(create CreateFunc, closeRes CloseFunc) *Pool {
return &Pool{ return &Pool{
cond: sync.NewCond(new(sync.Mutex)), cond: sync.NewCond(new(sync.Mutex)),
allResources: make(map[interface{}]*resourceWrapper), allResources: make(map[interface{}]*resourceWrapper),
maxSize: maxInt, maxSize: maxInt,
create: create, create: create,
closeRes: closeRes,
} }
} }
// Close closes all resources in the pool and rejects future Get calls.
// Unavailable resources will be closes when they are returned to the pool.
func (p *Pool) Close() {
p.cond.L.Lock()
p.closed = true
for _, rw := range p.availableResources {
p.closeRes(rw.resource)
// TODO - something with error
delete(p.allResources, rw.resource)
}
p.cond.L.Unlock()
}
// Size returns the current size of the pool. // Size returns the current size of the pool.
func (p *Pool) Size() int { func (p *Pool) Size() int {
p.cond.L.Lock() p.cond.L.Lock()
@@ -82,6 +104,11 @@ func (p *Pool) Get(ctx context.Context) (interface{}, error) {
p.cond.L.Lock() p.cond.L.Lock()
if p.closed {
p.cond.L.Unlock()
return nil, ErrClosedPool
}
// If a resource is available now // If a resource is available now
if len(p.availableResources) > 0 { if len(p.availableResources) > 0 {
res := p.lockedAvailableGet() res := p.lockedAvailableGet()
@@ -181,6 +208,14 @@ func (p *Pool) Return(res interface{}) {
panic("Return called on resource that does not belong to pool") panic("Return called on resource that does not belong to pool")
} }
if p.closed {
p.closeRes(rw.resource)
// TODO - something with error
delete(p.allResources, rw.resource)
p.cond.L.Unlock()
return
}
rw.status = resourceStatusAvailable rw.status = resourceStatusAvailable
p.availableResources = append(p.availableResources, rw) p.availableResources = append(p.availableResources, rw)
+92 -16
View File
@@ -34,12 +34,14 @@ func (c *Counter) Value() int {
return n return n
} }
func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) { func stubCloseRes(interface{}) error { return nil }
func TestPoolGetCreatesResourceWhenNoneAvailable(t *testing.T) {
var createCalls Counter var createCalls Counter
createFunc := func() (interface{}, error) { createFunc := func() (interface{}, error) {
return createCalls.Next(), nil return createCalls.Next(), nil
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
res, err := pool.Get(context.Background()) res, err := pool.Get(context.Background())
require.NoError(t, err) require.NoError(t, err)
@@ -48,12 +50,12 @@ func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) {
pool.Return(res) pool.Return(res)
} }
func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) { func TestPoolGetDoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
var createCalls Counter var createCalls Counter
createFunc := func() (interface{}, error) { createFunc := func() (interface{}, error) {
return createCalls.Next(), nil return createCalls.Next(), nil
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
pool.SetMaxSize(1) pool.SetMaxSize(1)
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
@@ -77,24 +79,24 @@ func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
assert.Equal(t, 1, pool.Size()) assert.Equal(t, 1, pool.Size())
} }
func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) { func TestPoolGetReturnsErrorFromFailedResourceCreate(t *testing.T) {
errCreateFailed := errors.New("create failed") errCreateFailed := errors.New("create failed")
createFunc := func() (interface{}, error) { createFunc := func() (interface{}, error) {
return nil, errCreateFailed return nil, errCreateFailed
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
res, err := pool.Get(context.Background()) res, err := pool.Get(context.Background())
assert.Equal(t, errCreateFailed, err) assert.Equal(t, errCreateFailed, err)
assert.Nil(t, res) assert.Nil(t, res)
} }
func TestPoolGet_ReusesResources(t *testing.T) { func TestPoolGetReusesResources(t *testing.T) {
var createCalls Counter var createCalls Counter
createFunc := func() (interface{}, error) { createFunc := func() (interface{}, error) {
return createCalls.Next(), nil return createCalls.Next(), nil
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
res, err := pool.Get(context.Background()) res, err := pool.Get(context.Background())
require.NoError(t, err) require.NoError(t, err)
@@ -111,11 +113,11 @@ func TestPoolGet_ReusesResources(t *testing.T) {
assert.Equal(t, 1, createCalls.Value()) assert.Equal(t, 1, createCalls.Value())
} }
func TestPoolGet_ContextAlreadyCanceled(t *testing.T) { func TestPoolGetContextAlreadyCanceled(t *testing.T) {
createFunc := func() (interface{}, error) { createFunc := func() (interface{}, error) {
panic("should never be called") panic("should never be called")
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
@@ -124,7 +126,7 @@ func TestPoolGet_ContextAlreadyCanceled(t *testing.T) {
assert.Nil(t, res) assert.Nil(t, res)
} }
func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) { func TestPoolGetContextCanceledDuringCreate(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
var createCalls Counter var createCalls Counter
@@ -133,29 +135,103 @@ func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
return createCalls.Next(), nil return createCalls.Next(), nil
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
res, err := pool.Get(ctx) res, err := pool.Get(ctx)
assert.Equal(t, context.Canceled, err) assert.Equal(t, context.Canceled, err)
assert.Nil(t, res) assert.Nil(t, res)
} }
func TestPoolReturn_PanicsIfResourceNotPartOfPool(t *testing.T) { func TestPoolReturnPanicsIfResourceNotPartOfPool(t *testing.T) {
var createCalls Counter var createCalls Counter
createFunc := func() (interface{}, error) { createFunc := func() (interface{}, error) {
return createCalls.Next(), nil return createCalls.Next(), nil
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
assert.Panics(t, func() { pool.Return(42) }) assert.Panics(t, func() { pool.Return(42) })
} }
func TestPoolCloseClosesAllAvailableResources(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
var closeCalls Counter
closeFunc := func(interface{}) error {
closeCalls.Next()
return nil
}
p := pool.New(createFunc, closeFunc)
resources := make([]interface{}, 4)
for i := range resources {
var err error
resources[i], err = p.Get(context.Background())
require.Nil(t, err)
}
for _, res := range resources {
p.Return(res)
}
p.Close()
assert.Equal(t, len(resources), closeCalls.Value())
}
func TestPoolReturnClosesResourcePoolIsAlreadyClosed(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
var closeCalls Counter
closeFunc := func(interface{}) error {
closeCalls.Next()
return nil
}
p := pool.New(createFunc, closeFunc)
resources := make([]interface{}, 4)
for i := range resources {
var err error
resources[i], err = p.Get(context.Background())
require.Nil(t, err)
}
p.Close()
assert.Equal(t, 0, closeCalls.Value())
for _, res := range resources {
p.Return(res)
}
assert.Equal(t, len(resources), closeCalls.Value())
}
func TestPoolGetReturnsErrorWhenPoolIsClosed(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
p := pool.New(createFunc, stubCloseRes)
p.Close()
res, err := p.Get(context.Background())
assert.Equal(t, pool.ErrClosedPool, err)
assert.Nil(t, res)
}
func BenchmarkPoolGetAndReturnNoContention(b *testing.B) { func BenchmarkPoolGetAndReturnNoContention(b *testing.B) {
var createCalls Counter var createCalls Counter
createFunc := func() (interface{}, error) { createFunc := func() (interface{}, error) {
return createCalls.Next(), nil return createCalls.Next(), nil
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
res, err := pool.Get(context.Background()) res, err := pool.Get(context.Background())
@@ -174,7 +250,7 @@ func BenchmarkPoolGetAndReturnHeavyContention(b *testing.B) {
createFunc := func() (interface{}, error) { createFunc := func() (interface{}, error) {
return createCalls.Next(), nil return createCalls.Next(), nil
} }
pool := pool.New(createFunc) pool := pool.New(createFunc, stubCloseRes)
pool.SetMaxSize(poolSize) pool.SetMaxSize(poolSize)
doneChan := make(chan struct{}) doneChan := make(chan struct{})