Use pgconn.PreparedStatementDescription directly
Instead of having similar pgx.PreparedStatement
This commit is contained in:
@@ -43,7 +43,7 @@ type ConnConfig struct {
|
|||||||
type Conn struct {
|
type Conn struct {
|
||||||
pgConn *pgconn.PgConn
|
pgConn *pgconn.PgConn
|
||||||
config *ConnConfig // config used when establishing this connection
|
config *ConnConfig // config used when establishing this connection
|
||||||
preparedStatements map[string]*PreparedStatement
|
preparedStatements map[string]*pgconn.PreparedStatementDescription
|
||||||
logger Logger
|
logger Logger
|
||||||
logLevel LogLevel
|
logLevel LogLevel
|
||||||
|
|
||||||
@@ -61,14 +61,6 @@ type Conn struct {
|
|||||||
eqb extendedQueryBuilder
|
eqb extendedQueryBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
// PreparedStatement is a description of a prepared statement
|
|
||||||
type PreparedStatement struct {
|
|
||||||
Name string
|
|
||||||
SQL string
|
|
||||||
FieldDescriptions []pgproto3.FieldDescription
|
|
||||||
ParameterOIDs []uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
|
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
|
||||||
// multiple parts such as ["schema", "table"] or ["table", "column"].
|
// multiple parts such as ["schema", "table"] or ["table", "column"].
|
||||||
type Identifier []string
|
type Identifier []string
|
||||||
@@ -172,7 +164,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.preparedStatements = make(map[string]*PreparedStatement)
|
c.preparedStatements = make(map[string]*pgconn.PreparedStatementDescription)
|
||||||
c.doneChan = make(chan struct{})
|
c.doneChan = make(chan struct{})
|
||||||
c.closedChan = make(chan error)
|
c.closedChan = make(chan error)
|
||||||
c.wbuf = make([]byte, 0, 1024)
|
c.wbuf = make([]byte, 0, 1024)
|
||||||
@@ -207,10 +199,11 @@ func (c *Conn) Close(ctx context.Context) error {
|
|||||||
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
|
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
|
||||||
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
|
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
|
||||||
// concern for if the statement has already been prepared.
|
// concern for if the statement has already been prepared.
|
||||||
func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedStatement, err error) {
|
func (c *Conn) Prepare(ctx context.Context, name, sql string) (psd *pgconn.PreparedStatementDescription, err error) {
|
||||||
if name != "" {
|
if name != "" {
|
||||||
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
|
var ok bool
|
||||||
return ps, nil
|
if psd, ok = c.preparedStatements[name]; ok && psd.SQL == sql {
|
||||||
|
return psd, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,27 +215,16 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
psd, err := c.pgConn.Prepare(ctx, name, sql, nil)
|
psd, err = c.pgConn.Prepare(ctx, name, sql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ps = &PreparedStatement{
|
|
||||||
Name: psd.Name,
|
|
||||||
SQL: psd.SQL,
|
|
||||||
ParameterOIDs: make([]uint32, len(psd.ParamOIDs)),
|
|
||||||
FieldDescriptions: psd.Fields,
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range ps.ParameterOIDs {
|
|
||||||
ps.ParameterOIDs[i] = uint32(psd.ParamOIDs[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if name != "" {
|
if name != "" {
|
||||||
c.preparedStatements[name] = ps
|
c.preparedStatements[name] = psd
|
||||||
}
|
}
|
||||||
|
|
||||||
return ps, nil
|
return psd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deallocate released a prepared statement
|
// Deallocate released a prepared statement
|
||||||
@@ -464,7 +446,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i
|
|||||||
return commandTag, err
|
return commandTag, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) execPrepared(ctx context.Context, ps *PreparedStatement, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
|
func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||||
c.eqb.Reset()
|
c.eqb.Reset()
|
||||||
|
|
||||||
args, err := convertDriverValuers(arguments)
|
args, err := convertDriverValuers(arguments)
|
||||||
@@ -473,14 +455,14 @@ func (c *Conn) execPrepared(ctx context.Context, ps *PreparedStatement, argument
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := range args {
|
for i := range args {
|
||||||
err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i])
|
err = c.eqb.AppendParam(c.ConnInfo, ps.ParamOIDs[i], args[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.Fields {
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok {
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||||
c.eqb.AppendResultFormat(BinaryFormatCode)
|
c.eqb.AppendResultFormat(BinaryFormatCode)
|
||||||
} else {
|
} else {
|
||||||
@@ -571,27 +553,16 @@ optionLoop:
|
|||||||
|
|
||||||
ps, ok := c.preparedStatements[sql]
|
ps, ok := c.preparedStatements[sql]
|
||||||
if !ok {
|
if !ok {
|
||||||
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
|
ps, err = c.pgConn.Prepare(ctx, "", sql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return rows, rows.err
|
return rows, rows.err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(psd.ParamOIDs) != len(args) {
|
if len(ps.ParamOIDs) != len(args) {
|
||||||
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(args)))
|
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(ps.ParamOIDs), len(args)))
|
||||||
return rows, rows.err
|
return rows, rows.err
|
||||||
}
|
}
|
||||||
|
|
||||||
ps = &PreparedStatement{
|
|
||||||
Name: psd.Name,
|
|
||||||
SQL: psd.SQL,
|
|
||||||
ParameterOIDs: make([]uint32, len(psd.ParamOIDs)),
|
|
||||||
FieldDescriptions: psd.Fields,
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range ps.ParameterOIDs {
|
|
||||||
ps.ParameterOIDs[i] = uint32(psd.ParamOIDs[i])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
rows.sql = ps.SQL
|
rows.sql = ps.SQL
|
||||||
|
|
||||||
@@ -602,7 +573,7 @@ optionLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := range args {
|
for i := range args {
|
||||||
err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i])
|
err = c.eqb.AppendParam(c.ConnInfo, ps.ParamOIDs[i], args[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return rows, rows.err
|
return rows, rows.err
|
||||||
@@ -610,15 +581,15 @@ optionLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if resultFormatsByOID != nil {
|
if resultFormatsByOID != nil {
|
||||||
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
resultFormats = make([]int16, len(ps.Fields))
|
||||||
for i := range resultFormats {
|
for i := range resultFormats {
|
||||||
resultFormats[i] = resultFormatsByOID[uint32(ps.FieldDescriptions[i].DataTypeOID)]
|
resultFormats[i] = resultFormatsByOID[uint32(ps.Fields[i].DataTypeOID)]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if resultFormats == nil {
|
if resultFormats == nil {
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.Fields {
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok {
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||||
c.eqb.AppendResultFormat(BinaryFormatCode)
|
c.eqb.AppendResultFormat(BinaryFormatCode)
|
||||||
} else {
|
} else {
|
||||||
@@ -655,7 +626,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||||||
ps := c.preparedStatements[bi.query]
|
ps := c.preparedStatements[bi.query]
|
||||||
|
|
||||||
if ps != nil {
|
if ps != nil {
|
||||||
parameterOIDs = ps.ParameterOIDs
|
parameterOIDs = ps.ParamOIDs
|
||||||
} else {
|
} else {
|
||||||
parameterOIDs = bi.parameterOIDs
|
parameterOIDs = bi.parameterOIDs
|
||||||
}
|
}
|
||||||
@@ -677,8 +648,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||||||
resultFormats := bi.resultFormatCodes
|
resultFormats := bi.resultFormatCodes
|
||||||
if resultFormats == nil {
|
if resultFormats == nil {
|
||||||
|
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.Fields {
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok {
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||||
c.eqb.AppendResultFormat(BinaryFormatCode)
|
c.eqb.AppendResultFormat(BinaryFormatCode)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+3
-2
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgio"
|
"github.com/jackc/pgio"
|
||||||
errors "golang.org/x/xerrors"
|
errors "golang.org/x/xerrors"
|
||||||
)
|
)
|
||||||
@@ -116,7 +117,7 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
|||||||
return commandTag.RowsAffected(), err
|
return commandTag.RowsAffected(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
|
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *pgconn.PreparedStatementDescription) (bool, []byte, error) {
|
||||||
|
|
||||||
for ct.rowSrc.Next() {
|
for ct.rowSrc.Next() {
|
||||||
values, err := ct.rowSrc.Values()
|
values, err := ct.rowSrc.Values()
|
||||||
@@ -129,7 +130,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt
|
|||||||
|
|
||||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||||
for i, val := range values {
|
for i, val := range values {
|
||||||
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, uint32(ps.FieldDescriptions[i].DataTypeOID), val)
|
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.Fields[i].DataTypeOID, val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, nil, err
|
return false, nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -46,7 +46,7 @@ func (tx *Tx) LargeObjects() pgx.LargeObjects {
|
|||||||
return tx.t.LargeObjects()
|
return tx.t.LargeObjects()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgx.PreparedStatement, error) {
|
func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgconn.PreparedStatementDescription, error) {
|
||||||
return tx.t.Prepare(ctx, name, sql)
|
return tx.t.Prepare(ctx, name, sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+7
-7
@@ -164,12 +164,12 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
|||||||
name := fmt.Sprintf("pgx_%d", c.psCount)
|
name := fmt.Sprintf("pgx_%d", c.psCount)
|
||||||
c.psCount++
|
c.psCount++
|
||||||
|
|
||||||
ps, err := c.conn.Prepare(ctx, name, query)
|
psd, err := c.conn.Prepare(ctx, name, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Stmt{ps: ps, conn: c}, nil
|
return &Stmt{psd: psd, conn: c}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
@@ -265,16 +265,16 @@ func (c *Conn) Ping(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Stmt struct {
|
type Stmt struct {
|
||||||
ps *pgx.PreparedStatement
|
psd *pgconn.PreparedStatementDescription
|
||||||
conn *Conn
|
conn *Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) Close() error {
|
func (s *Stmt) Close() error {
|
||||||
return s.conn.conn.Deallocate(context.Background(), s.ps.Name)
|
return s.conn.conn.Deallocate(context.Background(), s.psd.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) NumInput() int {
|
func (s *Stmt) NumInput() int {
|
||||||
return len(s.ps.ParameterOIDs)
|
return len(s.psd.ParamOIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
||||||
@@ -282,7 +282,7 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
|
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
|
||||||
return s.conn.ExecContext(ctx, s.ps.Name, argsV)
|
return s.conn.ExecContext(ctx, s.psd.Name, argsV)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
||||||
@@ -290,7 +290,7 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
|
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||||
return s.conn.QueryContext(ctx, s.ps.Name, argsV)
|
return s.conn.QueryContext(ctx, s.psd.Name, argsV)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Rows struct {
|
type Rows struct {
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ type Tx interface {
|
|||||||
SendBatch(ctx context.Context, b *Batch) BatchResults
|
SendBatch(ctx context.Context, b *Batch) BatchResults
|
||||||
LargeObjects() LargeObjects
|
LargeObjects() LargeObjects
|
||||||
|
|
||||||
Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error)
|
Prepare(ctx context.Context, name, sql string) (*pgconn.PreparedStatementDescription, error)
|
||||||
|
|
||||||
Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
|
||||||
Query(ctx context.Context, sql string, args ...interface{}) (Rows, error)
|
Query(ctx context.Context, sql string, args ...interface{}) (Rows, error)
|
||||||
@@ -174,7 +174,7 @@ func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{})
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prepare delegates to the underlying *Conn
|
// Prepare delegates to the underlying *Conn
|
||||||
func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) {
|
func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.PreparedStatementDescription, error) {
|
||||||
if tx.closed {
|
if tx.closed {
|
||||||
return nil, ErrTxClosed
|
return nil, ErrTxClosed
|
||||||
}
|
}
|
||||||
@@ -264,7 +264,7 @@ func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interf
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prepare delegates to the underlying Tx
|
// Prepare delegates to the underlying Tx
|
||||||
func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) {
|
func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*pgconn.PreparedStatementDescription, error) {
|
||||||
if sp.closed {
|
if sp.closed {
|
||||||
return nil, ErrTxClosed
|
return nil, ErrTxClosed
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user