From 1baf0ef57ec8643d0417d5b2b909ba17c214d125 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 May 2019 18:05:06 -0500 Subject: [PATCH] Refactor context handling into ctxwatch package --- benchmark_test.go | 16 +++ chan_to_set_deadline.go | 51 -------- go.mod | 1 - go.sum | 1 + helper_test.go | 4 +- internal/ctxwatch/context_watcher.go | 64 ++++++++++ internal/ctxwatch/context_watcher_test.go | 139 ++++++++++++++++++++++ pgconn.go | 65 ++++++---- 8 files changed, 261 insertions(+), 80 deletions(-) delete mode 100644 chan_to_set_deadline.go create mode 100644 internal/ctxwatch/context_watcher.go create mode 100644 internal/ctxwatch/context_watcher_test.go diff --git a/benchmark_test.go b/benchmark_test.go index 000dfd1b..073281aa 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -206,3 +206,19 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { } } } + +// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { +// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) +// require.Nil(b, err) +// defer closeConn(b, conn) + +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() + +// b.ResetTimer() + +// for i := 0; i < b.N; i++ { +// conn.ChanToSetDeadline().Watch(ctx) +// conn.ChanToSetDeadline().Ignore() +// } +// } diff --git a/chan_to_set_deadline.go b/chan_to_set_deadline.go deleted file mode 100644 index 04bb8fde..00000000 --- a/chan_to_set_deadline.go +++ /dev/null @@ -1,51 +0,0 @@ -package pgconn - -import ( - "time" -) - -var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) - -type setDeadliner interface { - SetDeadline(time.Time) error -} - -type chanToSetDeadline struct { - cleanupChan chan struct{} - conn setDeadliner - deadlineWasSet bool - cleanupComplete bool -} - -func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) { - if this.cleanupChan == nil { - this.cleanupChan = make(chan struct{}) - } - this.conn = conn - this.deadlineWasSet = false - this.cleanupComplete = false - - if doneChan != nil { - go func() { - select { - case <-doneChan: - conn.SetDeadline(deadlineTime) - this.deadlineWasSet = true - <-this.cleanupChan - case <-this.cleanupChan: - } - }() - } else { - this.cleanupComplete = true - } -} - -func (this *chanToSetDeadline) cleanup() { - if !this.cleanupComplete { - this.cleanupChan <- struct{}{} - if this.deadlineWasSet { - this.conn.SetDeadline(time.Time{}) - } - this.cleanupComplete = true - } -} diff --git a/go.mod b/go.mod index acbee593..4ad3564a 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db - github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 9160f187..9e2398cb 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/helper_test.go b/helper_test.go index 5d44f3b8..1a3ca75e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -12,9 +12,9 @@ import ( ) func closeConn(t testing.TB, conn *pgconn.PgConn) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - require.Nil(t, conn.Close(ctx)) + require.NoError(t, conn.Close(ctx)) } // Do a simple query to ensure the connection is still usable diff --git a/internal/ctxwatch/context_watcher.go b/internal/ctxwatch/context_watcher.go new file mode 100644 index 00000000..391f0b79 --- /dev/null +++ b/internal/ctxwatch/context_watcher.go @@ -0,0 +1,64 @@ +package ctxwatch + +import ( + "context" +) + +// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a +// time. +type ContextWatcher struct { + onCancel func() + onUnwatchAfterCancel func() + unwatchChan chan struct{} + watchInProgress bool + onCancelWasCalled bool +} + +// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. +// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and +// onCancel called. +func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { + cw := &ContextWatcher{ + onCancel: onCancel, + onUnwatchAfterCancel: onUnwatchAfterCancel, + unwatchChan: make(chan struct{}), + } + + return cw +} + +// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. +func (cw *ContextWatcher) Watch(ctx context.Context) { + if cw.watchInProgress { + panic("Watch already in progress") + } + + cw.onCancelWasCalled = false + + if ctx.Done() != nil { + cw.watchInProgress = true + go func() { + select { + case <-ctx.Done(): + cw.onCancel() + cw.onCancelWasCalled = true + <-cw.unwatchChan + case <-cw.unwatchChan: + } + }() + } else { + cw.watchInProgress = false + } +} + +// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was +// called then onUnwatchAfterCancel will also be called. +func (cw *ContextWatcher) Unwatch() { + if cw.watchInProgress { + cw.unwatchChan <- struct{}{} + if cw.onCancelWasCalled { + cw.onUnwatchAfterCancel() + } + cw.watchInProgress = false + } +} diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go new file mode 100644 index 00000000..0b491bf8 --- /dev/null +++ b/internal/ctxwatch/context_watcher_test.go @@ -0,0 +1,139 @@ +package ctxwatch_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgconn/internal/ctxwatch" + "github.com/stretchr/testify/require" +) + +func TestContextWatcherContextCancelled(t *testing.T) { + canceledChan := make(chan struct{}) + cleanupCalled := false + cw := ctxwatch.NewContextWatcher(func() { + canceledChan <- struct{}{} + }, func() { + cleanupCalled = true + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + + select { + case <-canceledChan: + case <-time.NewTimer(time.Second).C: + t.Fatal("Timed out waiting for cancel func to be called") + } + + cw.Unwatch() + + require.True(t, cleanupCalled, "Cleanup func was not called") +} + +func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() { + t.Error("cancel func should not have been called") + }, func() { + t.Error("cleanup func should not have been called") + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cw.Unwatch() + cancel() +} + +func TestContextWatcherMultipleWatchPanics(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") +} + +func TestContextWatcherStress(t *testing.T) { + var cancelFuncCalls int64 + var cleanupFuncCalls int64 + + cw := ctxwatch.NewContextWatcher(func() { + atomic.AddInt64(&cancelFuncCalls, 1) + }, func() { + atomic.AddInt64(&cleanupFuncCalls, 1) + }) + + cycleCount := 100000 + + for i := 0; i < cycleCount; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + if i%2 == 0 { + cancel() + } + + // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. + if i%3 == 0 { + time.Sleep(time.Nanosecond) + } + + cw.Unwatch() + if i%2 == 1 { + cancel() + } + } + + actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) + actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls) + + if actualCancelFuncCalls == 0 { + t.Fatal("actualCancelFuncCalls == 0") + } + + maxCancelFuncCalls := int64(cycleCount) / 2 + if actualCancelFuncCalls > maxCancelFuncCalls { + t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls) + } + + if actualCancelFuncCalls != actualCleanupFuncCalls { + t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls) + } +} + +func BenchmarkContextWatcherUncancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + cw.Watch(context.Background()) + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancelled(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + cw.Watch(ctx) + cw.Unwatch() + } +} diff --git a/pgconn.go b/pgconn.go index a4402a7d..aad5fafd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -13,7 +13,9 @@ import ( "strconv" "strings" "sync" + "time" + "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -21,6 +23,7 @@ import ( const ( connStatusUninitialized = iota + connStatusConnecting connStatusClosed connStatusIdle connStatusBusy @@ -71,10 +74,10 @@ type PgConn struct { bufferingReceiveErr error // Reusable / preallocated resources - wbuf []byte // write buffer - resultReader ResultReader - multiResultReader MultiResultReader - doneChanToDeadline chanToSetDeadline + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader + contextWatcher *ctxwatch.ContextWatcher } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -149,6 +152,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } + pgConn.status = connStatusConnecting + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) if err != nil { return nil, err @@ -355,8 +364,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { @@ -377,6 +386,7 @@ func (pgConn *PgConn) hardClose() error { return nil } pgConn.status = connStatusClosed + return pgConn.conn.Close() } @@ -453,8 +463,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -543,9 +553,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - var doneChanToDeadline chanToSetDeadline - doneChanToDeadline.start(ctx.Done(), cancelConn) - defer doneChanToDeadline.cleanup() + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -579,8 +592,8 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() for { msg, err := pgConn.ReceiveMessage() @@ -622,7 +635,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -630,7 +643,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Unwatch() multiResult.closed = true if n == 0 { err = linkErrors(err, ErrNoBytesSent) @@ -732,7 +745,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) return result } @@ -749,7 +762,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result err = linkErrors(err, ErrNoBytesSent) } result.concludeCommand(nil, linkErrors(ctx.Err(), err)) - pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() } @@ -767,8 +780,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -828,8 +841,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -962,7 +975,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) msg, err := mrr.pgConn.ReceiveMessage() if err != nil { - mrr.pgConn.doneChanToDeadline.cleanup() + mrr.pgConn.contextWatcher.Unwatch() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.hardClose() @@ -971,7 +984,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mrr.pgConn.doneChanToDeadline.cleanup() + mrr.pgConn.contextWatcher.Unwatch() mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: @@ -1129,7 +1142,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: - rr.pgConn.doneChanToDeadline.cleanup() + rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() return rr.commandTag, rr.err } @@ -1148,7 +1161,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error if err != nil { rr.concludeCommand(nil, err) - rr.pgConn.doneChanToDeadline.cleanup() + rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { rr.pgConn.hardClose() @@ -1223,7 +1236,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)