diff --git a/chan_to_set_deadline.go b/chan_to_set_deadline.go new file mode 100644 index 00000000..04bb8fde --- /dev/null +++ b/chan_to_set_deadline.go @@ -0,0 +1,51 @@ +package pgconn + +import ( + "time" +) + +var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) + +type setDeadliner interface { + SetDeadline(time.Time) error +} + +type chanToSetDeadline struct { + cleanupChan chan struct{} + conn setDeadliner + deadlineWasSet bool + cleanupComplete bool +} + +func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) { + if this.cleanupChan == nil { + this.cleanupChan = make(chan struct{}) + } + this.conn = conn + this.deadlineWasSet = false + this.cleanupComplete = false + + if doneChan != nil { + go func() { + select { + case <-doneChan: + conn.SetDeadline(deadlineTime) + this.deadlineWasSet = true + <-this.cleanupChan + case <-this.cleanupChan: + } + }() + } else { + this.cleanupComplete = true + } +} + +func (this *chanToSetDeadline) cleanup() { + if !this.cleanupComplete { + this.cleanupChan <- struct{}{} + if this.deadlineWasSet { + this.conn.SetDeadline(time.Time{}) + } + this.cleanupComplete = true + } +} diff --git a/pgconn.go b/pgconn.go index 7bc93435..6ff0d39f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,14 +15,11 @@ import ( "strconv" "strings" "sync" - "time" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" ) -var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) - // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -100,9 +97,10 @@ type PgConn struct { bufferingReceiveErr error // Reusable / preallocated resources - wbuf []byte // write buffer - resultReader ResultReader - multiResultReader MultiResultReader + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader + doneChanToDeadline chanToSetDeadline } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -382,8 +380,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { @@ -463,38 +461,6 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { return err } -// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from -// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to -// call multiple times. -func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { - if ctx.Done() != nil { - deadlineWasSet := false - doneChan := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - conn.SetDeadline(deadlineTime) - deadlineWasSet = true - <-doneChan - case <-doneChan: - } - }() - - finished := false - return func() { - if !finished { - doneChan <- struct{}{} - if deadlineWasSet { - conn.SetDeadline(time.Time{}) - } - finished = true - } - } - } - - return func() {} -} - type PreparedStatementDescription struct { Name string SQL string @@ -512,8 +478,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -599,8 +565,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, cancelConn) - defer cleanupContext() + var doneChanToDeadline chanToSetDeadline + doneChanToDeadline.start(ctx.Done(), cancelConn) + defer doneChanToDeadline.cleanup() buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -624,16 +591,16 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { pgConn.lock() + defer pgConn.unlock() select { case <-ctx.Done(): - pgConn.unlock() return ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() - defer pgConn.unlock() + + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() for { msg, err := pgConn.ReceiveMessage() @@ -657,9 +624,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.lock() pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } multiResult := &pgConn.multiResultReader @@ -671,7 +637,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult default: } - multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -679,7 +645,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - multiResult.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) pgConn.unlock() @@ -753,9 +719,8 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by pgConn.lock() pgConn.resultReader = ResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } result := &pgConn.resultReader @@ -774,7 +739,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result default: } - result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) return result } @@ -788,7 +753,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { if err != nil { pgConn.hardClose() result.concludeCommand(nil, err) - result.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() result.closed = true pgConn.unlock() } @@ -804,8 +769,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() // Send copy to command buf := pgConn.wbuf @@ -861,8 +826,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() // Send copy to command buf := pgConn.wbuf @@ -967,9 +932,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - cleanupContextDeadline func() + pgConn *PgConn + ctx context.Context rr *ResultReader @@ -993,7 +957,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) msg, err := mrr.pgConn.ReceiveMessage() if err != nil { - mrr.cleanupContextDeadline() + mrr.pgConn.doneChanToDeadline.cleanup() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.hardClose() @@ -1002,7 +966,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mrr.cleanupContextDeadline() + mrr.pgConn.doneChanToDeadline.cleanup() mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: @@ -1023,11 +987,10 @@ func (mrr *MultiResultReader) NextResult() bool { switch msg := msg.(type) { case *pgproto3.RowDescription: mrr.pgConn.resultReader = ResultReader{ - pgConn: mrr.pgConn, - multiResultReader: mrr, - ctx: mrr.ctx, - cleanupContextDeadline: func() {}, - fieldDescriptions: msg.Fields, + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + fieldDescriptions: msg.Fields, } mrr.rr = &mrr.pgConn.resultReader return true @@ -1066,10 +1029,9 @@ func (mrr *MultiResultReader) Close() error { // ResultReader is a reader for the result of a single query. type ResultReader struct { - pgConn *PgConn - multiResultReader *MultiResultReader - ctx context.Context - cleanupContextDeadline func() + pgConn *PgConn + multiResultReader *MultiResultReader + ctx context.Context fieldDescriptions []pgproto3.FieldDescription rowValues [][]byte @@ -1162,7 +1124,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: - rr.cleanupContextDeadline() + rr.pgConn.doneChanToDeadline.cleanup() rr.pgConn.unlock() return rr.commandTag, rr.err } @@ -1181,7 +1143,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error if err != nil { rr.concludeCommand(nil, err) - rr.cleanupContextDeadline() + rr.pgConn.doneChanToDeadline.cleanup() rr.closed = true if rr.multiResultReader == nil { rr.pgConn.hardClose() @@ -1238,9 +1200,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.lock() pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } multiResult := &pgConn.multiResultReader @@ -1252,13 +1213,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult default: } - multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) _, err := pgConn.conn.Write(batch.buf) if err != nil { pgConn.hardClose() - multiResult.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) pgConn.unlock()