diff --git a/export_test.go b/export_test.go index e977402..36e8df6 100644 --- a/export_test.go +++ b/export_test.go @@ -5,3 +5,5 @@ import "context" func (p *Pool[T]) AcquireRaw(ctx context.Context) (*Resource[T], error) { return p.acquire(ctx) } + +var AcquireSemAll = acquireSemAll diff --git a/log_test.go b/log_test.go index a3127cb..d425313 100644 --- a/log_test.go +++ b/log_test.go @@ -12,6 +12,7 @@ import ( func TestLog2Uint(t *testing.T) { r := require.New(t) + r.Equal(uint8(0), log2Int(1)) r.Equal(uint8(0), log2Int[uint64](1)) r.Equal(uint8(1), log2Int[uint32](2)) r.Equal(uint8(7), log2Int[uint8](math.MaxUint8)) diff --git a/pool.go b/pool.go index 0a1cfee..bdbca59 100644 --- a/pool.go +++ b/pool.go @@ -115,10 +115,24 @@ func (res *Resource[T]) IdleDuration() time.Duration { // Pool is a concurrency-safe resource pool. type Pool[T any] struct { - mux sync.Mutex - acquireSem *semaphore.Weighted - destructWG sync.WaitGroup + // Pool invariant is that semaphore is locked before mutex (doesn't + // apply to TryAcquire in AcquireAllIdle). Another invariant is that + // semaphore has to be released BEFORE unlock of mutex! + // mux is the pool internal lock. Any modification of shared state of + // the pool (but Acquires of acquireSem) must be performed only by + // holder of the lock. Long running operations are not allowed when mux + // is held. + mux sync.Mutex + // acquireSem provides an allowance to TRY to acquire a resource. The + // acquire of semaphore token doesn't guarantee that an attempt to + // acquire the resource will succeed. + // + // Releases are allowed only when caller holds mux. Acquires have to + // happen before mux is locked. + acquireSem *semaphore.Weighted + + destructWG sync.WaitGroup allResources resList[T] idleResources resList[T] @@ -347,8 +361,8 @@ func (p *Pool[T]) acquire(ctx context.Context) (*Resource[T], error) { p.mux.Lock() if p.closed { - p.mux.Unlock() p.acquireSem.Release(1) + p.mux.Unlock() return nil, ErrClosedPool } @@ -365,8 +379,6 @@ func (p *Pool[T]) acquire(ctx context.Context) (*Resource[T], error) { if len(p.allResources) > int(p.maxSize) { // Unreachable code. - p.mux.Unlock() - p.acquireSem.Release(1) panic("bug: semaphore allowed more acquires than pool allows") } @@ -422,11 +434,11 @@ func (p *Pool[T]) initResourceValue(ctx context.Context, res *Resource[T]) (*Res p.allResources.remove(res) p.destructWG.Done() - p.mux.Unlock() // The resource won't be acquired because its // construction failed. We have to allow someone else to // take that resouce. p.acquireSem.Release(1) + p.mux.Unlock() select { case constructErrChan <- err: @@ -503,12 +515,13 @@ func (p *Pool[T]) TryAcquire(ctx context.Context) (*Resource[T], error) { res := p.createNewResource() go func() { value, err := p.constructor(ctx) + + p.mux.Lock() + defer p.mux.Unlock() // We have to create the resource and only then release the // semaphore - For the time being there is no resource that // someone could acquire. defer p.acquireSem.Release(1) - p.mux.Lock() - defer p.mux.Unlock() if err != nil { p.allResources.remove(res) @@ -524,18 +537,18 @@ func (p *Pool[T]) TryAcquire(ctx context.Context) (*Resource[T], error) { return nil, ErrNotAvailable } -// acquireSemAll acquires all free tokens from sem. This function is guaranteed -// to acquire at least the lowest number of tokens that has been available in -// the semaphore during runtime of this function. +// acquireSemAll tries to acquire num free tokens from sem. This function is +// guaranteed to acquire at least the lowest number of tokens that has been +// available in the semaphore during runtime of this function. // // For the time being, semaphore doesn't allow to acquire all tokens atomically // (see https://github.com/golang/sync/pull/19). We simulate this by trying all // powers of 2 that are less or equal to num. // // For example, let's immagine we have 19 free tokens in the semaphore which in -// total has 24 tokens (i.e. the maxSize of the pool is 24 resources). Then num -// is 24, the log2Uint(24) is 4 and we try to acquire 16, 8, 4, 2 and 1 tokens. -// Out of those, the acquire of 16, 2 and 1 tokens will succeed. +// total has 24 tokens (i.e. the maxSize of the pool is 24 resources). Then if +// num is 24, the log2Uint(24) is 4 and we try to acquire 16, 8, 4, 2 and 1 +// tokens. Out of those, the acquire of 16, 2 and 1 tokens will succeed. // // Naturally, Acquires and Releases of the semaphore might take place // concurrently. For this reason, it's not guaranteed that absolutely all free @@ -550,7 +563,7 @@ func acquireSemAll(sem *semaphore.Weighted, num int) int { var acquired int for i := int(log2Int(num)); i >= 0; i-- { - val := int(1) << i + val := 1 << i if sem.TryAcquire(int64(val)) { acquired += val } @@ -579,7 +592,7 @@ func (p *Pool[T]) AcquireAllIdle() []*Resource[T] { // TryAcquire cannot block, the fact that we hold mutex locked and try // to acquire semaphore cannot result in dead-lock. // - // TODO: Replace this with acquireSem.TryAcqireAll() if it gets to + // TODO: Replace this with acquireSem.TryAcquireAll() if it gets to // upstream. https://github.com/golang/sync/pull/19 _ = acquireSemAll(p.acquireSem, numIdle) @@ -654,9 +667,9 @@ func (p *Pool[T]) Reset() { // releaseAcquiredResource returns res to the the pool. func (p *Pool[T]) releaseAcquiredResource(res *Resource[T], lastUsedNano int64) { - defer p.acquireSem.Release(1) p.mux.Lock() defer p.mux.Unlock() + defer p.acquireSem.Release(1) if p.closed || res.poolResetCount != p.resetCount { p.allResources.remove(res) @@ -672,16 +685,16 @@ func (p *Pool[T]) releaseAcquiredResource(res *Resource[T], lastUsedNano int64) // pool Remove will panic. func (p *Pool[T]) destroyAcquiredResource(res *Resource[T]) { p.destructResourceValue(res.value) - defer p.acquireSem.Release(1) p.mux.Lock() defer p.mux.Unlock() + defer p.acquireSem.Release(1) p.allResources.remove(res) } func (p *Pool[T]) hijackAcquiredResource(res *Resource[T]) { - defer p.acquireSem.Release(1) p.mux.Lock() defer p.mux.Unlock() + defer p.acquireSem.Release(1) p.allResources.remove(res) res.status = resourceStatusHijacked diff --git a/pool_test.go b/pool_test.go index 40c987a..45ee822 100644 --- a/pool_test.go +++ b/pool_test.go @@ -17,6 +17,7 @@ import ( "github.com/jackc/puddle/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/semaphore" ) type Counter struct { @@ -894,6 +895,17 @@ func TestSignalIsSentWhenResourceFailedToCreate(t *testing.T) { wg.Wait() } +func stressTestDur(t testing.TB) time.Duration { + s := os.Getenv("STRESS_TEST_DURATION") + if s == "" { + s = "1s" + } + + dur, err := time.ParseDuration(s) + require.Nil(t, err) + return dur +} + func TestStress(t *testing.T) { constructor, _ := createConstructor() var destructorCalls Counter @@ -980,10 +992,10 @@ func TestStress(t *testing.T) { for i := 0; i < workerCount; i++ { wg.Add(1) go func() { + defer wg.Done() for { select { case <-finishChan: - wg.Done() return default: } @@ -993,18 +1005,116 @@ func TestStress(t *testing.T) { }() } - s := os.Getenv("STRESS_TEST_DURATION") - if s == "" { - s = "1s" - } - testDuration, err := time.ParseDuration(s) - require.Nil(t, err) - time.AfterFunc(testDuration, func() { close(finishChan) }) + time.AfterFunc(stressTestDur(t), func() { close(finishChan) }) wg.Wait() - pool.Close() } +func TestStress_AcquireAllIdle_TryAcquire(t *testing.T) { + r := require.New(t) + + pool := testPool[int32](t) + + var wg sync.WaitGroup + done := make(chan struct{}) + + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-done: + return + default: + } + + idleRes := pool.AcquireAllIdle() + r.Less(len(idleRes), 2) + for _, res := range idleRes { + res.Release() + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-done: + return + default: + } + + res, err := pool.TryAcquire(context.Background()) + if err != nil { + r.Equal(puddle.ErrNotAvailable, err) + } else { + r.NotNil(res) + res.Release() + } + } + }() + + time.AfterFunc(stressTestDur(t), func() { close(done) }) + wg.Wait() +} + +func TestStress_AcquireAllIdle_Acquire(t *testing.T) { + r := require.New(t) + + pool := testPool[int32](t) + + var wg sync.WaitGroup + done := make(chan struct{}) + + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-done: + return + default: + } + + idleRes := pool.AcquireAllIdle() + r.Less(len(idleRes), 2) + for _, res := range idleRes { + r.NotNil(res) + res.Release() + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-done: + return + default: + } + + res, err := pool.Acquire(context.Background()) + if err != nil { + r.Equal(puddle.ErrNotAvailable, err) + } else { + r.NotNil(res) + res.Release() + } + } + }() + + time.AfterFunc(stressTestDur(t), func() { close(done) }) + wg.Wait() +} + func startAcceptOnceDummyServer(laddr string) { ln, err := net.Listen("tcp", laddr) if err != nil { @@ -1170,7 +1280,22 @@ func BenchmarkPoolAcquireAndRelease(b *testing.B) { }) } } -func benchmarkPool[T any](t testing.TB) *puddle.Pool[T] { + +func TestAcquireAllSem(t *testing.T) { + r := require.New(t) + + sem := semaphore.NewWeighted(5) + r.Equal(4, puddle.AcquireSemAll(sem, 4)) + sem.Release(4) + + r.Equal(5, puddle.AcquireSemAll(sem, 5)) + sem.Release(5) + + r.Equal(5, puddle.AcquireSemAll(sem, 6)) + sem.Release(5) +} + +func testPool[T any](t testing.TB) *puddle.Pool[T] { cfg := puddle.Config[T]{ MaxSize: 1, Constructor: func(ctx context.Context) (T, error) { @@ -1210,7 +1335,7 @@ func TestReleaseAfterAcquire(t *testing.T) { r := require.New(t) ctx := context.Background() - pool := benchmarkPool[int32](t) + pool := testPool[int32](t) releaseChan := releaser[int32](t) res, err := pool.Acquire(ctx) @@ -1229,7 +1354,7 @@ func TestReleaseAfterAcquire(t *testing.T) { func BenchmarkAcquire_ReleaseAfterAcquire(b *testing.B) { r := require.New(b) ctx := context.Background() - pool := benchmarkPool[int32](b) + pool := testPool[int32](b) releaseChan := releaser[int32](b) res, err := pool.Acquire(ctx) @@ -1268,7 +1393,7 @@ func withCPULoad() { func BenchmarkAcquire_ReleaseAfterAcquireWithCPULoad(b *testing.B) { r := require.New(b) ctx := context.Background() - pool := benchmarkPool[int32](b) + pool := testPool[int32](b) releaseChan := releaser[int32](b) withCPULoad() @@ -1292,7 +1417,7 @@ func BenchmarkAcquire_MultipleCancelled(b *testing.B) { r := require.New(b) ctx := context.Background() - pool := benchmarkPool[int32](b) + pool := testPool[int32](b) releaseChan := releaser[int32](b) cancelCtx, cancel := context.WithCancel(ctx) @@ -1322,7 +1447,7 @@ func BenchmarkAcquire_MultipleCancelledWithCPULoad(b *testing.B) { r := require.New(b) ctx := context.Background() - pool := benchmarkPool[int32](b) + pool := testPool[int32](b) releaseChan := releaser[int32](b) cancelCtx, cancel := context.WithCancel(ctx) diff --git a/resource_list.go b/resource_list.go index 9938bfb..b243095 100644 --- a/resource_list.go +++ b/resource_list.go @@ -16,9 +16,10 @@ func (l *resList[T]) popBack() *Resource[T] { func (l *resList[T]) remove(val *Resource[T]) { for i, elem := range *l { if elem == val { - (*l)[i] = (*l)[len(*l)-1] - (*l)[len(*l)-1] = nil // Avoid memory leak - (*l) = (*l)[:len(*l)-1] + lastIdx := len(*l) - 1 + (*l)[i] = (*l)[lastIdx] + (*l)[lastIdx] = nil // Avoid memory leak + (*l) = (*l)[:lastIdx] return } } diff --git a/resource_list_test.go b/resource_list_test.go index f73b78d..7104189 100644 --- a/resource_list_test.go +++ b/resource_list_test.go @@ -4,8 +4,49 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func TestResList_Append(t *testing.T) { + r := require.New(t) + + arr := []*Resource[any]{ + new(Resource[any]), + new(Resource[any]), + new(Resource[any]), + } + + list := resList[any](arr) + + list.append(new(Resource[any])) + r.Len(list, 4) + list.append(new(Resource[any])) + r.Len(list, 5) + list.append(new(Resource[any])) + r.Len(list, 6) +} + +func TestResList_PopBack(t *testing.T) { + r := require.New(t) + + arr := []*Resource[any]{ + new(Resource[any]), + new(Resource[any]), + new(Resource[any]), + } + + list := resList[any](arr) + + list.popBack() + r.Len(list, 2) + list.popBack() + r.Len(list, 1) + list.popBack() + r.Len(list, 0) + + r.Panics(func() { list.popBack() }) +} + func TestResList_PanicsWithBugReportIfResourceDoesNotExist(t *testing.T) { arr := []*Resource[any]{ new(Resource[any]),