2
0

SendBatch now uses pipeline mode to prepare and describe statements

Previously, a batch with 10 unique parameterized statements executed
100 times would entail 11 network round trips. 1 for each prepare /
describe and 1 for executing them all. Now pipeline mode is used to
prepare / describe all statements in a single network round trip. So it
would only take 2 round trips.
This commit is contained in:
Jack Christensen
2022-07-09 09:28:11 -05:00
parent ba58e3d5d2
commit e7aa76ccf9
12 changed files with 694 additions and 612 deletions
+300 -104
View File
@@ -236,11 +236,11 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
c.wbuf = make([]byte, 0, 1024)
if c.config.StatementCacheCapacity > 0 {
c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity)
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
}
if c.config.DescriptionCacheCapacity > 0 {
c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity)
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
}
return c, nil
@@ -382,6 +382,10 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return pgconn.CommandTag{}, err
}
startTime := time.Now()
commandTag, err := c.exec(ctx, sql, arguments...)
@@ -437,9 +441,13 @@ optionLoop:
if c.statementCache == nil {
return pgconn.CommandTag{}, errDisabledStatementCache
}
sd, err := c.statementCache.Get(ctx, sql)
if err != nil {
return pgconn.CommandTag{}, err
sd := c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
return pgconn.CommandTag{}, err
}
c.statementCache.Put(sd)
}
return c.execPrepared(ctx, sd, arguments)
@@ -447,9 +455,12 @@ optionLoop:
if c.descriptionCache == nil {
return pgconn.CommandTag{}, errDisabledDescriptionCache
}
sd, err := c.descriptionCache.Get(ctx, sql)
if err != nil {
return pgconn.CommandTag{}, err
sd := c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
return pgconn.CommandTag{}, err
}
}
return c.execParams(ctx, sd, arguments)
@@ -620,6 +631,10 @@ type QueryRewriter interface {
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
// needed. See the documentation for those types for details.
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return &baseRows{err: err, closed: true}, err
}
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
mode := c.config.DefaultQueryExecMode
@@ -649,6 +664,11 @@ optionLoop:
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
}
// Bypass any statement caching.
if sql == "" {
mode = QueryExecModeSimpleProtocol
}
c.eqb.reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)
@@ -664,10 +684,14 @@ optionLoop:
rows.fatal(err)
return rows, err
}
sd, err = c.statementCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, err
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
@@ -675,10 +699,14 @@ optionLoop:
rows.fatal(err)
return rows, err
}
sd, err = c.descriptionCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, err
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql)
@@ -767,6 +795,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
mode := c.config.DefaultQueryExecMode
for _, bi := range b.items {
@@ -794,105 +826,70 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
}
if mode == QueryExecModeSimpleProtocol {
var sb strings.Builder
for i, bi := range b.items {
if i > 0 {
sb.WriteByte(';')
}
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
sb.WriteString(sql)
}
mrr := c.pgConn.Exec(ctx, sb.String())
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
}
// All other modes use extended protocol and thus can use prepared statements.
for _, bi := range b.items {
if sd, ok := c.preparedStatements[bi.query]; ok {
bi.sd = sd
}
}
switch mode {
case QueryExecModeExec:
return c.sendBatchQueryExecModeExec(ctx, b)
case QueryExecModeCacheStatement:
return c.sendBatchQueryExecModeCacheStatement(ctx, b)
case QueryExecModeCacheDescribe:
return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
case QueryExecModeDescribeExec:
return c.sendBatchQueryExecModeDescribeExec(ctx, b)
default:
panic("unknown QueryExecMode")
}
}
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder
for i, bi := range b.items {
if i > 0 {
sb.WriteByte(';')
}
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
sb.WriteString(sql)
}
mrr := c.pgConn.Exec(ctx, sb.String())
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
}
}
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{}
if mode == QueryExecModeExec {
for _, bi := range b.items {
c.eqb.reset()
anynil.NormalizeSlice(bi.arguments)
sd := c.preparedStatements[bi.query]
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
} else {
distinctUnpreparedQueries := map[string]struct{}{}
for _, bi := range b.items {
if _, ok := c.preparedStatements[bi.query]; ok {
continue
}
distinctUnpreparedQueries[bi.query] = struct{}{}
}
var stmtCache stmtcache.Cache
if len(distinctUnpreparedQueries) > 0 {
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.statementCache
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.descriptionCache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
}
for _, bi := range b.items {
c.eqb.reset()
sd := c.preparedStatements[bi.query]
if sd == nil {
var err error
sd, err = stmtCache.Get(ctx, bi.query)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
}
for _, bi := range b.items {
sd := bi.sd
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
if sd.Name == "" {
batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
@@ -909,6 +906,171 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
}
}
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.statementCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
}
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
sd := c.statementCache.Get(bi.query)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
Name: stmtcache.NextStatementName(),
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
}
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.descriptionCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
}
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
sd := c.descriptionCache.Get(bi.query)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
}
func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd := &pgconn.StatementDescription{
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
}
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
pipeline := c.pgConn.StartPipeline(context.Background())
defer func() {
if pbr.err != nil {
pipeline.Close()
}
}()
// Prepare any needed queries
if len(distinctNewQueries) > 0 {
for _, sd := range distinctNewQueries {
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
}
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
resultSD, ok := results.(*pgconn.StatementDescription)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
}
// Fill in the previously empty / pending statement descriptions.
sd.ParamOIDs = resultSD.ParamOIDs
sd.Fields = resultSD.Fields
}
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
_, ok := results.(*pgconn.PipelineSync)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
}
}
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
c.statementCache.Put(sd)
}
}
// Queue the queries.
for _, bi := range b.items {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
if bi.sd.Name == "" {
pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
return &pipelineBatchResults{
ctx: ctx,
conn: c,
pipeline: pipeline,
b: b,
}
}
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
@@ -1015,3 +1177,37 @@ order by attnum`,
return fields, nil
}
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
if c.descriptionCache != nil {
c.descriptionCache.HandleInvalidated()
}
var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil {
invalidatedStatements = c.statementCache.HandleInvalidated()
}
if len(invalidatedStatements) == 0 {
return nil
}
pipeline := c.pgConn.StartPipeline(ctx)
defer pipeline.Close()
for _, sd := range invalidatedStatements {
pipeline.SendDeallocate(sd.Name)
}
err := pipeline.Sync()
if err != nil {
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
err = pipeline.Close()
if err != nil {
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
return nil
}