2
0

Add pool max size

This commit is contained in:
Jack Christensen
2018-12-22 18:21:18 -06:00
parent 201d8561f9
commit f39c666932
2 changed files with 231 additions and 55 deletions
+113 -40
View File
@@ -11,6 +11,9 @@ const (
resourceStatusBorrowed = iota
)
const maxUint = ^uint(0)
const maxInt = int(maxUint >> 1)
type CreateFunc func() (res interface{}, err error)
type resourceWrapper struct {
@@ -24,6 +27,7 @@ type Pool struct {
allResources map[interface{}]*resourceWrapper
availableResources []*resourceWrapper
maxSize int
create CreateFunc
}
@@ -32,10 +36,37 @@ func New(create CreateFunc) *Pool {
return &Pool{
cond: sync.NewCond(new(sync.Mutex)),
allResources: make(map[interface{}]*resourceWrapper),
maxSize: maxInt,
create: create,
}
}
// Size returns the current size of the pool.
func (p *Pool) Size() int {
p.cond.L.Lock()
n := len(p.allResources)
p.cond.L.Unlock()
return n
}
// MaxSize returns the current maximum size of the pool.
func (p *Pool) MaxSize() int {
p.cond.L.Lock()
n := p.maxSize
p.cond.L.Unlock()
return n
}
// SetMaxSize sets the maximum size of the pool. It panics if n < 1.
func (p *Pool) SetMaxSize(n int) {
if n < 1 {
panic("pool MaxSize cannot be < 1")
}
p.cond.L.Lock()
p.maxSize = n
p.cond.L.Unlock()
}
// Get gets a resource from the pool. If no resources are available and the pool
// is not at maximum capacity it will create a new resource. If the pool is at
// maximum capacity it will block until a resource is available. ctx can be used
@@ -51,51 +82,92 @@ func (p *Pool) Get(ctx context.Context) (interface{}, error) {
p.cond.L.Lock()
// If a resource is available now
if len(p.availableResources) > 0 {
rw := p.availableResources[len(p.availableResources)-1]
p.availableResources = p.availableResources[:len(p.availableResources)-1]
if rw.status != resourceStatusAvailable {
panic("BUG: unavailable resource gotten from availableResources")
}
rw.status = resourceStatusBorrowed
p.cond.L.Unlock()
return rw.resource, nil
}
// if can create resource
var localVal int
placeholder := &localVal
p.allResources[placeholder] = &resourceWrapper{resource: placeholder, status: resourceStatusCreating}
p.cond.L.Unlock()
resChan := make(chan interface{})
errChan := make(chan error)
go func() {
res, err := p.create()
if err != nil {
errChan <- err
}
resChan <- res
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errChan:
p.cond.L.Lock()
delete(p.allResources, placeholder)
p.cond.L.Unlock()
return nil, err
case res := <-resChan:
p.cond.L.Lock()
delete(p.allResources, placeholder)
p.allResources[res] = &resourceWrapper{resource: res, status: resourceStatusBorrowed}
res := p.lockedAvailableGet()
p.cond.L.Unlock()
return res, nil
}
// If there is room to create a resource start the process asynchronously
var errChan chan error
if len(p.allResources) < p.maxSize {
errChan = p.startCreate()
}
p.cond.L.Unlock()
// Whether or not we started creating a resource all we can do now is wait.
resChan := make(chan interface{})
abortChan := make(chan struct{})
go func() {
p.cond.L.Lock()
for len(p.availableResources) == 0 {
p.cond.Wait()
}
res := p.lockedAvailableGet()
p.cond.L.Unlock()
select {
case <-abortChan:
p.Return(res)
case resChan <- res:
}
}()
select {
case <-ctx.Done():
close(abortChan)
return nil, ctx.Err()
case err := <-errChan:
close(abortChan)
return nil, err
case res := <-resChan:
return res, nil
}
}
// lockedAvailableGet gets the top resource from p.availableResources. p.cond.L
// must already be locked. len(p.availableResources) must be > 0.
func (p *Pool) lockedAvailableGet() interface{} {
rw := p.availableResources[len(p.availableResources)-1]
p.availableResources = p.availableResources[:len(p.availableResources)-1]
if rw.status != resourceStatusAvailable {
panic("BUG: unavailable resource gotten from availableResources")
}
rw.status = resourceStatusBorrowed
return rw.resource
}
// startCreate starts creating a new resource. p.cond.L must already be
// locked. The returned error channel will receive any error returned by create.
func (p *Pool) startCreate() chan error {
// Use a buffered errChan to receive the error so the goroutine doesn't leak if
// the error channel is never read.
errChan := make(chan error, 1)
var localVal int
placeholder := &localVal
p.allResources[placeholder] = &resourceWrapper{resource: placeholder, status: resourceStatusCreating}
go func() {
res, err := p.create()
p.cond.L.Lock()
delete(p.allResources, placeholder)
if err != nil {
p.cond.L.Unlock()
errChan <- err
return
}
rw := &resourceWrapper{resource: res, status: resourceStatusAvailable}
p.allResources[res] = rw
p.availableResources = append(p.availableResources, rw)
p.cond.L.Unlock()
p.cond.Signal()
}()
return errChan
}
// Return returns res to the the pool. If res is not part of the pool Return
@@ -113,4 +185,5 @@ func (p *Pool) Return(res interface{}) {
p.availableResources = append(p.availableResources, rw)
p.cond.L.Unlock()
p.cond.Signal()
}
+118 -15
View File
@@ -3,6 +3,7 @@ package pool_test
import (
"context"
"errors"
"sync"
"testing"
"time"
@@ -11,11 +12,32 @@ import (
"github.com/stretchr/testify/require"
)
type Counter struct {
mutex sync.Mutex
n int
}
// Next increments the counter and returns the value
func (c *Counter) Next() int {
c.mutex.Lock()
c.n += 1
n := c.n
c.mutex.Unlock()
return n
}
// Value returns the counter
func (c *Counter) Value() int {
c.mutex.Lock()
n := c.n
c.mutex.Unlock()
return n
}
func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) {
createCalls := 0
var createCalls Counter
createFunc := func() (interface{}, error) {
createCalls += 1
return createCalls, nil
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
@@ -26,6 +48,35 @@ func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) {
pool.Return(res)
}
func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool.SetMaxSize(1)
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
for j := 0; j < 100; j++ {
res, err := pool.Get(context.Background())
assert.NoError(t, err)
assert.Equal(t, 1, res)
pool.Return(res)
}
wg.Done()
}()
}
wg.Wait()
assert.Equal(t, 1, createCalls.Value())
assert.Equal(t, 1, pool.Size())
}
func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) {
errCreateFailed := errors.New("create failed")
createFunc := func() (interface{}, error) {
@@ -39,10 +90,9 @@ func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) {
}
func TestPoolGet_ReusesResources(t *testing.T) {
createCalls := 0
var createCalls Counter
createFunc := func() (interface{}, error) {
createCalls += 1
return createCalls, nil
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
@@ -58,7 +108,7 @@ func TestPoolGet_ReusesResources(t *testing.T) {
pool.Return(res)
assert.Equal(t, 1, createCalls)
assert.Equal(t, 1, createCalls.Value())
}
func TestPoolGet_ContextAlreadyCanceled(t *testing.T) {
@@ -77,13 +127,11 @@ func TestPoolGet_ContextAlreadyCanceled(t *testing.T) {
func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
createCalls := 0
var createCalls Counter
createFunc := func() (interface{}, error) {
cancel()
time.Sleep(1 * time.Second)
createCalls += 1
return createCalls, nil
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
@@ -92,13 +140,68 @@ func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) {
assert.Nil(t, res)
}
func TestPoolReturnPanicsIfResourceNotPartOfPool(t *testing.T) {
createCalls := 0
func TestPoolReturn_PanicsIfResourceNotPartOfPool(t *testing.T) {
var createCalls Counter
createFunc := func() (interface{}, error) {
createCalls += 1
return createCalls, nil
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
assert.Panics(t, func() { pool.Return(42) })
}
func BenchmarkPoolGetAndReturnNoContention(b *testing.B) {
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
for i := 0; i < b.N; i++ {
res, err := pool.Get(context.Background())
if err != nil {
b.Fatal(err)
}
pool.Return(res)
}
}
func BenchmarkPoolGetAndReturnHeavyContention(b *testing.B) {
poolSize := 8
contentionClients := 15
var createCalls Counter
createFunc := func() (interface{}, error) {
return createCalls.Next(), nil
}
pool := pool.New(createFunc)
pool.SetMaxSize(poolSize)
doneChan := make(chan struct{})
defer close(doneChan)
for i := 0; i < contentionClients; i++ {
go func() {
for {
select {
case <-doneChan:
return
default:
}
res, err := pool.Get(context.Background())
if err != nil {
b.Fatal(err)
}
pool.Return(res)
}
}()
}
for i := 0; i < b.N; i++ {
res, err := pool.Get(context.Background())
if err != nil {
b.Fatal(err)
}
pool.Return(res)
}
}