From 25099e6f897c44d798c6d4a7ea1d2dc351417977 Mon Sep 17 00:00:00 2001 From: Jordan Lewis Date: Mon, 25 May 2020 01:37:48 -0400 Subject: [PATCH 1/3] Permit SendBatch with Simple Protocol This commit adds support for sending batches of queries via the Simple protocol with SendBatch. The result appears identically to how it would if it were created with the extended protocol. --- conn.go | 23 +++++++++++++++++++++++ conn_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/conn.go b/conn.go index b3bfd06d..0c4cfbe2 100644 --- a/conn.go +++ b/conn.go @@ -708,6 +708,29 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { + simpleProtocol := c.config.PreferSimpleProtocol + var sb strings.Builder + if simpleProtocol { + for i, bi := range b.items { + if i > 0 { + sb.WriteByte(';') + } + sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + sb.WriteString(sql) + } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } + } + distinctUnpreparedQueries := map[string]struct{}{} for _, bi := range b.items { diff --git a/conn_test.go b/conn_test.go index 72022b21..e10ff978 100644 --- a/conn_test.go +++ b/conn_test.go @@ -360,6 +360,48 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { } +func TestSendBatchSimpleProtocol(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.PreferSimpleProtocol = true + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + + var batch pgx.Batch + batch.Queue("SELECT 1::int") + batch.Queue("SELECT 2::int; SELECT $1::int", 3) + results := conn.SendBatch(ctx, &batch) + rows, err := results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err := rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(1), values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(2), values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(3), values[0]) + assert.False(t, rows.Next()) +} + func TestPrepare(t *testing.T) { t.Parallel() From 8bad186207c1b0863d5d8e19d5525e958f517f8a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 25 May 2020 11:35:20 -0500 Subject: [PATCH 2/3] Avoid race between close conn and cancel ctx --- conn_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/conn_test.go b/conn_test.go index e10ff978..da0520df 100644 --- a/conn_test.go +++ b/conn_test.go @@ -366,12 +366,11 @@ func TestSendBatchSimpleProtocol(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.PreferSimpleProtocol = true - conn := mustConnect(t, config) - defer closeConn(t, conn) - ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + conn := mustConnect(t, config) + defer closeConn(t, conn) var batch pgx.Batch batch.Queue("SELECT 1::int") From 72bba7fb4278bf4126f6abf66952815edb1476f2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 25 May 2020 11:36:18 -0500 Subject: [PATCH 3/3] Move batch simple protocol test to batch_test.go --- batch_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ conn_test.go | 41 ----------------------------------------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/batch_test.go b/batch_test.go index f487c52a..113ce3cf 100644 --- a/batch_test.go +++ b/batch_test.go @@ -725,3 +725,44 @@ func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"]) } } + +func TestSendBatchSimpleProtocol(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.PreferSimpleProtocol = true + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + var batch pgx.Batch + batch.Queue("SELECT 1::int") + batch.Queue("SELECT 2::int; SELECT $1::int", 3) + results := conn.SendBatch(ctx, &batch) + rows, err := results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err := rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(1), values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(2), values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.Equal(t, int32(3), values[0]) + assert.False(t, rows.Next()) +} diff --git a/conn_test.go b/conn_test.go index da0520df..72022b21 100644 --- a/conn_test.go +++ b/conn_test.go @@ -360,47 +360,6 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { } -func TestSendBatchSimpleProtocol(t *testing.T) { - t.Parallel() - - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.PreferSimpleProtocol = true - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - var batch pgx.Batch - batch.Queue("SELECT 1::int") - batch.Queue("SELECT 2::int; SELECT $1::int", 3) - results := conn.SendBatch(ctx, &batch) - rows, err := results.Query() - assert.NoError(t, err) - assert.True(t, rows.Next()) - values, err := rows.Values() - assert.NoError(t, err) - assert.Equal(t, int32(1), values[0]) - assert.False(t, rows.Next()) - - rows, err = results.Query() - assert.NoError(t, err) - assert.True(t, rows.Next()) - values, err = rows.Values() - assert.NoError(t, err) - assert.Equal(t, int32(2), values[0]) - assert.False(t, rows.Next()) - - rows, err = results.Query() - assert.NoError(t, err) - assert.True(t, rows.Next()) - values, err = rows.Values() - assert.NoError(t, err) - assert.Equal(t, int32(3), values[0]) - assert.False(t, rows.Next()) -} - func TestPrepare(t *testing.T) { t.Parallel()