From 31705e586af14e5865c1aa7a7695a78090044152 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 14:29:05 -0500 Subject: [PATCH] Use pgconn.PreparedStatementDescription directly Instead of having similar pgx.PreparedStatement --- conn.go | 77 ++++++++++++++++----------------------------------- copy_from.go | 5 ++-- pgxpool/tx.go | 2 +- stdlib/sql.go | 14 +++++----- tx.go | 6 ++-- 5 files changed, 38 insertions(+), 66 deletions(-) diff --git a/conn.go b/conn.go index 81a4c6e2..ad21ccdb 100644 --- a/conn.go +++ b/conn.go @@ -43,7 +43,7 @@ type ConnConfig struct { type Conn struct { pgConn *pgconn.PgConn config *ConnConfig // config used when establishing this connection - preparedStatements map[string]*PreparedStatement + preparedStatements map[string]*pgconn.PreparedStatementDescription logger Logger logLevel LogLevel @@ -61,14 +61,6 @@ type Conn struct { eqb extendedQueryBuilder } -// PreparedStatement is a description of a prepared statement -type PreparedStatement struct { - Name string - SQL string - FieldDescriptions []pgproto3.FieldDescription - ParameterOIDs []uint32 -} - // Identifier a PostgreSQL identifier or name. Identifiers can be composed of // multiple parts such as ["schema", "table"] or ["table", "column"]. type Identifier []string @@ -172,7 +164,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { return nil, err } - c.preparedStatements = make(map[string]*PreparedStatement) + c.preparedStatements = make(map[string]*pgconn.PreparedStatementDescription) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) @@ -207,10 +199,11 @@ func (c *Conn) Close(ctx context.Context) error { // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. -func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedStatement, err error) { +func (c *Conn) Prepare(ctx context.Context, name, sql string) (psd *pgconn.PreparedStatementDescription, err error) { if name != "" { - if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { - return ps, nil + var ok bool + if psd, ok = c.preparedStatements[name]; ok && psd.SQL == sql { + return psd, nil } } @@ -222,27 +215,16 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState }() } - psd, err := c.pgConn.Prepare(ctx, name, sql, nil) + psd, err = c.pgConn.Prepare(ctx, name, sql, nil) if err != nil { return nil, err } - ps = &PreparedStatement{ - Name: psd.Name, - SQL: psd.SQL, - ParameterOIDs: make([]uint32, len(psd.ParamOIDs)), - FieldDescriptions: psd.Fields, - } - - for i := range ps.ParameterOIDs { - ps.ParameterOIDs[i] = uint32(psd.ParamOIDs[i]) - } - if name != "" { - c.preparedStatements[name] = ps + c.preparedStatements[name] = psd } - return ps, nil + return psd, nil } // Deallocate released a prepared statement @@ -464,7 +446,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i return commandTag, err } -func (c *Conn) execPrepared(ctx context.Context, ps *PreparedStatement, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { +func (c *Conn) execPrepared(ctx context.Context, ps *pgconn.PreparedStatementDescription, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { c.eqb.Reset() args, err := convertDriverValuers(arguments) @@ -473,14 +455,14 @@ func (c *Conn) execPrepared(ctx context.Context, ps *PreparedStatement, argument } for i := range args { - err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i]) + err = c.eqb.AppendParam(c.ConnInfo, ps.ParamOIDs[i], args[i]) if err != nil { return nil, err } } - for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok { + for i := range ps.Fields { + if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { c.eqb.AppendResultFormat(BinaryFormatCode) } else { @@ -571,27 +553,16 @@ optionLoop: ps, ok := c.preparedStatements[sql] if !ok { - psd, err := c.pgConn.Prepare(ctx, "", sql, nil) + ps, err = c.pgConn.Prepare(ctx, "", sql, nil) if err != nil { rows.fatal(err) return rows, rows.err } - if len(psd.ParamOIDs) != len(args) { - rows.fatal(errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(args))) + if len(ps.ParamOIDs) != len(args) { + rows.fatal(errors.Errorf("expected %d arguments, got %d", len(ps.ParamOIDs), len(args))) return rows, rows.err } - - ps = &PreparedStatement{ - Name: psd.Name, - SQL: psd.SQL, - ParameterOIDs: make([]uint32, len(psd.ParamOIDs)), - FieldDescriptions: psd.Fields, - } - - for i := range ps.ParameterOIDs { - ps.ParameterOIDs[i] = uint32(psd.ParamOIDs[i]) - } } rows.sql = ps.SQL @@ -602,7 +573,7 @@ optionLoop: } for i := range args { - err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i]) + err = c.eqb.AppendParam(c.ConnInfo, ps.ParamOIDs[i], args[i]) if err != nil { rows.fatal(err) return rows, rows.err @@ -610,15 +581,15 @@ optionLoop: } if resultFormatsByOID != nil { - resultFormats = make([]int16, len(ps.FieldDescriptions)) + resultFormats = make([]int16, len(ps.Fields)) for i := range resultFormats { - resultFormats[i] = resultFormatsByOID[uint32(ps.FieldDescriptions[i].DataTypeOID)] + resultFormats[i] = resultFormatsByOID[uint32(ps.Fields[i].DataTypeOID)] } } if resultFormats == nil { - for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok { + for i := range ps.Fields { + if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { c.eqb.AppendResultFormat(BinaryFormatCode) } else { @@ -655,7 +626,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { ps := c.preparedStatements[bi.query] if ps != nil { - parameterOIDs = ps.ParameterOIDs + parameterOIDs = ps.ParamOIDs } else { parameterOIDs = bi.parameterOIDs } @@ -677,8 +648,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { resultFormats := bi.resultFormatCodes if resultFormats == nil { - for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.FieldDescriptions[i].DataTypeOID)); ok { + for i := range ps.Fields { + if dt, ok := c.ConnInfo.DataTypeForOID(uint32(ps.Fields[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { c.eqb.AppendResultFormat(BinaryFormatCode) } else { diff --git a/copy_from.go b/copy_from.go index 44bb9f3e..11c8acc1 100644 --- a/copy_from.go +++ b/copy_from.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + "github.com/jackc/pgconn" "github.com/jackc/pgio" errors "golang.org/x/xerrors" ) @@ -116,7 +117,7 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { return commandTag.RowsAffected(), err } -func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) { +func (ct *copyFrom) buildCopyBuf(buf []byte, ps *pgconn.PreparedStatementDescription) (bool, []byte, error) { for ct.rowSrc.Next() { values, err := ct.rowSrc.Values() @@ -129,7 +130,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, uint32(ps.FieldDescriptions[i].DataTypeOID), val) + buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.Fields[i].DataTypeOID, val) if err != nil { return false, nil, err } diff --git a/pgxpool/tx.go b/pgxpool/tx.go index c9f00290..1ddeb91b 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -46,7 +46,7 @@ func (tx *Tx) LargeObjects() pgx.LargeObjects { return tx.t.LargeObjects() } -func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgx.PreparedStatement, error) { +func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgconn.PreparedStatementDescription, error) { return tx.t.Prepare(ctx, name, sql) } diff --git a/stdlib/sql.go b/stdlib/sql.go index fb2732cf..8c96d466 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -164,12 +164,12 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ - ps, err := c.conn.Prepare(ctx, name, query) + psd, err := c.conn.Prepare(ctx, name, query) if err != nil { return nil, err } - return &Stmt{ps: ps, conn: c}, nil + return &Stmt{psd: psd, conn: c}, nil } func (c *Conn) Close() error { @@ -265,16 +265,16 @@ func (c *Conn) Ping(ctx context.Context) error { } type Stmt struct { - ps *pgx.PreparedStatement + psd *pgconn.PreparedStatementDescription conn *Conn } func (s *Stmt) Close() error { - return s.conn.conn.Deallocate(context.Background(), s.ps.Name) + return s.conn.conn.Deallocate(context.Background(), s.psd.Name) } func (s *Stmt) NumInput() int { - return len(s.ps.ParameterOIDs) + return len(s.psd.ParamOIDs) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { @@ -282,7 +282,7 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { } func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { - return s.conn.ExecContext(ctx, s.ps.Name, argsV) + return s.conn.ExecContext(ctx, s.psd.Name, argsV) } func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { @@ -290,7 +290,7 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { } func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { - return s.conn.QueryContext(ctx, s.ps.Name, argsV) + return s.conn.QueryContext(ctx, s.psd.Name, argsV) } type Rows struct { diff --git a/tx.go b/tx.go index cecd8f2c..1b42905f 100644 --- a/tx.go +++ b/tx.go @@ -95,7 +95,7 @@ type Tx interface { SendBatch(ctx context.Context, b *Batch) BatchResults LargeObjects() LargeObjects - Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) + Prepare(ctx context.Context, name, sql string) (*pgconn.PreparedStatementDescription, error) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) @@ -174,7 +174,7 @@ func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{}) } // Prepare delegates to the underlying *Conn -func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) { +func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.PreparedStatementDescription, error) { if tx.closed { return nil, ErrTxClosed } @@ -264,7 +264,7 @@ func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interf } // Prepare delegates to the underlying Tx -func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) { +func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*pgconn.PreparedStatementDescription, error) { if sp.closed { return nil, ErrTxClosed }