From deac6564eeb81e6ad3996b9e29f03854a8017f2d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 19:16:13 -0600 Subject: [PATCH] Implement Query in terms of QueryContext - Merge Rows.close into Rows.Close - Merge Rows.abort into Rows.Fatal --- query.go | 91 ++++++++++++++++++++------------------------------ replication.go | 6 ++-- 2 files changed, 39 insertions(+), 58 deletions(-) diff --git a/query.go b/query.go index b6470688..aa664649 100644 --- a/query.go +++ b/query.go @@ -56,7 +56,9 @@ func (rows *Rows) FieldDescriptions() []FieldDescription { return rows.fields } -func (rows *Rows) close() { +// Close closes the rows, making the connection ready for use again. It is safe +// to call Close after rows is already closed. +func (rows *Rows) Close() { if rows.closed { return } @@ -68,6 +70,8 @@ func (rows *Rows) close() { rows.closed = true + rows.err = rows.conn.termContext(rows.err) + if rows.err == nil { if rows.conn.shouldLog(LogLevelInfo) { endTime := time.Now() @@ -82,31 +86,10 @@ func (rows *Rows) close() { } } -// Close closes the rows, making the connection ready for use again. It is safe -// to call Close after rows is already closed. -func (rows *Rows) Close() { - if rows.closed { - return - } - rows.err = rows.conn.termContext(rows.err) - rows.close() -} - func (rows *Rows) Err() error { return rows.err } -// abort signals that the query was not successfully sent to the server. -// This differs from Fatal in that it is not necessary to readUntilReadyForQuery -func (rows *Rows) abort(err error) { - if rows.err != nil { - return - } - - rows.err = err - rows.close() -} - // Fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. func (rows *Rows) Fatal(err error) { @@ -148,7 +131,7 @@ func (rows *Rows) Next() bool { rows.mr = r return true case commandComplete: - rows.close() + rows.Close() return false default: @@ -408,32 +391,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) { // be returned in an error state. So it is allowed to ignore the error returned // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - c.lastActivityTime = time.Now() - - rows := c.getRows(sql, args) - - if err := c.lock(); err != nil { - rows.abort(err) - return rows, err - } - rows.unlockConn = true - - ps, ok := c.preparedStatements[sql] - if !ok { - var err error - ps, err = c.Prepare("", sql) - if err != nil { - rows.abort(err) - return rows, rows.err - } - } - rows.sql = ps.SQL - rows.fields = ps.FieldDescriptions - err := c.sendPreparedQuery(ps, args...) - if err != nil { - rows.abort(err) - } - return rows, rows.err + return c.QueryContext(context.Background(), sql, args...) } func (c *Conn) getRows(sql string, args []interface{}) *Rows { @@ -460,19 +418,42 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } -func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - err := c.initContext(ctx) +func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { + c.lastActivityTime = time.Now() + + rows = c.getRows(sql, args) + + if err := c.lock(); err != nil { + rows.Fatal(err) + return rows, err + } + rows.unlockConn = true + + ps, ok := c.preparedStatements[sql] + if !ok { + var err error + ps, err = c.PrepareExContext(ctx, "", sql, nil) + if err != nil { + rows.Fatal(err) + return rows, rows.err + } + } + rows.sql = ps.SQL + rows.fields = ps.FieldDescriptions + + err = c.initContext(ctx) if err != nil { - return nil, err + rows.Fatal(err) + return rows, err } - rows, err := c.Query(sql, args...) + err = c.sendPreparedQuery(ps, args...) if err != nil { + rows.Fatal(err) err = c.termContext(err) - return nil, err } - return rows, nil + return rows, err } func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { diff --git a/replication.go b/replication.go index 12a5c914..0acc9df9 100644 --- a/replication.go +++ b/replication.go @@ -312,14 +312,14 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows := rc.c.getRows(sql, nil) if err := rc.c.lock(); err != nil { - rows.abort(err) + rows.Fatal(err) return rows, err } rows.unlockConn = true err := rc.c.sendSimpleQuery(sql) if err != nil { - rows.abort(err) + rows.Fatal(err) } var t byte @@ -337,7 +337,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { // only Oids. Not much we can do about this. default: if e := rc.c.processContextFreeMsg(t, r); e != nil { - rows.abort(e) + rows.Fatal(e) return rows, e } }