diff --git a/conn.go b/conn.go index b3bfd06d..0c4cfbe2 100644 --- a/conn.go +++ b/conn.go @@ -708,6 +708,29 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { + simpleProtocol := c.config.PreferSimpleProtocol + var sb strings.Builder + if simpleProtocol { + for i, bi := range b.items { + if i > 0 { + sb.WriteByte(';') + } + sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + sb.WriteString(sql) + } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } + } + distinctUnpreparedQueries := map[string]struct{}{} for _, bi := range b.items { diff --git a/conn_test.go b/conn_test.go index 72022b21..e10ff978 100644 --- a/conn_test.go +++ b/conn_test.go @@ -360,6 +360,48 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { } +func TestSendBatchSimpleProtocol(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.PreferSimpleProtocol = true + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + + var batch pgx.Batch + batch.Queue("SELECT 1::int") + batch.Queue("SELECT 2::int; SELECT $1::int", 3) + results := conn.SendBatch(ctx, &batch) + rows, err := results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err := rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(1), values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(2), values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(3), values[0]) + assert.False(t, rows.Next()) +} + func TestPrepare(t *testing.T) { t.Parallel()