diff --git a/pgconn.go b/pgconn.go index db741d47..7ddc50e6 100644 --- a/pgconn.go +++ b/pgconn.go @@ -712,53 +712,16 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { - result := &ResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, - } - - if err := pgConn.lock(); err != nil { - result.concludeCommand("", err) - result.closed = true + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { return result } - if len(paramValues) > math.MaxUint16 { - result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) - result.closed = true - pgConn.unlock() - return result - } - - select { - case <-ctx.Done(): - result.concludeCommand("", ctx.Err()) - result.closed = true - pgConn.unlock() - return result - default: - } - result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - var buf []byte - - // TODO - refactor ExecParams and ExecPrepared - these lines only difference buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) - buf = (&pgproto3.Execute{}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - result.concludeCommand("", err) - result.cleanupContextDeadline() - result.closed = true - pgConn.unlock() - } + pgConn.execExtendedSuffix(buf, result) return result } @@ -776,6 +739,20 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + var buf []byte + buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + pgConn.execExtendedSuffix(buf, result) + + return result +} + +func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { result := &ResultReader{ pgConn: pgConn, ctx: ctx, @@ -805,8 +782,10 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - var buf []byte - buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + return result +} + +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -819,8 +798,6 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa result.closed = true pgConn.unlock() } - - return result } // CopyTo executes the copy command sql and copies the results to w.