SendBatch now uses pipeline mode to prepare and describe statements
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements in a single network round trip. So it would only take 2 round trips.
This commit is contained in:
@@ -236,11 +236,11 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
||||
c.wbuf = make([]byte, 0, 1024)
|
||||
|
||||
if c.config.StatementCacheCapacity > 0 {
|
||||
c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity)
|
||||
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
|
||||
}
|
||||
|
||||
if c.config.DescriptionCacheCapacity > 0 {
|
||||
c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity)
|
||||
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
@@ -382,6 +382,10 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
|
||||
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
|
||||
// positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
commandTag, err := c.exec(ctx, sql, arguments...)
|
||||
@@ -437,9 +441,13 @@ optionLoop:
|
||||
if c.statementCache == nil {
|
||||
return pgconn.CommandTag{}, errDisabledStatementCache
|
||||
}
|
||||
sd, err := c.statementCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
sd := c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
|
||||
return c.execPrepared(ctx, sd, arguments)
|
||||
@@ -447,9 +455,12 @@ optionLoop:
|
||||
if c.descriptionCache == nil {
|
||||
return pgconn.CommandTag{}, errDisabledDescriptionCache
|
||||
}
|
||||
sd, err := c.descriptionCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
sd := c.descriptionCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return c.execParams(ctx, sd, arguments)
|
||||
@@ -620,6 +631,10 @@ type QueryRewriter interface {
|
||||
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
|
||||
// needed. See the documentation for those types for details.
|
||||
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
return &baseRows{err: err, closed: true}, err
|
||||
}
|
||||
|
||||
var resultFormats QueryResultFormats
|
||||
var resultFormatsByOID QueryResultFormatsByOID
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
@@ -649,6 +664,11 @@ optionLoop:
|
||||
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||
}
|
||||
|
||||
// Bypass any statement caching.
|
||||
if sql == "" {
|
||||
mode = QueryExecModeSimpleProtocol
|
||||
}
|
||||
|
||||
c.eqb.reset()
|
||||
anynil.NormalizeSlice(args)
|
||||
rows := c.getRows(ctx, sql, args)
|
||||
@@ -664,10 +684,14 @@ optionLoop:
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
sd, err = c.statementCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
sd = c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeCacheDescribe:
|
||||
if c.descriptionCache == nil {
|
||||
@@ -675,10 +699,14 @@ optionLoop:
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
sd, err = c.descriptionCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
sd = c.descriptionCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
c.descriptionCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeDescribeExec:
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
@@ -767,6 +795,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
|
||||
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
||||
// is used again.
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
|
||||
for _, bi := range b.items {
|
||||
@@ -794,105 +826,70 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
}
|
||||
|
||||
if mode == QueryExecModeSimpleProtocol {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.items {
|
||||
if i > 0 {
|
||||
sb.WriteByte(';')
|
||||
}
|
||||
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
sb.WriteString(sql)
|
||||
}
|
||||
mrr := c.pgConn.Exec(ctx, sb.String())
|
||||
return &batchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
b: b,
|
||||
ix: 0,
|
||||
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
|
||||
}
|
||||
|
||||
// All other modes use extended protocol and thus can use prepared statements.
|
||||
for _, bi := range b.items {
|
||||
if sd, ok := c.preparedStatements[bi.query]; ok {
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case QueryExecModeExec:
|
||||
return c.sendBatchQueryExecModeExec(ctx, b)
|
||||
case QueryExecModeCacheStatement:
|
||||
return c.sendBatchQueryExecModeCacheStatement(ctx, b)
|
||||
case QueryExecModeCacheDescribe:
|
||||
return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
|
||||
case QueryExecModeDescribeExec:
|
||||
return c.sendBatchQueryExecModeDescribeExec(ctx, b)
|
||||
default:
|
||||
panic("unknown QueryExecMode")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.items {
|
||||
if i > 0 {
|
||||
sb.WriteByte(';')
|
||||
}
|
||||
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
sb.WriteString(sql)
|
||||
}
|
||||
mrr := c.pgConn.Exec(ctx, sb.String())
|
||||
return &batchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
b: b,
|
||||
ix: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
if mode == QueryExecModeExec {
|
||||
for _, bi := range b.items {
|
||||
c.eqb.reset()
|
||||
anynil.NormalizeSlice(bi.arguments)
|
||||
|
||||
sd := c.preparedStatements[bi.query]
|
||||
if sd != nil {
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
distinctUnpreparedQueries := map[string]struct{}{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
if _, ok := c.preparedStatements[bi.query]; ok {
|
||||
continue
|
||||
}
|
||||
distinctUnpreparedQueries[bi.query] = struct{}{}
|
||||
}
|
||||
|
||||
var stmtCache stmtcache.Cache
|
||||
if len(distinctUnpreparedQueries) > 0 {
|
||||
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||
stmtCache = c.statementCache
|
||||
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||
stmtCache = c.descriptionCache
|
||||
} else {
|
||||
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
|
||||
}
|
||||
|
||||
for sql, _ := range distinctUnpreparedQueries {
|
||||
_, err := stmtCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, bi := range b.items {
|
||||
c.eqb.reset()
|
||||
|
||||
sd := c.preparedStatements[bi.query]
|
||||
if sd == nil {
|
||||
var err error
|
||||
sd, err = stmtCache.Get(ctx, bi.query)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
if len(sd.ParamOIDs) != len(bi.arguments) {
|
||||
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
|
||||
}
|
||||
|
||||
for _, bi := range b.items {
|
||||
sd := bi.sd
|
||||
if sd != nil {
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
if sd.Name == "" {
|
||||
batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -909,6 +906,171 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
if c.statementCache == nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
|
||||
}
|
||||
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
if bi.sd == nil {
|
||||
sd := c.statementCache.Get(bi.query)
|
||||
if sd != nil {
|
||||
bi.sd = sd
|
||||
} else {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd = &pgconn.StatementDescription{
|
||||
Name: stmtcache.NextStatementName(),
|
||||
SQL: bi.query,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
if c.descriptionCache == nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
|
||||
}
|
||||
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
if bi.sd == nil {
|
||||
sd := c.descriptionCache.Get(bi.query)
|
||||
if sd != nil {
|
||||
bi.sd = sd
|
||||
} else {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd = &pgconn.StatementDescription{
|
||||
SQL: bi.query,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
if bi.sd == nil {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd := &pgconn.StatementDescription{
|
||||
SQL: bi.query,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
|
||||
pipeline := c.pgConn.StartPipeline(context.Background())
|
||||
defer func() {
|
||||
if pbr.err != nil {
|
||||
pipeline.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Prepare any needed queries
|
||||
if len(distinctNewQueries) > 0 {
|
||||
for _, sd := range distinctNewQueries {
|
||||
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
for _, sd := range distinctNewQueries {
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
resultSD, ok := results.(*pgconn.StatementDescription)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
|
||||
}
|
||||
|
||||
// Fill in the previously empty / pending statement descriptions.
|
||||
sd.ParamOIDs = resultSD.ParamOIDs
|
||||
sd.Fields = resultSD.Fields
|
||||
}
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
_, ok := results.(*pgconn.PipelineSync)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
|
||||
}
|
||||
}
|
||||
|
||||
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
|
||||
if sdCache != nil {
|
||||
for _, sd := range distinctNewQueries {
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
}
|
||||
|
||||
// Queue the queries.
|
||||
for _, bi := range b.items {
|
||||
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
if bi.sd.Name == "" {
|
||||
pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
}
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
return &pipelineBatchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
pipeline: pipeline,
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
|
||||
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
|
||||
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
||||
@@ -1015,3 +1177,37 @@ order by attnum`,
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
|
||||
if c.descriptionCache != nil {
|
||||
c.descriptionCache.HandleInvalidated()
|
||||
}
|
||||
|
||||
var invalidatedStatements []*pgconn.StatementDescription
|
||||
if c.statementCache != nil {
|
||||
invalidatedStatements = c.statementCache.HandleInvalidated()
|
||||
}
|
||||
|
||||
if len(invalidatedStatements) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pipeline := c.pgConn.StartPipeline(ctx)
|
||||
defer pipeline.Close()
|
||||
|
||||
for _, sd := range invalidatedStatements {
|
||||
pipeline.SendDeallocate(sd.Name)
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||
}
|
||||
|
||||
err = pipeline.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user