From 004c18e5a21c7837cb6dc578f22471115b29fdc8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Feb 2017 20:35:37 -0600 Subject: [PATCH] Begin extracting context handling --- conn.go | 53 +++++++++++++++++++++++------------------------------ query.go | 27 ++++++--------------------- 2 files changed, 29 insertions(+), 51 deletions(-) diff --git a/conn.go b/conn.go index b8131716..453f1a51 100644 --- a/conn.go +++ b/conn.go @@ -88,6 +88,10 @@ type Conn struct { closingLock sync.Mutex alive bool causeOfDeath error + + // context support + doneChan chan struct{} + closedChan chan struct{} } // PreparedStatement is a description of a prepared statement @@ -257,6 +261,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.channels = make(map[string]struct{}) c.alive = true c.lastActivityTime = time.Now() + c.doneChan = make(chan struct{}) + c.closedChan = make(chan struct{}) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -619,8 +625,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - return c.PrepareExContext(context.Background(), name, sql, opts) - + return c.prepareEx(name, sql, opts) } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { @@ -630,25 +635,14 @@ func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *Pre default: } - doneChan := make(chan struct{}) - closedChan := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- struct{}{} - case <-doneChan: - } - }() + go c.contextHandler(ctx) ps, err = c.prepareEx(name, sql, opts) select { - case <-closedChan: + case <-c.closedChan: return nil, ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return ps, err } } @@ -1383,25 +1377,24 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa default: } - doneChan := make(chan struct{}) - closedChan := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- struct{}{} - case <-doneChan: - } - }() + go c.contextHandler(ctx) commandTag, err = c.Exec(sql, arguments...) select { - case <-closedChan: + case <-c.closedChan: return "", ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return commandTag, err } } + +func (c *Conn) contextHandler(ctx context.Context) { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + c.closedChan <- struct{}{} + case <-c.doneChan: + } +} diff --git a/query.go b/query.go index 3ded881d..daf1b354 100644 --- a/query.go +++ b/query.go @@ -51,9 +51,7 @@ type Rows struct { unlockConn bool closed bool - ctx context.Context - doneChan chan struct{} - closedChan chan bool + ctx context.Context } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -128,9 +126,9 @@ func (rows *Rows) Close() { if rows.ctx != nil { select { - case <-rows.closedChan: + case <-rows.conn.closedChan: rows.err = rows.ctx.Err() - case rows.doneChan <- struct{}{}: + case rows.conn.doneChan <- struct{}{}: } } @@ -508,33 +506,20 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - doneChan := make(chan struct{}) - closedChan := make(chan bool) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- true - case <-doneChan: - } - }() + go c.contextHandler(ctx) rows, err := c.Query(sql, args...) if err != nil { select { - case <-closedChan: + case <-c.closedChan: return rows, ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return rows, err } } rows.ctx = ctx - rows.doneChan = doneChan - rows.closedChan = closedChan return rows, nil }