From dc699cefc7b3429e77825867f6c0d729970da2da Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 11:38:23 -0500 Subject: [PATCH] Conn.CopyFrom takes context --- bench_test.go | 3 ++- copy_from.go | 8 ++++---- copy_from_test.go | 14 +++++++------- tx.go | 4 ++-- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/bench_test.go b/bench_test.go index 1eb492a9..a47d6b51 100644 --- a/bench_test.go +++ b/bench_test.go @@ -492,7 +492,8 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { for i := 0; i < b.N; i++ { src := newBenchmarkWriteTableCopyFromSrc(n) - _, err := conn.CopyFrom(pgx.Identifier{"t"}, + _, err := conn.CopyFrom(context.Background(), + pgx.Identifier{"t"}, []string{"varchar_1", "varchar_2", "varchar_null_1", diff --git a/copy_from.go b/copy_from.go index 34a28dff..08c9488b 100644 --- a/copy_from.go +++ b/copy_from.go @@ -57,7 +57,7 @@ type copyFrom struct { readerErrChan chan error } -func (ct *copyFrom) run() (int, error) { +func (ct *copyFrom) run(ctx context.Context) (int, error) { quotedTableName := ct.tableName.Sanitize() cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { @@ -111,7 +111,7 @@ func (ct *copyFrom) run() (int, error) { w.Close() }() - commandTag, err := ct.conn.pgConn.CopyFrom(context.TODO(), r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) return int(commandTag.RowsAffected()), err } @@ -149,7 +149,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt // CopyFrom requires all values use the binary format. Almost all types // implemented by pgx use the binary format by default. Types implementing // Encoder can only be used if they encode to the binary format. -func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { +func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { ct := ©From{ conn: c, tableName: tableName, @@ -158,5 +158,5 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF readerErrChan: make(chan error), } - return ct.run() + return ct.run(ctx) } diff --git a/copy_from_test.go b/copy_from_test.go index 7f00df7b..b74878a8 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -35,7 +35,7 @@ func TestConnCopyFromSmall(t *testing.T) { {nil, nil, nil, nil, nil, nil, nil}, } - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } @@ -93,7 +93,7 @@ func TestConnCopyFromLarge(t *testing.T) { inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) } - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } @@ -148,7 +148,7 @@ func TestConnCopyFromJSON(t *testing.T) { {nil, nil}, } - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } @@ -220,7 +220,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { {int32(3), "def"}, } - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } @@ -291,7 +291,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { startTime := time.Now() - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } @@ -343,7 +343,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { a bytea not null )`) - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } @@ -403,7 +403,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { a bytea not null )`) - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } diff --git a/tx.go b/tx.go index 96afadeb..00a91d5d 100644 --- a/tx.go +++ b/tx.go @@ -189,12 +189,12 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row } // CopyFrom delegates to the underlying *Conn -func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { +func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { if tx.status != TxStatusInProgress { return 0, ErrTxClosed } - return tx.conn.CopyFrom(tableName, columnNames, rowSrc) + return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc) } // Status returns the status of the transaction from the set of