diff --git a/pool.go b/pool.go index c368540..a5f2b40 100644 --- a/pool.go +++ b/pool.go @@ -20,7 +20,7 @@ const maxInt = int(maxUint >> 1) // ErrClosedPool occurs on an attempt to get a connection from a closed pool. var ErrClosedPool = errors.New("cannot get from closed pool") -type CreateFunc func() (res interface{}, err error) +type CreateFunc func(ctx context.Context) (res interface{}, err error) type CloseFunc func(res interface{}) (err error) // BackgroundErrorHandler is the type of function that handles background @@ -109,8 +109,6 @@ func (p *Pool) SetMinSize(n int) { p.cond.L.Lock() p.minSize = n - p.ensureMinResources() - p.cond.L.Unlock() } @@ -206,11 +204,26 @@ func (p *Pool) Get(ctx context.Context) (interface{}, error) { return res, nil } - // If there is room to create a resource start the process asynchronously - var createResChan chan interface{} - var createErrChan chan error + // If there is room to create a resource do so if len(p.allResources) < p.maxSize { - createResChan, createErrChan = p.startCreate() + var localVal int + placeholder := &localVal + startTime := time.Now() + p.allResources[placeholder] = &resourceWrapper{resource: placeholder, creationTime: startTime, status: resourceStatusCreating} + p.cond.L.Unlock() + + res, err := p.createRes(ctx) + p.cond.L.Lock() + delete(p.allResources, placeholder) + if err != nil { + p.cond.L.Unlock() + return nil, err + } + + rw := &resourceWrapper{resource: res, creationTime: startTime, status: resourceStatusBorrowed, checkoutCount: 1} + p.allResources[res] = rw + p.cond.L.Unlock() + return res, nil } p.cond.L.Unlock() @@ -239,16 +252,8 @@ func (p *Pool) Get(ctx context.Context) (interface{}, error) { select { case <-ctx.Done(): close(abortWaitResChan) - p.backgroundFinishCreate(createResChan, createErrChan) return nil, ctx.Err() - case err := <-createErrChan: - close(abortWaitResChan) - return nil, err - case res := <-createResChan: - close(abortWaitResChan) - return res, nil case res := <-waitResChan: - p.backgroundFinishCreate(createResChan, createErrChan) return res, nil } } @@ -277,50 +282,6 @@ func (p *Pool) lockedAvailableGet() interface{} { return rw.resource } -// startCreate starts creating a new resource. p.cond.L must already be -// locked. The newly created resource will be sent on resChan (already checked -// out) or an error will be sent on errChan. -func (p *Pool) startCreate() (resChan chan interface{}, errChan chan error) { - resChan = make(chan interface{}) - errChan = make(chan error) - - var localVal int - placeholder := &localVal - startTime := time.Now() - p.allResources[placeholder] = &resourceWrapper{resource: placeholder, creationTime: startTime, status: resourceStatusCreating} - - go func() { - res, err := p.createRes() - p.cond.L.Lock() - delete(p.allResources, placeholder) - if err != nil { - p.cond.L.Unlock() - errChan <- err - return - } - - rw := &resourceWrapper{resource: res, creationTime: startTime, status: resourceStatusBorrowed, checkoutCount: 1} - p.allResources[res] = rw - p.cond.L.Unlock() - resChan <- res - }() - - return resChan, errChan -} - -func (p *Pool) backgroundFinishCreate(resChan chan interface{}, errChan chan error) { - go func() { - select { - case res := <-resChan: - p.Return(res) - case err := <-errChan: - p.cond.L.Lock() - p.backgroundErrorHandler(err) - p.cond.L.Unlock() - } - }() -} - func (p *Pool) backgroundClose(res interface{}) { go func() { err := p.closeRes(res) @@ -359,7 +320,6 @@ func (p *Pool) Return(res interface{}) { if closeResource { delete(p.allResources, rw.resource) - p.ensureMinResources() p.cond.L.Unlock() p.backgroundClose(rw.resource) return @@ -394,17 +354,4 @@ func (p *Pool) Remove(res interface{}) { p.cond.L.Unlock() } }() - - p.ensureMinResources() -} - -// ensureMinResources creates new resources if necessary to get pool up to min size. -// If pool is closed does nothing. p.cond.L must already be locked. -func (p *Pool) ensureMinResources() { - if !p.closed { - for len(p.allResources) < p.minSize { - createResChan, createErrChan := p.startCreate() - p.backgroundFinishCreate(createResChan, createErrChan) - } - } } diff --git a/pool_test.go b/pool_test.go index 563c54d..99ed031 100644 --- a/pool_test.go +++ b/pool_test.go @@ -37,7 +37,7 @@ func (c *Counter) Value() int { func createCreateResourceFunc() (puddle.CreateFunc, *Counter) { var c Counter - f := func() (interface{}, error) { + f := func(ctx context.Context) (interface{}, error) { return c.Next(), nil } return f, &c @@ -46,7 +46,7 @@ func createCreateResourceFunc() (puddle.CreateFunc, *Counter) { func createCreateResourceFuncWithNotifierChan() (puddle.CreateFunc, *Counter, chan int) { ch := make(chan int) var c Counter - f := func() (interface{}, error) { + f := func(ctx context.Context) (interface{}, error) { n := c.Next() // Because the tests will not read from ch until after the create function f returns. @@ -93,13 +93,6 @@ func TestPoolGetCreatesResourceWhenNoneAvailable(t *testing.T) { pool.Return(res) } -func TestPoolSetMinSizeImmediatelyCreatesNewResources(t *testing.T) { - createFunc, _ := createCreateResourceFunc() - pool := puddle.NewPool(createFunc, stubCloseRes) - pool.SetMinSize(2) - assert.Equal(t, 2, pool.Size()) -} - func TestPoolGetDoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) { createFunc, createCounter := createCreateResourceFunc() pool := puddle.NewPool(createFunc, stubCloseRes) @@ -128,7 +121,7 @@ func TestPoolGetDoesNotCreatesResourceWhenItWouldExceedMaxSize(t *testing.T) { func TestPoolGetReturnsErrorFromFailedResourceCreate(t *testing.T) { errCreateFailed := errors.New("create failed") - createFunc := func() (interface{}, error) { + createFunc := func(ctx context.Context) (interface{}, error) { return nil, errCreateFailed } pool := puddle.NewPool(createFunc, stubCloseRes) @@ -158,7 +151,7 @@ func TestPoolGetReusesResources(t *testing.T) { } func TestPoolGetContextAlreadyCanceled(t *testing.T) { - createFunc := func() (interface{}, error) { + createFunc := func(ctx context.Context) (interface{}, error) { panic("should never be called") } pool := puddle.NewPool(createFunc, stubCloseRes) @@ -172,11 +165,16 @@ func TestPoolGetContextAlreadyCanceled(t *testing.T) { func TestPoolGetContextCanceledDuringCreate(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(100*time.Millisecond, cancel) + timeoutChan := time.After(1 * time.Second) var createCalls Counter - createFunc := func() (interface{}, error) { - cancel() - time.Sleep(1 * time.Second) + createFunc := func(ctx context.Context) (interface{}, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timeoutChan: + } return createCalls.Next(), nil } pool := puddle.NewPool(createFunc, stubCloseRes) @@ -306,47 +304,6 @@ func TestPoolRemoveRemovesResourceFromPool(t *testing.T) { assert.Equal(t, 0, pool.Size()) } -func TestPoolRemoveRemovesResourceFromPoolAndStartsNewCreationToMaintainMinSize(t *testing.T) { - createFunc, createCounter, createCallsChan := createCreateResourceFuncWithNotifierChan() - closeFunc, closeCalls, closeCallsChan := createCloseResourceFuncWithNotifierChan() - - pool := puddle.NewPool(createFunc, closeFunc) - - // Ensure there are 2 resources available in pool - { - r1, err := pool.Get(context.Background()) - require.Nil(t, err) - r2, err := pool.Get(context.Background()) - require.Nil(t, err) - pool.Return(r1) - pool.Return(r2) - } - - assert.Equal(t, 2, pool.Size()) - pool.SetMinSize(2) - assert.Equal(t, 2, pool.Size()) - - { - r1, err := pool.Get(context.Background()) - require.Nil(t, err) - r2, err := pool.Get(context.Background()) - require.Nil(t, err) - pool.Remove(r1) - pool.Remove(r2) - } - - require.True(t, waitForRead(createCallsChan)) - require.True(t, waitForRead(createCallsChan)) - require.True(t, waitForRead(createCallsChan)) - require.True(t, waitForRead(createCallsChan)) - require.True(t, waitForRead(closeCallsChan)) - require.True(t, waitForRead(closeCallsChan)) - - assert.Equal(t, 2, pool.Size()) - assert.Equal(t, 4, createCounter.Value()) - assert.Equal(t, 2, closeCalls.Value()) -} - func TestPoolRemoveRemovesResourceFromPoolAndDoesNotStartNewCreationToMaintainMinSizeWhenPoolIsClosed(t *testing.T) { createFunc, createCounter, createCallsChan := createCreateResourceFuncWithNotifierChan() closeFunc, closeCalls, closeCallsChan := createCloseResourceFuncWithNotifierChan() @@ -399,47 +356,6 @@ func TestPoolGetReturnsErrorWhenPoolIsClosed(t *testing.T) { assert.Nil(t, res) } -func TestPoolGetLateFailedCreateErrorIsReported(t *testing.T) { - errCreateStartedChan := make(chan struct{}) - createWaitChan := make(chan struct{}) - errCreateFailed := errors.New("create failed") - var createCalls Counter - createFunc := func() (interface{}, error) { - n := createCalls.Next() - if n == 1 { - return n, nil - } - close(errCreateStartedChan) - <-createWaitChan - return nil, errCreateFailed - } - pool := puddle.NewPool(createFunc, stubCloseRes) - - asyncErrChan := make(chan error) - pool.SetBackgroundErrorHandler(func(err error) { asyncErrChan <- err }) - - res1, err := pool.Get(context.Background()) - require.NoError(t, err) - assert.Equal(t, 1, res1) - - go func() { - <-errCreateStartedChan - pool.Return(res1) - }() - - res, err := pool.Get(context.Background()) - require.NoError(t, err) - assert.Equal(t, 1, res) - close(createWaitChan) - - select { - case err = <-asyncErrChan: - assert.Equal(t, errCreateFailed, err) - case <-time.NewTimer(time.Second).C: - t.Fatal("timed out waiting for async error") - } -} - func TestPoolCloseResourceCloseErrorIsReported(t *testing.T) { createFunc, _ := createCreateResourceFunc() errCloseFailed := errors.New("close failed")