From 25099e6f897c44d798c6d4a7ea1d2dc351417977 Mon Sep 17 00:00:00 2001 From: Jordan Lewis Date: Mon, 25 May 2020 01:37:48 -0400 Subject: [PATCH] Permit SendBatch with Simple Protocol This commit adds support for sending batches of queries via the Simple protocol with SendBatch. The result appears identically to how it would if it were created with the extended protocol. --- conn.go | 23 +++++++++++++++++++++++ conn_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/conn.go b/conn.go index b3bfd06d..0c4cfbe2 100644 --- a/conn.go +++ b/conn.go @@ -708,6 +708,29 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { + simpleProtocol := c.config.PreferSimpleProtocol + var sb strings.Builder + if simpleProtocol { + for i, bi := range b.items { + if i > 0 { + sb.WriteByte(';') + } + sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + sb.WriteString(sql) + } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } + } + distinctUnpreparedQueries := map[string]struct{}{} for _, bi := range b.items { diff --git a/conn_test.go b/conn_test.go index 72022b21..e10ff978 100644 --- a/conn_test.go +++ b/conn_test.go @@ -360,6 +360,48 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { } +func TestSendBatchSimpleProtocol(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.PreferSimpleProtocol = true + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + + var batch pgx.Batch + batch.Queue("SELECT 1::int") + batch.Queue("SELECT 2::int; SELECT $1::int", 3) + results := conn.SendBatch(ctx, &batch) + rows, err := results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err := rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(1), values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(2), values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(3), values[0]) + assert.False(t, rows.Next()) +} + func TestPrepare(t *testing.T) { t.Parallel()