2
0

Initial proof-of-concept database/sql context support

This commit is contained in:
Jack Christensen
2017-02-06 19:39:34 -06:00
parent 14eedb4fca
commit 351eb8ba67
2 changed files with 87 additions and 11 deletions
+41 -11
View File
@@ -619,6 +619,41 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
// concern for if the statement has already been prepared. // concern for if the statement has already been prepared.
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
return c.PrepareExContext(context.Background(), name, sql, opts)
}
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
doneChan := make(chan struct{})
closedChan := make(chan struct{})
go func() {
select {
case <-ctx.Done():
c.cancelQuery()
c.Close()
closedChan <- struct{}{}
case <-doneChan:
}
}()
ps, err = c.prepareEx(name, sql, opts)
select {
case <-closedChan:
return nil, ctx.Err()
case doneChan <- struct{}{}:
return ps, err
}
}
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
if name != "" { if name != "" {
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil return ps, nil
@@ -1349,29 +1384,24 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa
} }
doneChan := make(chan struct{}) doneChan := make(chan struct{})
closedChan := make(chan bool) closedChan := make(chan struct{})
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.cancelQuery() c.cancelQuery()
c.Close() c.Close()
<-doneChan closedChan <- struct{}{}
closedChan <- true
case <-doneChan: case <-doneChan:
closedChan <- false
} }
}() }()
commandTag, err = c.Exec(sql, arguments...) commandTag, err = c.Exec(sql, arguments...)
// Signal cancelation goroutine that operation is done select {
doneChan <- struct{}{} case <-closedChan:
// If c was closed due to context cancelation then return context err
if <-closedChan {
return "", ctx.Err() return "", ctx.Err()
case doneChan <- struct{}{}:
return commandTag, err
} }
return commandTag, err
} }
+46
View File
@@ -44,6 +44,7 @@
package stdlib package stdlib
import ( import (
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
@@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
return c.queryPrepared("", argsV) return c.queryPrepared("", argsV)
} }
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
ps, err := c.conn.PrepareExContext(ctx, "", query, nil)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, "", argsV)
}
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
if !c.conn.IsAlive() { if !c.conn.IsAlive() {
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
@@ -226,6 +242,24 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
return &Rows{rows: rows}, nil return &Rows{rows: rows}, nil
} }
func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
args := namedValueToInterface(argsV)
rows, err := c.conn.QueryContext(ctx, name, args...)
if err != nil {
fmt.Println(err)
return nil, err
}
fmt.Println("ere")
return &Rows{rows: rows}, nil
}
// Anything that isn't a database/sql compatible type needs to be forced to // Anything that isn't a database/sql compatible type needs to be forced to
// text format so that pgx.Rows.Values doesn't decode it into a native type // text format so that pgx.Rows.Values doesn't decode it into a native type
// (e.g. []int32) // (e.g. []int32)
@@ -318,6 +352,18 @@ func valueToInterface(argsV []driver.Value) []interface{} {
return args return args
} }
func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
args := make([]interface{}, 0, len(argsV))
for _, v := range argsV {
if v.Value != nil {
args = append(args, v.Value.(interface{}))
} else {
args = append(args, nil)
}
}
return args
}
type Tx struct { type Tx struct {
conn *pgx.Conn conn *pgx.Conn
} }