From 6be77b4d641adf50e61bb6a1c530e77d0d943d43 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 Apr 2013 15:48:57 -0500 Subject: [PATCH] Added ConnectionPool fixes #7 --- connection_pool.go | 47 +++++++++++++++++ connection_pool_test.go | 111 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 connection_pool.go create mode 100644 connection_pool_test.go diff --git a/connection_pool.go b/connection_pool.go new file mode 100644 index 00000000..7a668c1b --- /dev/null +++ b/connection_pool.go @@ -0,0 +1,47 @@ +package pgx + +type ConnectionPool struct { + connectionChannel chan *Connection + options map[string]string // options used when establishing connection + MaxConnections int +} + +// options: options used by Connect +// MaxConnections: max simultaneous connections to use (currently all are immediately connected) +func NewConnectionPool(options map[string]string, MaxConnections int) (p *ConnectionPool, err error) { + p = new(ConnectionPool) + p.connectionChannel = make(chan *Connection, MaxConnections) + p.MaxConnections = MaxConnections + + p.options = make(map[string]string) + for k, v := range options { + p.options[k] = v + } + + for i := 0; i < p.MaxConnections; i++ { + var c *Connection + c, err = Connect(options) + if err != nil { + return + } + p.connectionChannel <- c + } + + return +} + +func (p *ConnectionPool) Acquire() (c *Connection) { + c = <-p.connectionChannel + return +} + +func (p *ConnectionPool) Release(c *Connection) { + p.connectionChannel <- c +} + +func (p *ConnectionPool) Close() { + for i := 0; i < p.MaxConnections; i++ { + c := <-p.connectionChannel + _ = c.Close() + } +} diff --git a/connection_pool_test.go b/connection_pool_test.go new file mode 100644 index 00000000..d00a017e --- /dev/null +++ b/connection_pool_test.go @@ -0,0 +1,111 @@ +package pgx + +import ( + "fmt" + "testing" +) + +func createConnectionPool(maxConnections int) *ConnectionPool { + connectionOptions := map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"} + pool, err := NewConnectionPool(connectionOptions, maxConnections) + if err != nil { + panic("Unable to create connection pool") + } + return pool +} + +func TestNewConnectionPool(t *testing.T) { + connectionOptions := map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"} + pool, err := NewConnectionPool(connectionOptions, 5) + if err != nil { + t.Fatal("Unable to establish connection pool") + } + defer pool.Close() + + if pool.MaxConnections != 5 { + t.Error("Wrong maxConnections") + } +} + +func TestPoolAcquireAndReleaseCycle(t *testing.T) { + maxConnections := 2 + incrementCount := int32(100) + completeSync := make(chan int) + pool := createConnectionPool(maxConnections) + defer pool.Close() + + acquireAll := func() (connections []*Connection) { + connections = make([]*Connection, maxConnections) + for i := 0; i < maxConnections; i++ { + connections[i] = pool.Acquire() + } + return + } + + allConnections := acquireAll() + + for _, c := range allConnections { + var err error + if _, err = c.Execute("create temporary table t(counter integer not null)"); err != nil { + t.Fatal("Unable to create temp table:" + err.Error()) + } + if _, err = c.Execute("insert into t(counter) values(0);"); err != nil { + t.Fatal("Unable to insert initial counter row: " + err.Error()) + } + } + + for _, c := range allConnections { + pool.Release(c) + } + + f := func() { + var err error + conn := pool.Acquire() + if err != nil { + t.Fatal("Unable to acquire connection") + } + defer pool.Release(conn) + + // Increment counter... + _, err = conn.Execute("update t set counter = counter + 1") + if err != nil { + t.Fatal("Unable to update counter: " + err.Error()) + } + completeSync <- 0 + } + + for i := int32(0); i < incrementCount; i++ { + go f() + } + + // Wait for all f() to complete + for i := int32(0); i < incrementCount; i++ { + <-completeSync + } + + // Check that temp table in each connection has been incremented some number of times + actualCount := int32(0) + allConnections = acquireAll() + + for _, c := range allConnections { + n, err := c.SelectInt32("select counter from t") + if err != nil { + t.Fatal("Unable to read back execution counter: " + err.Error()) + } + + if n == 0 { + t.Error("A connection was never used") + } + + actualCount += n + } + + if actualCount != incrementCount { + fmt.Println(actualCount) + t.Error("Wrong number of increments") + } + + for _, c := range allConnections { + pool.Release(c) + } +}