2
0

0 alloc context to deadline

This commit is contained in:
Jack Christensen
2019-04-19 14:24:51 -05:00
parent 2383561e4d
commit 16412e56e2
2 changed files with 95 additions and 83 deletions
+51
View File
@@ -0,0 +1,51 @@
package pgconn
import (
"time"
)
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
type setDeadliner interface {
SetDeadline(time.Time) error
}
type chanToSetDeadline struct {
cleanupChan chan struct{}
conn setDeadliner
deadlineWasSet bool
cleanupComplete bool
}
func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) {
if this.cleanupChan == nil {
this.cleanupChan = make(chan struct{})
}
this.conn = conn
this.deadlineWasSet = false
this.cleanupComplete = false
if doneChan != nil {
go func() {
select {
case <-doneChan:
conn.SetDeadline(deadlineTime)
this.deadlineWasSet = true
<-this.cleanupChan
case <-this.cleanupChan:
}
}()
} else {
this.cleanupComplete = true
}
}
func (this *chanToSetDeadline) cleanup() {
if !this.cleanupComplete {
this.cleanupChan <- struct{}{}
if this.deadlineWasSet {
this.conn.SetDeadline(time.Time{})
}
this.cleanupComplete = true
}
}
+44 -83
View File
@@ -15,14 +15,11 @@ import (
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgio"
"github.com/jackc/pgproto3/v2"
)
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
@@ -100,9 +97,10 @@ type PgConn struct {
bufferingReceiveErr error
// Reusable / preallocated resources
wbuf []byte // write buffer
resultReader ResultReader
multiResultReader MultiResultReader
wbuf []byte // write buffer
resultReader ResultReader
multiResultReader MultiResultReader
doneChanToDeadline chanToSetDeadline
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
@@ -382,8 +380,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
defer pgConn.conn.Close()
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContext()
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil {
@@ -463,38 +461,6 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
return err
}
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to
// call multiple times.
func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) {
if ctx.Done() != nil {
deadlineWasSet := false
doneChan := make(chan struct{})
go func() {
select {
case <-ctx.Done():
conn.SetDeadline(deadlineTime)
deadlineWasSet = true
<-doneChan
case <-doneChan:
}
}()
finished := false
return func() {
if !finished {
doneChan <- struct{}{}
if deadlineWasSet {
conn.SetDeadline(time.Time{})
}
finished = true
}
}
}
return func() {}
}
type PreparedStatementDescription struct {
Name string
SQL string
@@ -512,8 +478,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
return nil, ctx.Err()
default:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
@@ -599,8 +565,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
}
defer cancelConn.Close()
cleanupContext := contextDoneToConnDeadline(ctx, cancelConn)
defer cleanupContext()
var doneChanToDeadline chanToSetDeadline
doneChanToDeadline.start(ctx.Done(), cancelConn)
defer doneChanToDeadline.cleanup()
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
@@ -624,16 +591,16 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// received.
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
pgConn.lock()
defer pgConn.unlock()
select {
case <-ctx.Done():
pgConn.unlock()
return ctx.Err()
default:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
defer pgConn.unlock()
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
for {
msg, err := pgConn.ReceiveMessage()
@@ -657,9 +624,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.lock()
pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn,
ctx: ctx,
cleanupContextDeadline: func() {},
pgConn: pgConn,
ctx: ctx,
}
multiResult := &pgConn.multiResultReader
@@ -671,7 +637,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
return multiResult
default:
}
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
@@ -679,7 +645,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
multiResult.cleanupContextDeadline()
pgConn.doneChanToDeadline.cleanup()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
pgConn.unlock()
@@ -753,9 +719,8 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
pgConn.lock()
pgConn.resultReader = ResultReader{
pgConn: pgConn,
ctx: ctx,
cleanupContextDeadline: func() {},
pgConn: pgConn,
ctx: ctx,
}
result := &pgConn.resultReader
@@ -774,7 +739,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result
default:
}
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
return result
}
@@ -788,7 +753,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
if err != nil {
pgConn.hardClose()
result.concludeCommand(nil, err)
result.cleanupContextDeadline()
pgConn.doneChanToDeadline.cleanup()
result.closed = true
pgConn.unlock()
}
@@ -804,8 +769,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
return nil, ctx.Err()
default:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
// Send copy to command
buf := pgConn.wbuf
@@ -861,8 +826,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
return nil, ctx.Err()
default:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
// Send copy to command
buf := pgConn.wbuf
@@ -967,9 +932,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
pgConn *PgConn
ctx context.Context
cleanupContextDeadline func()
pgConn *PgConn
ctx context.Context
rr *ResultReader
@@ -993,7 +957,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
msg, err := mrr.pgConn.ReceiveMessage()
if err != nil {
mrr.cleanupContextDeadline()
mrr.pgConn.doneChanToDeadline.cleanup()
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
mrr.closed = true
mrr.pgConn.hardClose()
@@ -1002,7 +966,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
mrr.cleanupContextDeadline()
mrr.pgConn.doneChanToDeadline.cleanup()
mrr.closed = true
mrr.pgConn.unlock()
case *pgproto3.ErrorResponse:
@@ -1023,11 +987,10 @@ func (mrr *MultiResultReader) NextResult() bool {
switch msg := msg.(type) {
case *pgproto3.RowDescription:
mrr.pgConn.resultReader = ResultReader{
pgConn: mrr.pgConn,
multiResultReader: mrr,
ctx: mrr.ctx,
cleanupContextDeadline: func() {},
fieldDescriptions: msg.Fields,
pgConn: mrr.pgConn,
multiResultReader: mrr,
ctx: mrr.ctx,
fieldDescriptions: msg.Fields,
}
mrr.rr = &mrr.pgConn.resultReader
return true
@@ -1066,10 +1029,9 @@ func (mrr *MultiResultReader) Close() error {
// ResultReader is a reader for the result of a single query.
type ResultReader struct {
pgConn *PgConn
multiResultReader *MultiResultReader
ctx context.Context
cleanupContextDeadline func()
pgConn *PgConn
multiResultReader *MultiResultReader
ctx context.Context
fieldDescriptions []pgproto3.FieldDescription
rowValues [][]byte
@@ -1162,7 +1124,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
switch msg.(type) {
case *pgproto3.ReadyForQuery:
rr.cleanupContextDeadline()
rr.pgConn.doneChanToDeadline.cleanup()
rr.pgConn.unlock()
return rr.commandTag, rr.err
}
@@ -1181,7 +1143,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
if err != nil {
rr.concludeCommand(nil, err)
rr.cleanupContextDeadline()
rr.pgConn.doneChanToDeadline.cleanup()
rr.closed = true
if rr.multiResultReader == nil {
rr.pgConn.hardClose()
@@ -1238,9 +1200,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
pgConn.lock()
pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn,
ctx: ctx,
cleanupContextDeadline: func() {},
pgConn: pgConn,
ctx: ctx,
}
multiResult := &pgConn.multiResultReader
@@ -1252,13 +1213,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
return multiResult
default:
}
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
pgConn.hardClose()
multiResult.cleanupContextDeadline()
pgConn.doneChanToDeadline.cleanup()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
pgConn.unlock()