2
0

0 alloc context to deadline

This commit is contained in:
Jack Christensen
2019-04-19 14:24:51 -05:00
parent 2383561e4d
commit 16412e56e2
2 changed files with 95 additions and 83 deletions
+51
View File
@@ -0,0 +1,51 @@
package pgconn
import (
"time"
)
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
type setDeadliner interface {
SetDeadline(time.Time) error
}
type chanToSetDeadline struct {
cleanupChan chan struct{}
conn setDeadliner
deadlineWasSet bool
cleanupComplete bool
}
func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) {
if this.cleanupChan == nil {
this.cleanupChan = make(chan struct{})
}
this.conn = conn
this.deadlineWasSet = false
this.cleanupComplete = false
if doneChan != nil {
go func() {
select {
case <-doneChan:
conn.SetDeadline(deadlineTime)
this.deadlineWasSet = true
<-this.cleanupChan
case <-this.cleanupChan:
}
}()
} else {
this.cleanupComplete = true
}
}
func (this *chanToSetDeadline) cleanup() {
if !this.cleanupComplete {
this.cleanupChan <- struct{}{}
if this.deadlineWasSet {
this.conn.SetDeadline(time.Time{})
}
this.cleanupComplete = true
}
}
+44 -83
View File
@@ -15,14 +15,11 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/jackc/pgio" "github.com/jackc/pgio"
"github.com/jackc/pgproto3/v2" "github.com/jackc/pgproto3/v2"
) )
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
// PgError represents an error reported by the PostgreSQL server. See // PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description. // detailed field description.
@@ -100,9 +97,10 @@ type PgConn struct {
bufferingReceiveErr error bufferingReceiveErr error
// Reusable / preallocated resources // Reusable / preallocated resources
wbuf []byte // write buffer wbuf []byte // write buffer
resultReader ResultReader resultReader ResultReader
multiResultReader MultiResultReader multiResultReader MultiResultReader
doneChanToDeadline chanToSetDeadline
} }
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
@@ -382,8 +380,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
defer pgConn.conn.Close() defer pgConn.conn.Close()
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer cleanupContext() defer pgConn.doneChanToDeadline.cleanup()
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil { if err != nil {
@@ -463,38 +461,6 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
return err return err
} }
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to
// call multiple times.
func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) {
if ctx.Done() != nil {
deadlineWasSet := false
doneChan := make(chan struct{})
go func() {
select {
case <-ctx.Done():
conn.SetDeadline(deadlineTime)
deadlineWasSet = true
<-doneChan
case <-doneChan:
}
}()
finished := false
return func() {
if !finished {
doneChan <- struct{}{}
if deadlineWasSet {
conn.SetDeadline(time.Time{})
}
finished = true
}
}
}
return func() {}
}
type PreparedStatementDescription struct { type PreparedStatementDescription struct {
Name string Name string
SQL string SQL string
@@ -512,8 +478,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
return nil, ctx.Err() return nil, ctx.Err()
default: default:
} }
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer cleanupContextDeadline() defer pgConn.doneChanToDeadline.cleanup()
buf := pgConn.wbuf buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
@@ -599,8 +565,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
} }
defer cancelConn.Close() defer cancelConn.Close()
cleanupContext := contextDoneToConnDeadline(ctx, cancelConn) var doneChanToDeadline chanToSetDeadline
defer cleanupContext() doneChanToDeadline.start(ctx.Done(), cancelConn)
defer doneChanToDeadline.cleanup()
buf := make([]byte, 16) buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[0:4], 16)
@@ -624,16 +591,16 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// received. // received.
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
pgConn.lock() pgConn.lock()
defer pgConn.unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
pgConn.unlock()
return ctx.Err() return ctx.Err()
default: default:
} }
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline() pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.unlock() defer pgConn.doneChanToDeadline.cleanup()
for { for {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
@@ -657,9 +624,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.lock() pgConn.lock()
pgConn.multiResultReader = MultiResultReader{ pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn, pgConn: pgConn,
ctx: ctx, ctx: ctx,
cleanupContextDeadline: func() {},
} }
multiResult := &pgConn.multiResultReader multiResult := &pgConn.multiResultReader
@@ -671,7 +637,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
return multiResult return multiResult
default: default:
} }
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
buf := pgConn.wbuf buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf) buf = (&pgproto3.Query{String: sql}).Encode(buf)
@@ -679,7 +645,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
_, err := pgConn.conn.Write(buf) _, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
multiResult.cleanupContextDeadline() pgConn.doneChanToDeadline.cleanup()
multiResult.closed = true multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err) multiResult.err = preferContextOverNetTimeoutError(ctx, err)
pgConn.unlock() pgConn.unlock()
@@ -753,9 +719,8 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
pgConn.lock() pgConn.lock()
pgConn.resultReader = ResultReader{ pgConn.resultReader = ResultReader{
pgConn: pgConn, pgConn: pgConn,
ctx: ctx, ctx: ctx,
cleanupContextDeadline: func() {},
} }
result := &pgConn.resultReader result := &pgConn.resultReader
@@ -774,7 +739,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result return result
default: default:
} }
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
return result return result
} }
@@ -788,7 +753,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
result.concludeCommand(nil, err) result.concludeCommand(nil, err)
result.cleanupContextDeadline() pgConn.doneChanToDeadline.cleanup()
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
} }
@@ -804,8 +769,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
return nil, ctx.Err() return nil, ctx.Err()
default: default:
} }
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer cleanupContextDeadline() defer pgConn.doneChanToDeadline.cleanup()
// Send copy to command // Send copy to command
buf := pgConn.wbuf buf := pgConn.wbuf
@@ -861,8 +826,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
return nil, ctx.Err() return nil, ctx.Err()
default: default:
} }
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer cleanupContextDeadline() defer pgConn.doneChanToDeadline.cleanup()
// Send copy to command // Send copy to command
buf := pgConn.wbuf buf := pgConn.wbuf
@@ -967,9 +932,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct { type MultiResultReader struct {
pgConn *PgConn pgConn *PgConn
ctx context.Context ctx context.Context
cleanupContextDeadline func()
rr *ResultReader rr *ResultReader
@@ -993,7 +957,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
msg, err := mrr.pgConn.ReceiveMessage() msg, err := mrr.pgConn.ReceiveMessage()
if err != nil { if err != nil {
mrr.cleanupContextDeadline() mrr.pgConn.doneChanToDeadline.cleanup()
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
mrr.closed = true mrr.closed = true
mrr.pgConn.hardClose() mrr.pgConn.hardClose()
@@ -1002,7 +966,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
mrr.cleanupContextDeadline() mrr.pgConn.doneChanToDeadline.cleanup()
mrr.closed = true mrr.closed = true
mrr.pgConn.unlock() mrr.pgConn.unlock()
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
@@ -1023,11 +987,10 @@ func (mrr *MultiResultReader) NextResult() bool {
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.RowDescription: case *pgproto3.RowDescription:
mrr.pgConn.resultReader = ResultReader{ mrr.pgConn.resultReader = ResultReader{
pgConn: mrr.pgConn, pgConn: mrr.pgConn,
multiResultReader: mrr, multiResultReader: mrr,
ctx: mrr.ctx, ctx: mrr.ctx,
cleanupContextDeadline: func() {}, fieldDescriptions: msg.Fields,
fieldDescriptions: msg.Fields,
} }
mrr.rr = &mrr.pgConn.resultReader mrr.rr = &mrr.pgConn.resultReader
return true return true
@@ -1066,10 +1029,9 @@ func (mrr *MultiResultReader) Close() error {
// ResultReader is a reader for the result of a single query. // ResultReader is a reader for the result of a single query.
type ResultReader struct { type ResultReader struct {
pgConn *PgConn pgConn *PgConn
multiResultReader *MultiResultReader multiResultReader *MultiResultReader
ctx context.Context ctx context.Context
cleanupContextDeadline func()
fieldDescriptions []pgproto3.FieldDescription fieldDescriptions []pgproto3.FieldDescription
rowValues [][]byte rowValues [][]byte
@@ -1162,7 +1124,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
switch msg.(type) { switch msg.(type) {
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
rr.cleanupContextDeadline() rr.pgConn.doneChanToDeadline.cleanup()
rr.pgConn.unlock() rr.pgConn.unlock()
return rr.commandTag, rr.err return rr.commandTag, rr.err
} }
@@ -1181,7 +1143,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
if err != nil { if err != nil {
rr.concludeCommand(nil, err) rr.concludeCommand(nil, err)
rr.cleanupContextDeadline() rr.pgConn.doneChanToDeadline.cleanup()
rr.closed = true rr.closed = true
if rr.multiResultReader == nil { if rr.multiResultReader == nil {
rr.pgConn.hardClose() rr.pgConn.hardClose()
@@ -1238,9 +1200,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
pgConn.lock() pgConn.lock()
pgConn.multiResultReader = MultiResultReader{ pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn, pgConn: pgConn,
ctx: ctx, ctx: ctx,
cleanupContextDeadline: func() {},
} }
multiResult := &pgConn.multiResultReader multiResult := &pgConn.multiResultReader
@@ -1252,13 +1213,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
return multiResult return multiResult
default: default:
} }
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
_, err := pgConn.conn.Write(batch.buf) _, err := pgConn.conn.Write(batch.buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
multiResult.cleanupContextDeadline() pgConn.doneChanToDeadline.cleanup()
multiResult.closed = true multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err) multiResult.err = preferContextOverNetTimeoutError(ctx, err)
pgConn.unlock() pgConn.unlock()