2
0

SendBatch supports default QueryExecMode

This commit is contained in:
Jack Christensen
2022-03-12 15:06:13 -06:00
parent 1390a11fe2
commit cb721dfb5b
3 changed files with 500 additions and 469 deletions
+90 -53
View File
@@ -861,9 +861,10 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol
var sb strings.Builder
if simpleProtocol {
mode := c.config.DefaultQueryExecMode
if mode == QueryExecModeSimpleProtocol {
var sb strings.Builder
for i, bi := range b.items {
if i > 0 {
sb.WriteByte(';')
@@ -884,66 +885,102 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
}
}
distinctUnpreparedQueries := map[string]struct{}{}
for _, bi := range b.items {
if _, ok := c.preparedStatements[bi.query]; ok {
continue
}
distinctUnpreparedQueries[bi.query] = struct{}{}
}
var stmtCache stmtcache.Cache
if len(distinctUnpreparedQueries) > 0 {
if c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.statementCache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
}
batch := &pgconn.Batch{}
for _, bi := range b.items {
c.eqb.Reset()
if mode == QueryExecModeExec {
for _, bi := range b.items {
c.eqb.Reset()
anynil.NormalizeSlice(bi.arguments)
sd := c.preparedStatements[bi.query]
if sd == nil {
var err error
sd, err = stmtCache.Get(ctx, bi.query)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
sd := c.preparedStatements[bi.query]
if sd != nil {
if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
}
for i := range bi.arguments {
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
} else {
err := c.appendParamsForQueryExecModeExec(bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats)
}
}
} else {
distinctUnpreparedQueries := map[string]struct{}{}
for _, bi := range b.items {
if _, ok := c.preparedStatements[bi.query]; ok {
continue
}
distinctUnpreparedQueries[bi.query] = struct{}{}
}
var stmtCache stmtcache.Cache
if len(distinctUnpreparedQueries) > 0 {
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.statementCache
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.descriptionCache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
}
if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
}
for _, bi := range b.items {
c.eqb.Reset()
anynil.NormalizeSlice(bi.arguments)
for i := range bi.arguments {
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
sd := c.preparedStatements[bi.query]
if sd == nil {
var err error
sd, err = stmtCache.Get(ctx, bi.query)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
}
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
}
if sd.Name == "" {
batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats)
} else {
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
anynil.NormalizeSlice(bi.arguments)
for i := range bi.arguments {
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
if sd.Name == "" {
batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats)
} else {
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
}
}
}