Add simple protocol suuport with (Query|Exec)Ex
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
@@ -123,6 +124,17 @@ func (rows *Rows) Next() bool {
|
||||
}
|
||||
|
||||
switch t {
|
||||
case rowDescription:
|
||||
rows.fields = rows.conn.rxRowDescription(r)
|
||||
for i := range rows.fields {
|
||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok {
|
||||
rows.fields[i].DataTypeName = dt.Name
|
||||
rows.fields[i].FormatCode = TextFormatCode
|
||||
} else {
|
||||
rows.Fatal(fmt.Errorf("unknown oid: %d", rows.fields[i].DataType))
|
||||
return false
|
||||
}
|
||||
}
|
||||
case dataRow:
|
||||
fieldCount := r.readInt16()
|
||||
if int(fieldCount) != len(rows.fields) {
|
||||
@@ -341,7 +353,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) {
|
||||
// be returned in an error state. So it is allowed to ignore the error returned
|
||||
// from Query and handle it in *Rows.
|
||||
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||
return c.QueryContext(context.Background(), sql, args...)
|
||||
return c.QueryEx(context.Background(), sql, nil, args...)
|
||||
}
|
||||
|
||||
func (c *Conn) getRows(sql string, args []interface{}) *Rows {
|
||||
@@ -368,7 +380,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
|
||||
type QueryExOptions struct {
|
||||
SimpleProtocol bool
|
||||
}
|
||||
|
||||
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -384,6 +400,22 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
|
||||
}
|
||||
rows.unlockConn = true
|
||||
|
||||
if options != nil && options.SimpleProtocol {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
@@ -411,7 +443,32 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
|
||||
rows, _ := c.QueryContext(ctx, sql, args...)
|
||||
func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) {
|
||||
if c.RuntimeParams["standard_conforming_strings"] != "on" {
|
||||
return errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
||||
}
|
||||
|
||||
if c.RuntimeParams["client_encoding"] != "UTF8" {
|
||||
return errors.New("simple protocol queries must be run with client_encoding=UTF8")
|
||||
}
|
||||
|
||||
valueArgs := make([]interface{}, len(args))
|
||||
for i, a := range args {
|
||||
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
sql, err = sanitize.SanitizeSQL(sql, valueArgs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.sendSimpleQuery(sql)
|
||||
}
|
||||
|
||||
func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
|
||||
rows, _ := c.QueryEx(ctx, sql, options, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user