From 201d8561f93bd7f4693b709441ec299e4147efe8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Dec 2018 16:40:42 -0600 Subject: [PATCH] Initial commit --- pool.go | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++ pool_test.go | 104 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 pool.go create mode 100644 pool_test.go diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..6063294 --- /dev/null +++ b/pool.go @@ -0,0 +1,116 @@ +package pool + +import ( + "context" + "sync" +) + +const ( + resourceStatusCreating = 0 + resourceStatusAvailable = iota + resourceStatusBorrowed = iota +) + +type CreateFunc func() (res interface{}, err error) + +type resourceWrapper struct { + resource interface{} + status byte +} + +// Pool is a thread-safe resource pool. +type Pool struct { + cond *sync.Cond + + allResources map[interface{}]*resourceWrapper + availableResources []*resourceWrapper + + create CreateFunc +} + +func New(create CreateFunc) *Pool { + return &Pool{ + cond: sync.NewCond(new(sync.Mutex)), + allResources: make(map[interface{}]*resourceWrapper), + create: create, + } +} + +// Get gets a resource from the pool. If no resources are available and the pool +// is not at maximum capacity it will create a new resource. If the pool is at +// maximum capacity it will block until a resource is available. ctx can be used +// to cancel the Get. +func (p *Pool) Get(ctx context.Context) (interface{}, error) { + if doneChan := ctx.Done(); doneChan != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + + p.cond.L.Lock() + + if len(p.availableResources) > 0 { + rw := p.availableResources[len(p.availableResources)-1] + p.availableResources = p.availableResources[:len(p.availableResources)-1] + if rw.status != resourceStatusAvailable { + panic("BUG: unavailable resource gotten from availableResources") + } + rw.status = resourceStatusBorrowed + p.cond.L.Unlock() + return rw.resource, nil + } + + // if can create resource + + var localVal int + placeholder := &localVal + p.allResources[placeholder] = &resourceWrapper{resource: placeholder, status: resourceStatusCreating} + p.cond.L.Unlock() + + resChan := make(chan interface{}) + errChan := make(chan error) + + go func() { + res, err := p.create() + if err != nil { + errChan <- err + } + resChan <- res + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-errChan: + p.cond.L.Lock() + delete(p.allResources, placeholder) + p.cond.L.Unlock() + return nil, err + case res := <-resChan: + p.cond.L.Lock() + delete(p.allResources, placeholder) + p.allResources[res] = &resourceWrapper{resource: res, status: resourceStatusBorrowed} + p.cond.L.Unlock() + return res, nil + } + +} + +// Return returns res to the the pool. If res is not part of the pool Return +// will panic. +func (p *Pool) Return(res interface{}) { + p.cond.L.Lock() + + rw, present := p.allResources[res] + if !present { + p.cond.L.Unlock() + panic("Return called on resource that does not belong to pool") + } + + rw.status = resourceStatusAvailable + p.availableResources = append(p.availableResources, rw) + + p.cond.L.Unlock() +} diff --git a/pool_test.go b/pool_test.go new file mode 100644 index 0000000..2365ad8 --- /dev/null +++ b/pool_test.go @@ -0,0 +1,104 @@ +package pool_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/jackc/pool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPoolGet_CreatesResourceWhenNoneAvailable(t *testing.T) { + createCalls := 0 + createFunc := func() (interface{}, error) { + createCalls += 1 + return createCalls, nil + } + pool := pool.New(createFunc) + + res, err := pool.Get(context.Background()) + require.NoError(t, err) + assert.Equal(t, 1, res) + + pool.Return(res) +} + +func TestPoolGet_ReturnsErrorFromFailedResourceCreate(t *testing.T) { + errCreateFailed := errors.New("create failed") + createFunc := func() (interface{}, error) { + return nil, errCreateFailed + } + pool := pool.New(createFunc) + + res, err := pool.Get(context.Background()) + assert.Equal(t, errCreateFailed, err) + assert.Nil(t, res) +} + +func TestPoolGet_ReusesResources(t *testing.T) { + createCalls := 0 + createFunc := func() (interface{}, error) { + createCalls += 1 + return createCalls, nil + } + pool := pool.New(createFunc) + + res, err := pool.Get(context.Background()) + require.NoError(t, err) + assert.Equal(t, 1, res) + + pool.Return(res) + + res, err = pool.Get(context.Background()) + require.NoError(t, err) + assert.Equal(t, 1, res) + + pool.Return(res) + + assert.Equal(t, 1, createCalls) +} + +func TestPoolGet_ContextAlreadyCanceled(t *testing.T) { + createFunc := func() (interface{}, error) { + panic("should never be called") + } + pool := pool.New(createFunc) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + res, err := pool.Get(ctx) + assert.Equal(t, context.Canceled, err) + assert.Nil(t, res) +} + +func TestPoolGet_ContextCanceledDuringCreate(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + createCalls := 0 + + createFunc := func() (interface{}, error) { + cancel() + time.Sleep(1 * time.Second) + createCalls += 1 + return createCalls, nil + } + pool := pool.New(createFunc) + + res, err := pool.Get(ctx) + assert.Equal(t, context.Canceled, err) + assert.Nil(t, res) +} + +func TestPoolReturnPanicsIfResourceNotPartOfPool(t *testing.T) { + createCalls := 0 + createFunc := func() (interface{}, error) { + createCalls += 1 + return createCalls, nil + } + pool := pool.New(createFunc) + + assert.Panics(t, func() { pool.Return(42) }) +}