From 4ed4e0122de1ed46647934dcd7e2ea8680b82c59 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Feb 2019 13:27:18 -0600 Subject: [PATCH] Restore simple protocol support --- query.go | 76 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/query.go b/query.go index 0ba075d6..c169db8d 100644 --- a/query.go +++ b/query.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" + "github.com/jackc/pgx/internal/sanitize" "github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgtype" ) @@ -56,7 +57,8 @@ type Rows struct { unlockConn bool closed bool - resultReader *pgconn.ResultReader + resultReader *pgconn.ResultReader + multiResultReader *pgconn.MultiResultReader } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -84,6 +86,13 @@ func (rows *Rows) Close() { } } + if rows.multiResultReader != nil { + closeErr := rows.multiResultReader.Close() + if rows.err == nil { + rows.err = closeErr + } + } + if rows.err == nil { if rows.conn.shouldLog(LogLevelInfo) { endTime := time.Now() @@ -373,16 +382,25 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, // return rows, rows.err // } - // if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { - // c.lastStmtSent = true - // err = c.sanitizeAndSendSimpleQuery(sql, args...) - // if err != nil { - // rows.fatal(err) - // return rows, err - // } + if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { + sql, err = c.sanitizeForSimpleQuery(sql, args...) + if err != nil { + rows.fatal(err) + return rows, err + } - // return rows, nil - // } + c.lastStmtSent = true + rows.multiResultReader = c.pgConn.Exec(ctx, sql) + if rows.multiResultReader.NextResult() { + rows.resultReader = rows.multiResultReader.ResultReader() + } else { + err = rows.multiResultReader.Close() + rows.fatal(err) + return rows, err + } + + return rows, nil + } // if options != nil && len(options.ParameterOIDs) > 0 { @@ -513,30 +531,26 @@ func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryEx return buf, nil } -// func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { -// if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { -// return errors.New("simple protocol queries must be run with standard_conforming_strings=on") -// } +func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) { + if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") + } -// if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { -// return errors.New("simple protocol queries must be run with client_encoding=UTF8") -// } + if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("simple protocol queries must be run with client_encoding=UTF8") + } -// valueArgs := make([]interface{}, len(args)) -// for i, a := range args { -// valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) -// if err != nil { -// return err -// } -// } + var err error + valueArgs := make([]interface{}, len(args)) + for i, a := range args { + valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) + if err != nil { + return "", err + } + } -// sql, err = sanitize.SanitizeSQL(sql, valueArgs...) -// if err != nil { -// return err -// } - -// return c.sendSimpleQuery(sql) -// } + return sanitize.SanitizeSQL(sql, valueArgs...) +} func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { rows, _ := c.QueryEx(ctx, sql, options, args...)