From dd5e6a77dc81414a78548a1b4e04c5b56dcb67f0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 11:27:44 -0500 Subject: [PATCH] Add QueryEx single round-trip mode --- query.go | 122 ++++++++++++++++++++++++++++++++++++++++++++------ query_test.go | 26 +++++++++++ 2 files changed, 134 insertions(+), 14 deletions(-) diff --git a/query.go b/query.go index 0962b352..447a55ac 100644 --- a/query.go +++ b/query.go @@ -348,7 +348,12 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } type QueryExOptions struct { - ParameterOids []pgtype.Oid + // When ParameterOids are present and the query is not a prepared statement, + // then ParameterOids and ResultFormatCodes will be used to avoid an extra + // network round-trip. + ParameterOids []pgtype.Oid + ResultFormatCodes []int16 + SimpleProtocol bool } @@ -358,6 +363,10 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, return nil, err } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + c.lastActivityTime = time.Now() rows = c.getRows(sql, args) @@ -368,13 +377,13 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, } rows.unlockConn = true - if options != nil && options.SimpleProtocol { - err = c.initContext(ctx) - if err != nil { - rows.fatal(err) - return rows, err - } + err = c.initContext(ctx) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + if options != nil && options.SimpleProtocol { err = c.sanitizeAndSendSimpleQuery(sql, args...) if err != nil { rows.fatal(err) @@ -384,10 +393,54 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, return rows, nil } + if options != nil && len(options.ParameterOids) > 0 { + + buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args) + if err != nil { + rows.fatal(err) + return rows, err + } + + buf = appendSync(buf) + + n, err := c.conn.Write(buf) + if err != nil && fatalWriteErr(n, err) { + rows.fatal(err) + c.die(err) + return nil, err + } + c.readyForQuery = false + + fieldDescriptions, err := c.readUntilRowDescription() + if err != nil { + rows.fatal(err) + return nil, err + } + + if len(options.ResultFormatCodes) == 0 { + for i := range fieldDescriptions { + fieldDescriptions[i].FormatCode = TextFormatCode + } + } else if len(options.ResultFormatCodes) == 1 { + fc := options.ResultFormatCodes[0] + for i := range fieldDescriptions { + fieldDescriptions[i].FormatCode = fc + } + } else { + for i := range options.ResultFormatCodes { + fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i] + } + } + + rows.sql = sql + rows.fields = fieldDescriptions + return rows, nil + } + ps, ok := c.preparedStatements[sql] if !ok { var err error - ps, err = c.PrepareEx(ctx, "", sql, nil) + ps, err = c.prepareEx("", sql, nil) if err != nil { rows.fatal(err) return rows, rows.err @@ -396,12 +449,6 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, rows.sql = ps.SQL rows.fields = ps.FieldDescriptions - err = c.initContext(ctx) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - err = c.sendPreparedQuery(ps, args...) if err != nil { rows.fatal(err) @@ -410,6 +457,53 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, return rows, rows.err } +func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { + if len(arguments) != len(options.ParameterOids) { + return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids)) + } + + if len(options.ParameterOids) > 65535 { + return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids)) + } + + buf = appendParse(buf, "", sql, options.ParameterOids) + buf = appendDescribe(buf, 'S', "") + buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOids, arguments, options.ResultFormatCodes) + if err != nil { + return nil, err + } + buf = appendExecute(buf, "", 0) + + return buf, nil +} + +func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { + for { + msg, err := c.rxMsg() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + case *pgproto3.RowDescription: + fieldDescriptions := c.rxRowDescription(msg) + for i := range fieldDescriptions { + if dt, ok := c.ConnInfo.DataTypeForOid(fieldDescriptions[i].DataType); ok { + fieldDescriptions[i].DataTypeName = dt.Name + } else { + return nil, fmt.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) + } + } + return fieldDescriptions, nil + default: + if err := c.processContextFreeMsg(msg); err != nil { + return nil, err + } + } + } +} + func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { if c.RuntimeParams["standard_conforming_strings"] != "on" { return errors.New("simple protocol queries must be run with standard_conforming_strings=on") diff --git a/query_test.go b/query_test.go index 801b34dd..4e128fb2 100644 --- a/query_test.go +++ b/query_test.go @@ -1182,6 +1182,32 @@ func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) { ensureConnValid(t, conn) } +func TestConnQueryRowExSingleRoundTrip(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var result int32 + err := conn.QueryRowEx( + context.Background(), + "select $1 + $2", + &pgx.QueryExOptions{ + ParameterOids: []pgtype.Oid{pgtype.Int4Oid, pgtype.Int4Oid}, + ResultFormatCodes: []int16{pgx.BinaryFormatCode}, + }, + 1, 2, + ).Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != 3 { + t.Fatal("result => %d, want %d", result, 3) + } + + ensureConnValid(t, conn) +} + func TestConnSimpleProtocol(t *testing.T) { t.Parallel()