diff --git a/pool.go b/pool.go index acbcc33..a413e4d 100644 --- a/pool.go +++ b/pool.go @@ -311,32 +311,79 @@ func (p *Pool[T]) Acquire(ctx context.Context) (*Resource[T], error) { p.destructWG.Add(1) p.cond.L.Unlock() - value, err := p.constructResourceValue(ctx) - p.cond.L.Lock() - if err != nil { - p.allResources = removeResource(p.allResources, res) - p.destructWG.Done() + // we create the resource in the background because the constructor might + // outlive the context and we want to continue constructing it as long as + // necessary but the acquire should be cancelled when the context is cancelled + // see: https://github.com/jackc/pgx/issues/1287 and https://github.com/jackc/pgx/issues/1259 + constructErrCh := make(chan error) + go func() { + value, err := p.constructResourceValue(ctx) + p.cond.L.Lock() + if err != nil { + p.allResources = removeResource(p.allResources, res) + p.destructWG.Done() + select { + case <-ctx.Done(): + if err == ctx.Err() { + p.canceledAcquireCount += 1 + } + default: + } + + p.cond.L.Unlock() + p.cond.Signal() + + // try to notify the caller that we failed + select { + case constructErrCh <- err: + default: + } + return + } + res.value = value + + // check the context now so we don't increment the metrics when the caller + // has already been cancelled select { case <-ctx.Done(): - if err == ctx.Err() { - p.canceledAcquireCount += 1 - } + p.cond.L.Unlock() default: + // assume that we will acquire it + res.status = resourceStatusAcquired + // we have to increment these BEFORE the next select because otherwise + // they could run after Acquire has returned and that will mess up + // tests but this also means that these could be incremented even if + // the acquire times out, but we just checked that so the chances are + // slim + p.emptyAcquireCount += 1 + p.acquireCount += 1 + p.acquireDuration += time.Duration(nanotime() - startNano) + p.cond.L.Unlock() + // we don't call Signal here we didn't change any of the resopurce pools } - p.cond.L.Unlock() - p.cond.Signal() - return nil, err - } + select { + case constructErrCh <- nil: + default: + // since we couldn't send the constructed resource to the acquire + // function that means the caller has stopped waiting and we should + // just put this resource back in the pool + p.releaseAcquiredResource(res, res.lastUsedNano) + } + }() - res.value = value - res.status = resourceStatusAcquired - p.emptyAcquireCount += 1 - p.acquireCount += 1 - p.acquireDuration += time.Duration(nanotime() - startNano) - p.cond.L.Unlock() - return res, nil + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-constructErrCh: + if err != nil { + return nil, err + } + // we don't call signal here because we didn't change the resource pools + // at all so waking anything else up won't help + return res, nil + } } if ctx.Done() == nil { @@ -355,8 +402,8 @@ func (p *Pool[T]) Acquire(ctx context.Context) (*Resource[T], error) { // do anything with it. Another goroutine might be waiting. go func() { <-waitChan - p.cond.Signal() p.cond.L.Unlock() + p.cond.Signal() }() p.cond.L.Lock() diff --git a/pool_test.go b/pool_test.go index 5a0f645..d7e82b7 100644 --- a/pool_test.go +++ b/pool_test.go @@ -133,6 +133,32 @@ func TestPoolAcquireReturnsErrorFromFailedResourceCreate(t *testing.T) { assert.Nil(t, res) } +func TestPoolAcquireCreatesResourceRespectingContext(t *testing.T) { + var cancel func() + constructor := func(ctx context.Context) (int, error) { + cancel() + // sleep to give a chance for the acquire to recognize it's cancelled + time.Sleep(10 * time.Millisecond) + return 1, nil + } + pool := puddle.NewPool(constructor, stubDestructor, 1) + defer pool.Close() + + var ctx context.Context + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + _, err := pool.Acquire(ctx) + assert.ErrorIs(t, err, context.Canceled) + + // wait for the constructor to sleep and then for the resource to be added back + // to the idle pool + time.Sleep(100 * time.Millisecond) + + stat := pool.Stat() + assert.EqualValues(t, 1, stat.IdleResources()) + assert.EqualValues(t, 1, stat.TotalResources()) +} + func TestPoolAcquireReusesResources(t *testing.T) { constructor, createCounter := createConstructor() pool := puddle.NewPool(constructor, stubDestructor, 10) @@ -544,7 +570,12 @@ func TestPoolStatResources(t *testing.T) { func TestPoolStatSuccessfulAcquireCounters(t *testing.T) { constructor, _ := createConstructor() - pool := puddle.NewPool(constructor, stubDestructor, 1) + sleepConstructor := func(ctx context.Context) (int, error) { + // sleep to make sure we don't fail the AcquireDuration test + time.Sleep(time.Nanosecond) + return constructor(ctx) + } + pool := puddle.NewPool(sleepConstructor, stubDestructor, 1) defer pool.Close() res, err := pool.Acquire(context.Background())