2
0
Files
pgx/conn.go
T
Jack Christensen c53c9e6eb5 Remove simple protocol and one round trip query options
It is impossible to guarantee that the a query executed with the simple
protocol will behave the same as with the extended protocol. This is
because the normal pgx path relies on knowing the OID of query
parameters. Without this encoding a value can only be determined by the
value instead of the combination of value and PostgreSQL type. For
example, how should a []int32 be encoded? It might be encoded into a
PostgreSQL int4[] or json.

Removal also simplifies the core query path.

The primary reason for the simple protocol is for servers like PgBouncer
that may not be able to support normal prepared statements. After
further research it appears that issuing a "flush" instead "sync" after
preparing the unnamed statement would allow PgBouncer to work.

The one round trip mode can be better handled with prepared statements.

As a last resort, all original server functionality can still be accessed by
dropping down to PgConn.
2019-04-13 11:39:01 -05:00

922 lines
26 KiB
Go

package pgx
import (
"context"
"database/sql/driver"
"fmt"
"net"
"reflect"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3"
"github.com/jackc/pgx/pgtype"
)
const (
connStatusUninitialized = iota
connStatusClosed
connStatusIdle
connStatusBusy
)
// minimalConnInfo has just enough static type information to establish the
// connection and retrieve the type data.
var minimalConnInfo *pgtype.ConnInfo
func init() {
minimalConnInfo = pgtype.NewConnInfo()
minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{
"int4": pgtype.Int4OID,
"name": pgtype.NameOID,
"oid": pgtype.OIDOID,
"text": pgtype.TextOID,
"varchar": pgtype.VarcharOID,
})
}
// ConnConfig contains all the options used to establish a connection.
type ConnConfig struct {
pgconn.Config
Logger Logger
LogLevel LogLevel
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
// Use ConnPool to manage access to multiple database connections from multiple
// goroutines.
type Conn struct {
pgConn *pgconn.PgConn
wbuf []byte
config *ConnConfig // config used when establishing this connection
preparedStatements map[string]*PreparedStatement
logger Logger
logLevel LogLevel
fp *fastpath
poolResetCount int
preallocatedRows []connRows
mux sync.Mutex
status byte // One of connStatus* constants
causeOfDeath error
lastStmtSent bool
// context support
ctxInProgress bool
doneChan chan struct{}
closedChan chan error
ConnInfo *pgtype.ConnInfo
}
// PreparedStatement is a description of a prepared statement
type PreparedStatement struct {
Name string
SQL string
FieldDescriptions []FieldDescription
ParameterOIDs []pgtype.OID
}
// PrepareExOptions is an option struct that can be passed to PrepareEx
type PrepareExOptions struct {
ParameterOIDs []pgtype.OID
}
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
// multiple parts such as ["schema", "table"] or ["table", "column"].
type Identifier []string
// Sanitize returns a sanitized string safe for SQL interpolation.
func (ident Identifier) Sanitize() string {
parts := make([]string, len(ident))
for i := range ident {
parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"`
}
return strings.Join(parts, ".")
}
// ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors.New("no rows in result set")
// ErrDeadConn occurs on an attempt to use a dead connection
var ErrDeadConn = errors.New("conn is dead")
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = pgconn.ErrTLSRefused
// ErrConnBusy occurs when the connection is busy (for example, in the middle of
// reading query results) and another action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
var ErrInvalidLogLevel = errors.New("invalid log level")
// ProtocolError occurs when unexpected data is received from PostgreSQL
type ProtocolError string
func (e ProtocolError) Error() string {
return string(e)
}
// Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details.
func Connect(ctx context.Context, connString string) (*Conn, error) {
connConfig, err := ParseConfig(connString)
if err != nil {
return nil, err
}
return connect(ctx, connConfig, minimalConnInfo)
}
// Connect establishes a connection with a PostgreSQL server with a configuration struct.
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
return connect(ctx, connConfig, minimalConnInfo)
}
func ParseConfig(connString string) (*ConnConfig, error) {
config, err := pgconn.ParseConfig(connString)
if err != nil {
return nil, err
}
connConfig := &ConnConfig{
Config: *config,
}
return connConfig, nil
}
func connect(ctx context.Context, config *ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
c = new(Conn)
c.config = config
c.ConnInfo = connInfo
if c.config.LogLevel != 0 {
c.logLevel = c.config.LogLevel
} else {
// Preserve pre-LogLevel behavior by defaulting to LogLevelDebug
c.logLevel = LogLevelDebug
}
c.logger = c.config.Logger
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
}
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
if err != nil {
return nil, err
}
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
}
return nil, err
}
c.preparedStatements = make(map[string]*PreparedStatement)
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
c.wbuf = make([]byte, 0, 1024)
c.status = connStatusIdle
// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _, ok := c.pgConn.Config.RuntimeParams["replication"]; ok {
return c, nil
}
if c.ConnInfo == minimalConnInfo {
err = c.initConnInfo()
if err != nil {
c.Close(ctx)
return nil, err
}
}
return c, nil
}
func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
const (
namedOIDQuery = `select t.oid,
case when nsp.nspname in ('pg_catalog', 'public') then t.typname
else nsp.nspname||'.'||t.typname
end
from pg_type t
left join pg_type base_type on t.typelem=base_type.oid
left join pg_namespace nsp on t.typnamespace=nsp.oid
where (
t.typtype in('b', 'p', 'r', 'e')
and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))
)`
)
nameOIDs, err := connInfoFromRows(c.Query(context.TODO(), namedOIDQuery))
if err != nil {
return nil, err
}
cinfo := pgtype.NewConnInfo()
cinfo.InitializeDataTypes(nameOIDs)
if err = c.initConnInfoEnumArray(cinfo); err != nil {
return nil, err
}
if err = c.initConnInfoDomains(cinfo); err != nil {
return nil, err
}
return cinfo, nil
}
func (c *Conn) initConnInfo() (err error) {
var (
connInfo *pgtype.ConnInfo
)
if c.config.CustomConnInfo != nil {
if c.ConnInfo, err = c.config.CustomConnInfo(c); err != nil {
return err
}
return nil
}
if connInfo, err = initPostgresql(c); err == nil {
c.ConnInfo = connInfo
return err
}
// Check if CrateDB specific approach might still allow us to connect.
if connInfo, err = c.crateDBTypesQuery(err); err == nil {
c.ConnInfo = connInfo
}
return err
}
// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them.
func (c *Conn) initConnInfoEnumArray(cinfo *pgtype.ConnInfo) error {
nameOIDs := make(map[string]pgtype.OID, 16)
rows, err := c.Query(context.TODO(), `select t.oid, t.typname
from pg_type t
join pg_type base_type on t.typelem=base_type.oid
where t.typtype = 'b'
and base_type.typtype = 'e'`)
if err != nil {
return err
}
for rows.Next() {
var oid pgtype.OID
var name pgtype.Text
if err := rows.Scan(&oid, &name); err != nil {
return err
}
nameOIDs[name.String] = oid
}
if rows.Err() != nil {
return rows.Err()
}
for name, oid := range nameOIDs {
cinfo.RegisterDataType(pgtype.DataType{
Value: &pgtype.EnumArray{},
Name: name,
OID: oid,
})
}
return nil
}
// initConnInfoDomains introspects for domains and registers a data type for them.
func (c *Conn) initConnInfoDomains(cinfo *pgtype.ConnInfo) error {
type domain struct {
oid pgtype.OID
name pgtype.Text
baseOID pgtype.OID
}
domains := make([]*domain, 0, 16)
rows, err := c.Query(context.TODO(), `select t.oid, t.typname, t.typbasetype
from pg_type t
join pg_type base_type on t.typbasetype=base_type.oid
where t.typtype = 'd'
and base_type.typtype = 'b'`)
if err != nil {
return err
}
for rows.Next() {
var d domain
if err := rows.Scan(&d.oid, &d.name, &d.baseOID); err != nil {
return err
}
domains = append(domains, &d)
}
if rows.Err() != nil {
return rows.Err()
}
for _, d := range domains {
baseDataType, ok := cinfo.DataTypeForOID(d.baseOID)
if ok {
cinfo.RegisterDataType(pgtype.DataType{
Value: reflect.New(reflect.ValueOf(baseDataType.Value).Elem().Type()).Interface().(pgtype.Value),
Name: d.name.String,
OID: d.oid,
})
}
}
return nil
}
// crateDBTypesQuery checks if the given err is likely to be the result of
// CrateDB not implementing the pg_types table correctly. If yes, a CrateDB
// specific query against pg_types is executed and its results are returned. If
// not, the original error is returned.
func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) {
// CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol,
// but not perfectly. In particular, the pg_catalog schema containing the
// pg_type table is not visible by default and the pg_type.typtype column is
// not implemented. Therefor the query above currently returns the following
// error:
//
// pgx.PgError{Severity:"ERROR", Code:"XX000",
// Message:"TableUnknownException: Table 'test.pg_type' unknown",
// Detail:"", Hint:"", Position:0, InternalPosition:0, InternalQuery:"",
// Where:"", SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
// ConstraintName:"", File:"Schemas.java", Line:99, Routine:"getTableInfo"}
//
// If CrateDB was to fix the pg_type table visbility in the future, we'd
// still get this error until typtype column is implemented:
//
// pgx.PgError{Severity:"ERROR", Code:"XX000",
// Message:"ColumnUnknownException: Column typtype unknown", Detail:"",
// Hint:"", Position:0, InternalPosition:0, InternalQuery:"", Where:"",
// SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
// ConstraintName:"", File:"FullQualifiedNameFieldProvider.java", Line:132,
//
// Additionally CrateDB doesn't implement Postgres error codes [2], and
// instead always returns "XX000" (internal_error). The code below uses all
// of this knowledge as a heuristic to detect CrateDB. If CrateDB is
// detected, a CrateDB specific pg_type query is executed instead.
//
// The heuristic is designed to still work even if CrateDB fixes [2] or
// renames its internal exception names. If both are changed but pg_types
// isn't fixed, this code will need to be changed.
//
// There is also a small chance the heuristic will yield a false positive for
// non-CrateDB databases (e.g. if a real Postgres instance returns a XX000
// error), but hopefully there will be no harm in attempting the alternative
// query in this case.
//
// CrateDB also uses the type varchar for the typname column which required
// adding varchar to the minimalConnInfo init code.
//
// Also see the discussion here [3].
//
// [1] https://crate.io/
// [2] https://github.com/crate/crate/issues/5027
// [3] https://github.com/jackc/pgx/issues/320
if pgErr, ok := err.(*pgconn.PgError); ok &&
(pgErr.Code == "XX000" ||
strings.Contains(pgErr.Message, "TableUnknownException") ||
strings.Contains(pgErr.Message, "ColumnUnknownException")) {
var (
nameOIDs map[string]pgtype.OID
)
if nameOIDs, err = connInfoFromRows(c.Query(context.TODO(), `select oid, typname from pg_catalog.pg_type`)); err != nil {
return nil, err
}
cinfo := pgtype.NewConnInfo()
cinfo.InitializeDataTypes(nameOIDs)
return cinfo, err
}
return nil, err
}
// PID returns the backend PID for this connection.
func (c *Conn) PID() uint32 {
return c.pgConn.PID()
}
// LocalAddr returns the underlying connection's local address
func (c *Conn) LocalAddr() (net.Addr, error) {
if !c.IsAlive() {
return nil, errors.New("connection not ready")
}
return c.pgConn.Conn().LocalAddr(), nil
}
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func (c *Conn) Close(ctx context.Context) error {
c.mux.Lock()
defer c.mux.Unlock()
if c.status < connStatusIdle {
return nil
}
c.status = connStatusClosed
err := c.pgConn.Close(ctx)
c.causeOfDeath = errors.New("Closed")
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "closed connection", nil)
}
return err
}
func (c *Conn) TxStatus() byte {
return c.pgConn.TxStatus
}
// ParameterStatus returns the value of a parameter reported by the server (e.g.
// server_version). Returns an empty string for unknown parameters.
func (c *Conn) ParameterStatus(key string) string {
return c.pgConn.ParameterStatus(key)
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
return c.PrepareEx(context.Background(), name, sql, nil)
}
// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct
//
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
if name != "" {
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
}
}
if c.shouldLog(LogLevelError) {
defer func() {
if err != nil {
c.log(LogLevelError, "prepareEx failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
}
}()
}
if opts == nil {
opts = &PrepareExOptions{}
}
if len(opts.ParameterOIDs) > 65535 {
return nil, errors.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs))
}
var paramOIDs []uint32
for _, oid := range opts.ParameterOIDs {
paramOIDs = append(paramOIDs, uint32(oid))
}
psd, err := c.pgConn.Prepare(context.TODO(), name, sql, paramOIDs)
if err != nil {
return nil, err
}
ps = &PreparedStatement{
Name: psd.Name,
SQL: psd.SQL,
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
}
for i := range ps.ParameterOIDs {
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
}
if name != "" {
c.preparedStatements[name] = ps
}
return ps, nil
}
// Deallocate released a prepared statement
func (c *Conn) Deallocate(name string) error {
return c.deallocateContext(context.Background(), name)
}
// TODO - consider making this public
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
delete(c.preparedStatements, name)
_, err = c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
return err
}
func (c *Conn) IsAlive() bool {
c.mux.Lock()
defer c.mux.Unlock()
return c.pgConn.IsAlive() && c.status >= connStatusIdle
}
func (c *Conn) CauseOfDeath() error {
c.mux.Lock()
defer c.mux.Unlock()
return c.causeOfDeath
}
// Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages is the same
// regardless of when they occur. It also ignores messages that are only
// meaningful in a given context. These messages can occur due to a context
// deadline interrupting message processing. For example, an interrupted query
// may have left DataRow messages on the wire.
func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg)
}
return nil
}
func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) *pgconn.PgError {
err := &pgconn.PgError{
Severity: msg.Severity,
Code: msg.Code,
Message: msg.Message,
Detail: msg.Detail,
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: msg.InternalQuery,
Where: msg.Where,
SchemaName: msg.SchemaName,
TableName: msg.TableName,
ColumnName: msg.ColumnName,
DataTypeName: msg.DataTypeName,
ConstraintName: msg.ConstraintName,
File: msg.File,
Line: msg.Line,
Routine: msg.Routine,
}
if err.Severity == "FATAL" {
c.die(err)
}
return err
}
func (c *Conn) die(err error) {
c.mux.Lock()
defer c.mux.Unlock()
if c.status == connStatusClosed {
return
}
c.status = connStatusClosed
c.causeOfDeath = err
c.pgConn.Conn().Close()
}
func (c *Conn) lock() error {
c.mux.Lock()
defer c.mux.Unlock()
if c.status != connStatusIdle {
return ErrConnBusy
}
c.status = connStatusBusy
return nil
}
func (c *Conn) unlock() error {
c.mux.Lock()
defer c.mux.Unlock()
if c.status != connStatusBusy {
return errors.New("unlock conn that is not busy")
}
c.status = connStatusIdle
return nil
}
func (c *Conn) shouldLog(lvl LogLevel) bool {
return c.logger != nil && c.logLevel >= lvl
}
func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
if data == nil {
data = map[string]interface{}{}
}
if c.pgConn != nil && c.pgConn.PID() != 0 {
data["pid"] = c.pgConn.PID()
}
c.logger.Log(lvl, msg, data)
}
// SetLogger replaces the current logger and returns the previous logger.
func (c *Conn) SetLogger(logger Logger) Logger {
oldLogger := c.logger
c.logger = logger
return oldLogger
}
// SetLogLevel replaces the current log level and returns the previous log
// level.
func (c *Conn) SetLogLevel(lvl LogLevel) (LogLevel, error) {
oldLvl := c.logLevel
if lvl < LogLevelNone || lvl > LogLevelTrace {
return oldLvl, ErrInvalidLogLevel
}
c.logLevel = lvl
return lvl, nil
}
func quoteIdentifier(s string) string {
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
}
func (c *Conn) Ping(ctx context.Context) error {
_, err := c.Exec(ctx, ";")
return err
}
func connInfoFromRows(rows Rows, err error) (map[string]pgtype.OID, error) {
if err != nil {
return nil, err
}
defer rows.Close()
nameOIDs := make(map[string]pgtype.OID, 256)
for rows.Next() {
var oid pgtype.OID
var name pgtype.Text
if err = rows.Scan(&oid, &name); err != nil {
return nil, err
}
nameOIDs[name.String] = oid
}
if err = rows.Err(); err != nil {
return nil, err
}
return nameOIDs, err
}
// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to
// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets
// the value to false initially until the statement has been sent. This does
// NOT mean that the statement was successful or even received, it just means
// that a write was attempted and therefore it could have been executed. Calls
// to prepare a statement are ignored, only when the prepared statement is
// attempted to be executed will this return true.
func (c *Conn) LastStmtSent() bool {
return c.lastStmtSent
}
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
// PostgreSQL connection than pgx exposes.
//
// It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn
// is used and the connection must be returned to the same state before any *pgx.Conn methods are again used.
func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn }
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
c.lastStmtSent = false
if err := c.lock(); err != nil {
return "", err
}
defer c.unlock()
startTime := time.Now()
commandTag, err := c.exec(ctx, sql, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
}
return commandTag, err
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
}
return commandTag, err
}
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
if ps, ok := c.preparedStatements[sql]; ok {
args, err := convertDriverValuers(arguments)
if err != nil {
return "", err
}
paramFormats := make([]int16, len(args))
paramValues := make([][]byte, len(args))
for i := range args {
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], args[i])
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], args[i])
if err != nil {
return "", err
}
}
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
} else {
resultFormats[i] = TextFormatCode
}
}
}
c.lastStmtSent = true
result := c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats).Read()
return result.CommandTag, result.Err
}
if len(arguments) == 0 {
c.lastStmtSent = true
results, err := c.pgConn.Exec(ctx, sql).ReadAll()
if err != nil {
return "", err
}
if len(results) == 0 {
return "", nil
}
return results[len(results)-1].CommandTag, nil
} else {
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
if err != nil {
return "", err
}
if len(psd.ParamOIDs) != len(arguments) {
return "", errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(arguments))
}
ps := &PreparedStatement{
Name: psd.Name,
SQL: psd.SQL,
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
}
for i := range ps.ParameterOIDs {
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i])
}
arguments, err = convertDriverValuers(arguments)
if err != nil {
return "", err
}
paramFormats := make([]int16, len(arguments))
paramValues := make([][]byte, len(arguments))
for i := range arguments {
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
if err != nil {
return "", err
}
}
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
} else {
resultFormats[i] = TextFormatCode
}
}
}
c.lastStmtSent = true
result := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats).Read()
return result.CommandTag, result.Err
}
}
func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) ([]byte, error) {
if arg == nil {
return nil, nil
}
// TODO - don't allocate a new buf for each encoded prepared statement. The empty slice is necessary because otherwise empty strings may be encoded as []byte(nil) instead of []byte{}
buf := make([]byte, 0)
switch arg := arg.(type) {
case pgtype.BinaryEncoder:
return arg.EncodeBinary(ci, buf)
case pgtype.TextEncoder:
return arg.EncodeText(ci, buf)
case string:
return []byte(arg), nil
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
return nil, nil
}
arg = refVal.Elem().Interface()
return newencodePreparedStatementArgument(ci, oid, arg)
}
if dt, ok := ci.DataTypeForOID(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
}
return newencodePreparedStatementArgument(ci, oid, v)
}
}
return nil, err
}
return value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return newencodePreparedStatementArgument(ci, oid, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
// pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
// FieldDescription.
func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) {
dst.Name = src.Name
dst.Table = pgtype.OID(src.TableOID)
dst.AttributeNumber = src.TableAttributeNumber
dst.DataType = pgtype.OID(src.DataTypeOID)
dst.DataTypeSize = src.DataTypeSize
dst.Modifier = src.TypeModifier
dst.FormatCode = src.Format
if dt, ok := c.ConnInfo.DataTypeForOID(dst.DataType); ok {
dst.DataTypeName = dt.Name
}
}