diff --git a/pool.go b/pool.go index 6063294..374a2b4 100644 --- a/pool.go +++ b/pool.go @@ -11,6 +11,9 @@ const ( resourceStatusBorrowed = iota ) +const maxUint = ^uint(0) +const maxInt = int(maxUint >> 1) + type CreateFunc func() (res interface{}, err error) type resourceWrapper struct { @@ -24,6 +27,7 @@ type Pool struct { allResources map[interface{}]*resourceWrapper availableResources []*resourceWrapper + maxSize int create CreateFunc } @@ -32,10 +36,37 @@ func New(create CreateFunc) *Pool { return &Pool{ cond: sync.NewCond(new(sync.Mutex)), allResources: make(map[interface{}]*resourceWrapper), + maxSize: maxInt, create: create, } } +// Size returns the current size of the pool. +func (p *Pool) Size() int { + p.cond.L.Lock() + n := len(p.allResources) + p.cond.L.Unlock() + return n +} + +// MaxSize returns the current maximum size of the pool. +func (p *Pool) MaxSize() int { + p.cond.L.Lock() + n := p.maxSize + p.cond.L.Unlock() + return n +} + +// SetMaxSize sets the maximum size of the pool. It panics if n < 1. +func (p *Pool) SetMaxSize(n int) { + if n < 1 { + panic("pool MaxSize cannot be < 1") + } + p.cond.L.Lock() + p.maxSize = n + p.cond.L.Unlock() +} + // Get gets a resource from the pool. If no resources are available and the pool // is not at maximum capacity it will create a new resource. If the pool is at // maximum capacity it will block until a resource is available. ctx can be used @@ -51,51 +82,92 @@ func (p *Pool) Get(ctx context.Context) (interface{}, error) { p.cond.L.Lock() + // If a resource is available now if len(p.availableResources) > 0 { - rw := p.availableResources[len(p.availableResources)-1] - p.availableResources = p.availableResources[:len(p.availableResources)-1] - if rw.status != resourceStatusAvailable { - panic("BUG: unavailable resource gotten from availableResources") - } - rw.status = resourceStatusBorrowed - p.cond.L.Unlock() - return rw.resource, nil - } - - // if can create resource - - var localVal int - placeholder := &localVal - p.allResources[placeholder] = &resourceWrapper{resource: placeholder, status: resourceStatusCreating} - p.cond.L.Unlock() - - resChan := make(chan interface{}) - errChan := make(chan error) - - go func() { - res, err := p.create() - if err != nil { - errChan <- err - } - resChan <- res - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case err := <-errChan: - p.cond.L.Lock() - delete(p.allResources, placeholder) - p.cond.L.Unlock() - return nil, err - case res := <-resChan: - p.cond.L.Lock() - delete(p.allResources, placeholder) - p.allResources[res] = &resourceWrapper{resource: res, status: resourceStatusBorrowed} + res := p.lockedAvailableGet() p.cond.L.Unlock() return res, nil } + // If there is room to create a resource start the process asynchronously + var errChan chan error + if len(p.allResources) < p.maxSize { + errChan = p.startCreate() + } + p.cond.L.Unlock() + + // Whether or not we started creating a resource all we can do now is wait. + resChan := make(chan interface{}) + abortChan := make(chan struct{}) + + go func() { + p.cond.L.Lock() + for len(p.availableResources) == 0 { + p.cond.Wait() + } + res := p.lockedAvailableGet() + p.cond.L.Unlock() + + select { + case <-abortChan: + p.Return(res) + case resChan <- res: + } + }() + + select { + case <-ctx.Done(): + close(abortChan) + return nil, ctx.Err() + case err := <-errChan: + close(abortChan) + return nil, err + case res := <-resChan: + return res, nil + } +} + +// lockedAvailableGet gets the top resource from p.availableResources. p.cond.L +// must already be locked. len(p.availableResources) must be > 0. +func (p *Pool) lockedAvailableGet() interface{} { + rw := p.availableResources[len(p.availableResources)-1] + p.availableResources = p.availableResources[:len(p.availableResources)-1] + if rw.status != resourceStatusAvailable { + panic("BUG: unavailable resource gotten from availableResources") + } + rw.status = resourceStatusBorrowed + return rw.resource +} + +// startCreate starts creating a new resource. p.cond.L must already be +// locked. The returned error channel will receive any error returned by create. +func (p *Pool) startCreate() chan error { + // Use a buffered errChan to receive the error so the goroutine doesn't leak if + // the error channel is never read. + errChan := make(chan error, 1) + + var localVal int + placeholder := &localVal + p.allResources[placeholder] = &resourceWrapper{resource: placeholder, status: resourceStatusCreating} + + go func() { + res, err := p.create() + p.cond.L.Lock() + delete(p.allResources, placeholder) + if err != nil { + p.cond.L.Unlock() + errChan <- err + return + } + + rw := &resourceWrapper{resource: res, status: resourceStatusAvailable} + p.allResources[res] = rw + p.availableResources = append(p.availableResources, rw) + p.cond.L.Unlock() + p.cond.Signal() + }() + + return errChan } // Return returns res to the the pool. If res is not part of the pool Return @@ -113,4 +185,5 @@ func (p *Pool) Return(res interface{}) { p.availableResources = append(p.availableResources, rw) p.cond.L.Unlock() + p.cond.Signal() } diff --git a/pool_test.go b/pool_test.go index 2365ad8..0aa0cfe 100644 --- a/pool_test.go +++ b/pool_test.go @@ -3,6 +3,7 @@ package pool_test import ( "context" "errors" + "sync" "testing" "time" @@ -11,11 +12,32 @@ import ( "github.com/stretchr/testify/require" ) +type Counter struct { + mutex sync.Mutex + n int +} + +// Next increments the counter and returns the value +func (c *Counter) Next() int { + c.mutex.Lock() + c.n += 1 + n := c.n + c.mutex.Unlock() + return n +} + +// Value returns the counter +func (c *Counter) Value() int { + c.mutex.Lock() + n := c.n + c.mutex.Unlock() + return n +} + func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) { - createCalls := 0 + var createCalls Counter createFunc := func() (interface{}, error) { - createCalls += 1 - return createCalls, nil + return createCalls.Next(), nil } pool := pool.New(createFunc) @@ -26,6 +48,35 @@ func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) { pool.Return(res) } +func TestPoolGet_DoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) { + var createCalls Counter + createFunc := func() (interface{}, error) { + return createCalls.Next(), nil + } + pool := pool.New(createFunc) + pool.SetMaxSize(1) + + wg := &sync.WaitGroup{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + for j := 0; j < 100; j++ { + res, err := pool.Get(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, res) + pool.Return(res) + } + wg.Done() + }() + } + + wg.Wait() + + assert.Equal(t, 1, createCalls.Value()) + assert.Equal(t, 1, pool.Size()) +} + func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) { errCreateFailed := errors.New("create failed") createFunc := func() (interface{}, error) { @@ -39,10 +90,9 @@ func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) { } func TestPoolGet_ReusesResources(t *testing.T) { - createCalls := 0 + var createCalls Counter createFunc := func() (interface{}, error) { - createCalls += 1 - return createCalls, nil + return createCalls.Next(), nil } pool := pool.New(createFunc) @@ -58,7 +108,7 @@ func TestPoolGet_ReusesResources(t *testing.T) { pool.Return(res) - assert.Equal(t, 1, createCalls) + assert.Equal(t, 1, createCalls.Value()) } func TestPoolGet_ContextAlreadyCanceled(t *testing.T) { @@ -77,13 +127,11 @@ func TestPoolGet_ContextAlreadyCanceled(t *testing.T) { func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - createCalls := 0 - + var createCalls Counter createFunc := func() (interface{}, error) { cancel() time.Sleep(1 * time.Second) - createCalls += 1 - return createCalls, nil + return createCalls.Next(), nil } pool := pool.New(createFunc) @@ -92,13 +140,68 @@ func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) { assert.Nil(t, res) } -func TestPoolReturnPanicsIfResourceNotPartOfPool(t *testing.T) { - createCalls := 0 +func TestPoolReturn_PanicsIfResourceNotPartOfPool(t *testing.T) { + var createCalls Counter createFunc := func() (interface{}, error) { - createCalls += 1 - return createCalls, nil + return createCalls.Next(), nil } pool := pool.New(createFunc) assert.Panics(t, func() { pool.Return(42) }) } + +func BenchmarkPoolGetAndReturnNoContention(b *testing.B) { + var createCalls Counter + createFunc := func() (interface{}, error) { + return createCalls.Next(), nil + } + pool := pool.New(createFunc) + + for i := 0; i < b.N; i++ { + res, err := pool.Get(context.Background()) + if err != nil { + b.Fatal(err) + } + pool.Return(res) + } +} + +func BenchmarkPoolGetAndReturnHeavyContention(b *testing.B) { + poolSize := 8 + contentionClients := 15 + + var createCalls Counter + createFunc := func() (interface{}, error) { + return createCalls.Next(), nil + } + pool := pool.New(createFunc) + pool.SetMaxSize(poolSize) + + doneChan := make(chan struct{}) + defer close(doneChan) + for i := 0; i < contentionClients; i++ { + go func() { + for { + select { + case <-doneChan: + return + default: + } + + res, err := pool.Get(context.Background()) + if err != nil { + b.Fatal(err) + } + pool.Return(res) + } + }() + } + + for i := 0; i < b.N; i++ { + res, err := pool.Get(context.Background()) + if err != nil { + b.Fatal(err) + } + pool.Return(res) + } +}