Use pgconn.PreparedStatementDescription directly
Instead of having similar pgx.PreparedStatement
This commit is contained in:
@@ -43,7 +43,7 @@ type ConnConfig struct {
|
||||
type Conn struct {
|
||||
pgConn *pgconn.PgConn
|
||||
config *ConnConfig // config used when establishing this connection
|
||||
preparedStatements map[string]*PreparedStatement
|
||||
preparedStatements map[string]*pgconn.PreparedStatementDescription
|
||||
logger Logger
|
||||
logLevel LogLevel
|
||||
|
||||
@@ -61,14 +61,6 @@ type Conn struct {
|
||||
eqb extendedQueryBuilder
|
||||
}
|
||||
|
||||
// PreparedStatement is a description of a prepared statement
|
||||
type PreparedStatement struct {
|
||||
Name string
|
||||
SQL string
|
||||
FieldDescriptions []pgproto3.FieldDescription
|
||||
ParameterOIDs []uint32
|
||||
}
|
||||
|
||||
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
|
||||
// multiple parts such as ["schema", "table"] or ["table", "column"].
|
||||
type Identifier []string
|
||||
@@ -172,7 +164,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.preparedStatements = make(map[string]*PreparedStatement)
|
||||
c.preparedStatements = make(map[string]*pgconn.PreparedStatementDescription)
|
||||
c.doneChan = make(chan struct{})
|
||||
c.closedChan = make(chan error)
|
||||
c.wbuf = make([]byte, 0, 1024)
|
||||
@@ -207,10 +199,11 @@ func (c *Conn) Close(ctx context.Context) error {
|
||||
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
|
||||
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
|
||||
// concern for if the statement has already been prepared.
|
||||
func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedStatement, err error) {
|
||||
func (c *Conn) Prepare(ctx context.Context, name, sql string) (psd *pgconn.PreparedStatementDescription, err error) {
|
||||
if name != "" {
|
||||
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
|
||||
return ps, nil
|
||||
var ok bool
|
||||
if psd, ok = c.preparedStatements[name]; ok && psd.SQL == sql {
|
||||
return psd, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,27 +215,16 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
|
||||
}()
|
||||
}
|
||||
|
||||
psd, err := c.pgConn.Prepare(ctx, name, sql, nil)
|
||||
psd, err = c.pgConn.Prepare(ctx, name, sql, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ps = &PreparedStatement{
|
||||
Name: psd.Name,
|
||||
SQL: psd.SQL,
|
||||
ParameterOIDs: make([]uint32, len(psd.ParamOIDs)),
|
||||
FieldDescriptions: psd.Fields,
|
||||
}
|
||||
|
||||
for i := range ps.ParameterOIDs {
|
||||
ps.ParameterOIDs[i] = uint32(psd.ParamOIDs[i])
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
c.preparedStatements[name] = ps
|
||||
c.preparedStatements[name] = psd
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
return psd, nil
|
||||
}
|
||||
|
||||
// Deallocate released a prepared statement
|
||||
@@ -464,7 +446,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
func (c *Conn) execPrepared(ctx context.Context, ps *PreparedStatement, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
c.eqb.Reset()
|
||||
|
||||
args, err := convertDriverValuers(arguments)
|
||||
@@ -473,14 +455,14 @@ func (c *Conn) execPrepared(ctx context.Context, ps *PreparedStatement, argument
|
||||
}
|
||||
|
||||
for i := range args {
|
||||
err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i])
|
||||
err = c.eqb.AppendParam(c.ConnInfo, ps.ParamOIDs[i], args[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for i := range ps.FieldDescriptions {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
||||
for i := range ps.Fields {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok {
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
c.eqb.AppendResultFormat(BinaryFormatCode)
|
||||
} else {
|
||||
@@ -571,27 +553,16 @@ optionLoop:
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
|
||||
ps, err = c.pgConn.Prepare(ctx, "", sql, nil)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
if len(psd.ParamOIDs) != len(args) {
|
||||
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(args)))
|
||||
if len(ps.ParamOIDs) != len(args) {
|
||||
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(ps.ParamOIDs), len(args)))
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
ps = &PreparedStatement{
|
||||
Name: psd.Name,
|
||||
SQL: psd.SQL,
|
||||
ParameterOIDs: make([]uint32, len(psd.ParamOIDs)),
|
||||
FieldDescriptions: psd.Fields,
|
||||
}
|
||||
|
||||
for i := range ps.ParameterOIDs {
|
||||
ps.ParameterOIDs[i] = uint32(psd.ParamOIDs[i])
|
||||
}
|
||||
}
|
||||
rows.sql = ps.SQL
|
||||
|
||||
@@ -602,7 +573,7 @@ optionLoop:
|
||||
}
|
||||
|
||||
for i := range args {
|
||||
err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i])
|
||||
err = c.eqb.AppendParam(c.ConnInfo, ps.ParamOIDs[i], args[i])
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, rows.err
|
||||
@@ -610,15 +581,15 @@ optionLoop:
|
||||
}
|
||||
|
||||
if resultFormatsByOID != nil {
|
||||
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
||||
resultFormats = make([]int16, len(ps.Fields))
|
||||
for i := range resultFormats {
|
||||
resultFormats[i] = resultFormatsByOID[uint32(ps.FieldDescriptions[i].DataTypeOID)]
|
||||
resultFormats[i] = resultFormatsByOID[uint32(ps.Fields[i].DataTypeOID)]
|
||||
}
|
||||
}
|
||||
|
||||
if resultFormats == nil {
|
||||
for i := range ps.FieldDescriptions {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
||||
for i := range ps.Fields {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok {
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
c.eqb.AppendResultFormat(BinaryFormatCode)
|
||||
} else {
|
||||
@@ -655,7 +626,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
ps := c.preparedStatements[bi.query]
|
||||
|
||||
if ps != nil {
|
||||
parameterOIDs = ps.ParameterOIDs
|
||||
parameterOIDs = ps.ParamOIDs
|
||||
} else {
|
||||
parameterOIDs = bi.parameterOIDs
|
||||
}
|
||||
@@ -677,8 +648,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
resultFormats := bi.resultFormatCodes
|
||||
if resultFormats == nil {
|
||||
|
||||
for i := range ps.FieldDescriptions {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
||||
for i := range ps.Fields {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok {
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
c.eqb.AppendResultFormat(BinaryFormatCode)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user