From 78f498fc43f957b2eccdac1d002798ee3c277a5c Mon Sep 17 00:00:00 2001 From: Kale Blankenship Date: Sat, 31 Aug 2019 10:27:19 -0700 Subject: [PATCH] Add ConnPool.AcquireEx --- conn_pool.go | 19 ++++++ conn_pool_test.go | 144 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+) diff --git a/conn_pool.go b/conn_pool.go index 344f00d7..95e1b015 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -110,6 +110,25 @@ func (p *ConnPool) Acquire() (*Conn, error) { return c, err } +func (p *ConnPool) AcquireEx(ctx context.Context) (*Conn, error) { + var deadline *time.Time + + if p.acquireTimeout > 0 { + tmp := time.Now().Add(p.acquireTimeout) + deadline = &tmp + } + + ctxDeadline, ok := ctx.Deadline() + if ok && (deadline == nil || ctxDeadline.Before(*deadline)) { + deadline = &ctxDeadline + } + + p.cond.L.Lock() + c, err := p.acquire(deadline) + p.cond.L.Unlock() + return c, err +} + // deadlinePassed returns true if the given deadline has passed. func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { return deadline != nil && time.Now().After(*deadline) diff --git a/conn_pool_test.go b/conn_pool_test.go index 84a74aed..83bdf1fd 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -45,6 +45,12 @@ func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) return c, time.Since(startTime), err } +func acquireExWithTimeTaken(pool *pgx.ConnPool, ctx context.Context) (*pgx.Conn, time.Duration, error) { + startTime := time.Now() + c, err := pool.AcquireEx(ctx) + return c, time.Since(startTime), err +} + func TestNewConnPool(t *testing.T) { t.Parallel() @@ -315,6 +321,144 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { } } +func TestPoolWithAcquireExContextTimeoutSet(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutLower(t *testing.T) { + t.Parallel() + + connAllocTimeout := 2 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < connAllocTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) + } + if timeTaken > ctxTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutHigher(t *testing.T) { + t.Parallel() + + connAllocTimeout := 5 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } + if timeTaken > connAllocTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", connAllocTimeout, timeTaken) + } +} + +func TestPoolWithoutAcquireExTimeoutSet(t *testing.T) { + t.Parallel() + + maxConnections := 1 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, maxConnections) + + // ... then try to consume 1 more. It should hang forever. + // To unblock it we release the previously taken connection in a goroutine. + stopDeadWaitTimeout := 5 * time.Second + timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { + releaseAllConnections(pool, allConnections) + }) + defer timer.Stop() + + conn, timeTaken, err := acquireExWithTimeTaken(pool, context.Background()) + if err == nil { + pool.Release(conn) + } else { + t.Fatalf("Expected error to be nil, instead it was '%v'", err) + } + if timeTaken < stopDeadWaitTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) + } +} + func TestPoolErrClosedPool(t *testing.T) { t.Parallel()