Remove extra prepare in stdlib
This commit is contained in:
@@ -580,14 +580,19 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
|
||||||
type QueryResultFormats []int16
|
type QueryResultFormats []int16
|
||||||
|
|
||||||
|
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
|
||||||
|
type QueryResultFormatsByOID map[pgtype.OID]int16
|
||||||
|
|
||||||
// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is
|
// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is
|
||||||
// allowed to ignore the error returned from Query and handle it in Rows.
|
// allowed to ignore the error returned from Query and handle it in Rows.
|
||||||
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
|
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
|
||||||
// rows = c.getRows(sql, args)
|
// rows = c.getRows(sql, args)
|
||||||
|
|
||||||
var resultFormats QueryResultFormats
|
var resultFormats QueryResultFormats
|
||||||
|
var resultFormatsByOID QueryResultFormatsByOID
|
||||||
|
|
||||||
optionLoop:
|
optionLoop:
|
||||||
for len(args) > 0 {
|
for len(args) > 0 {
|
||||||
@@ -595,6 +600,9 @@ optionLoop:
|
|||||||
case QueryResultFormats:
|
case QueryResultFormats:
|
||||||
resultFormats = arg
|
resultFormats = arg
|
||||||
args = args[1:]
|
args = args[1:]
|
||||||
|
case QueryResultFormatsByOID:
|
||||||
|
resultFormatsByOID = arg
|
||||||
|
args = args[1:]
|
||||||
default:
|
default:
|
||||||
break optionLoop
|
break optionLoop
|
||||||
}
|
}
|
||||||
@@ -655,6 +663,13 @@ optionLoop:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resultFormatsByOID != nil {
|
||||||
|
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
||||||
|
for i := range resultFormats {
|
||||||
|
resultFormats[i] = resultFormatsByOID[ps.FieldDescriptions[i].DataType]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if resultFormats == nil {
|
if resultFormats == nil {
|
||||||
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
||||||
for i := range resultFormats {
|
for i := range resultFormats {
|
||||||
|
|||||||
+22
-62
@@ -87,9 +87,8 @@ import (
|
|||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v4"
|
||||||
)
|
)
|
||||||
|
|
||||||
// oids that map to intrinsic database/sql types. These will be allowed to be
|
// Only intrinsic types should be binary format with database/sql.
|
||||||
// binary, anything else will be forced to text format
|
var databaseSQLResultFormats pgx.QueryResultFormatsByOID
|
||||||
var databaseSqlOIDs map[pgtype.OID]bool
|
|
||||||
|
|
||||||
var pgxDriver *Driver
|
var pgxDriver *Driver
|
||||||
|
|
||||||
@@ -104,20 +103,21 @@ func init() {
|
|||||||
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
||||||
sql.Register("pgx", pgxDriver)
|
sql.Register("pgx", pgxDriver)
|
||||||
|
|
||||||
databaseSqlOIDs = make(map[pgtype.OID]bool)
|
databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
|
||||||
databaseSqlOIDs[pgtype.BoolOID] = true
|
pgtype.BoolOID: 1,
|
||||||
databaseSqlOIDs[pgtype.ByteaOID] = true
|
pgtype.ByteaOID: 1,
|
||||||
databaseSqlOIDs[pgtype.CIDOID] = true
|
pgtype.CIDOID: 1,
|
||||||
databaseSqlOIDs[pgtype.DateOID] = true
|
pgtype.DateOID: 1,
|
||||||
databaseSqlOIDs[pgtype.Float4OID] = true
|
pgtype.Float4OID: 1,
|
||||||
databaseSqlOIDs[pgtype.Float8OID] = true
|
pgtype.Float8OID: 1,
|
||||||
databaseSqlOIDs[pgtype.Int2OID] = true
|
pgtype.Int2OID: 1,
|
||||||
databaseSqlOIDs[pgtype.Int4OID] = true
|
pgtype.Int4OID: 1,
|
||||||
databaseSqlOIDs[pgtype.Int8OID] = true
|
pgtype.Int8OID: 1,
|
||||||
databaseSqlOIDs[pgtype.OIDOID] = true
|
pgtype.OIDOID: 1,
|
||||||
databaseSqlOIDs[pgtype.TimestampOID] = true
|
pgtype.TimestampOID: 1,
|
||||||
databaseSqlOIDs[pgtype.TimestamptzOID] = true
|
pgtype.TimestamptzOID: 1,
|
||||||
databaseSqlOIDs[pgtype.XIDOID] = true
|
pgtype.XIDOID: 1,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -168,8 +168,6 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
restrictBinaryToDatabaseSqlTypes(ps)
|
|
||||||
|
|
||||||
return &Stmt{ps: ps, conn: c}, nil
|
return &Stmt{ps: ps, conn: c}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,48 +239,22 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
|
|||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name
|
return c.queryPreparedContext(ctx, query, argsV)
|
||||||
psname := fmt.Sprintf("stdlibpx%v", &argsV)
|
|
||||||
|
|
||||||
ps, err := c.conn.Prepare(ctx, psname, query)
|
|
||||||
if err != nil {
|
|
||||||
// since PrepareEx failed, we didn't actually get to send the values, so
|
|
||||||
// we can safely retry
|
|
||||||
if _, is := err.(net.Error); is {
|
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
restrictBinaryToDatabaseSqlTypes(ps)
|
|
||||||
return c.queryPreparedContext(ctx, psname, argsV)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (c *Conn) execParams(ctx context.Context, sql string, argsV []driver.NamedValue) (*pgconn.ResultReader, error) {
|
|
||||||
// if !c.conn.IsAlive() {
|
|
||||||
// return nil, driver.ErrBadConn
|
|
||||||
// }
|
|
||||||
|
|
||||||
// paramValues := make([][]byte, len(argsV))
|
|
||||||
// for i := 0;i< len(paramValues); i++ {
|
|
||||||
// v := argsV[i].Value
|
|
||||||
// paramValues
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return c.conn.PgConn().ExecParams(ctx, sql,paramValues, nil, nil, nil)
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
|
func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||||
if !c.conn.IsAlive() {
|
if !c.conn.IsAlive() {
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - don't always use text
|
args := []interface{}{databaseSQLResultFormats}
|
||||||
args := []interface{}{pgx.QueryResultFormats{0}}
|
|
||||||
args = append(args, namedValueToInterface(argsV)...)
|
args = append(args, namedValueToInterface(argsV)...)
|
||||||
|
|
||||||
rows, err := c.conn.Query(ctx, name, args...)
|
rows, err := c.conn.Query(ctx, name, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, pgconn.ErrNoBytesSent) {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -299,18 +271,6 @@ func (c *Conn) Ping(ctx context.Context) error {
|
|||||||
return c.conn.Ping(ctx)
|
return c.conn.Ping(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Anything that isn't a database/sql compatible type needs to be forced to
|
|
||||||
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
|
||||||
// (e.g. []int32)
|
|
||||||
func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
|
|
||||||
for i := range ps.FieldDescriptions {
|
|
||||||
intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType]
|
|
||||||
if !intrinsic {
|
|
||||||
ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Stmt struct {
|
type Stmt struct {
|
||||||
ps *pgx.PreparedStatement
|
ps *pgx.PreparedStatement
|
||||||
conn *Conn
|
conn *Conn
|
||||||
|
|||||||
Reference in New Issue
Block a user