2
0

Add simple protocol suuport with (Query|Exec)Ex

This commit is contained in:
Jack Christensen
2017-04-10 08:58:51 -05:00
parent 54d9cbc743
commit 7b1f461ec3
16 changed files with 999 additions and 326 deletions
+61 -4
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"time"
"github.com/jackc/pgx/internal/sanitize"
"github.com/jackc/pgx/pgtype"
)
@@ -123,6 +124,17 @@ func (rows *Rows) Next() bool {
}
switch t {
case rowDescription:
rows.fields = rows.conn.rxRowDescription(r)
for i := range rows.fields {
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok {
rows.fields[i].DataTypeName = dt.Name
rows.fields[i].FormatCode = TextFormatCode
} else {
rows.Fatal(fmt.Errorf("unknown oid: %d", rows.fields[i].DataType))
return false
}
}
case dataRow:
fieldCount := r.readInt16()
if int(fieldCount) != len(rows.fields) {
@@ -341,7 +353,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) {
// be returned in an error state. So it is allowed to ignore the error returned
// from Query and handle it in *Rows.
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
return c.QueryContext(context.Background(), sql, args...)
return c.QueryEx(context.Background(), sql, nil, args...)
}
func (c *Conn) getRows(sql string, args []interface{}) *Rows {
@@ -368,7 +380,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
return (*Row)(rows)
}
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
type QueryExOptions struct {
SimpleProtocol bool
}
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
@@ -384,6 +400,22 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
}
rows.unlockConn = true
if options != nil && options.SimpleProtocol {
err = c.initContext(ctx)
if err != nil {
rows.Fatal(err)
return rows, err
}
err = c.sanitizeAndSendSimpleQuery(sql, args...)
if err != nil {
rows.Fatal(err)
return rows, err
}
return rows, nil
}
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
@@ -411,7 +443,32 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
return rows, err
}
func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
rows, _ := c.QueryContext(ctx, sql, args...)
func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) {
if c.RuntimeParams["standard_conforming_strings"] != "on" {
return errors.New("simple protocol queries must be run with standard_conforming_strings=on")
}
if c.RuntimeParams["client_encoding"] != "UTF8" {
return errors.New("simple protocol queries must be run with client_encoding=UTF8")
}
valueArgs := make([]interface{}, len(args))
for i, a := range args {
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
if err != nil {
return err
}
}
sql, err = sanitize.SanitizeSQL(sql, valueArgs...)
if err != nil {
return err
}
return c.sendSimpleQuery(sql)
}
func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
rows, _ := c.QueryEx(ctx, sql, options, args...)
return (*Row)(rows)
}