From 89416dd80542cc62f45af214ca0722c32e6624ca Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 13:09:50 +0200 Subject: [PATCH] Enable passing nil context --- .gitignore | 3 +- doc.go | 3 + pgconn.go | 187 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 116 insertions(+), 77 deletions(-) diff --git a/.gitignore b/.gitignore index 6eb9d442..e980f555 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .envrc -vendor/ \ No newline at end of file +vendor/ +.vscode diff --git a/doc.go b/doc.go index cde58cd8..12ed6630 100644 --- a/doc.go +++ b/doc.go @@ -23,6 +23,9 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. +A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional +slight performance increase, if you don't need the operation to be cancellable. + The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. */ diff --git a/pgconn.go b/pgconn.go index 4c75d367..3b90b802 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,6 +116,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } + if ctx == nil { + ctx = context.Background() + } + // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -362,13 +366,15 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() n, err := pgConn.conn.Write(buf) if err != nil { @@ -392,13 +398,15 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() msg, err := pgConn.receiveMessage() if err != nil { @@ -489,8 +497,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + if ctx != nil { + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } // Ignore any errors sending Terminate message and waiting for server to close connection. // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully @@ -586,13 +596,15 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -673,18 +685,24 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + _ctx := ctx + if _ctx == nil { + _ctx = context.Background() + } + cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) - contextWatcher.Watch(ctx) - defer contextWatcher.Unwatch() + if ctx != nil { + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -712,14 +730,16 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return ctx.Err() - default: - } + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } for { msg, err := pgConn.receiveMessage() @@ -752,16 +772,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != nil { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } else { + pgConn.multiResultReader.ctx = context.Background() } - pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -808,7 +831,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -834,7 +857,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -845,6 +868,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by ctx: ctx, } result := &pgConn.resultReader + if ctx == nil { + pgConn.resultReader.ctx = context.Background() + } if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) @@ -859,20 +885,22 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - select { - case <-ctx.Done(): - result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) - result.closed = true - pgConn.unlock() - return result - default: + if ctx != nil { + select { + case <-ctx.Done(): + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) } - pgConn.contextWatcher.Watch(ctx) return result } -func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -893,14 +921,16 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - select { - case <-ctx.Done(): - pgConn.unlock() - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + pgConn.unlock() + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -952,13 +982,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -1344,16 +1376,19 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != nil { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } else { + pgConn.multiResultReader.ctx = context.Background() } - pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)