2
0

Merge branch 'bakape-master'

This commit is contained in:
Jack Christensen
2020-01-11 18:10:33 -06:00
3 changed files with 185 additions and 134 deletions
+1
View File
@@ -1,2 +1,3 @@
.envrc .envrc
vendor/ vendor/
.vscode
+38 -9
View File
@@ -39,16 +39,28 @@ func BenchmarkConnect(b *testing.B) {
} }
func BenchmarkExec(b *testing.B) { func BenchmarkExec(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
benchmarks := []struct {
name string
ctx context.Context
}{
// Using an empty context other than context.Background() to compare
// performance
{"background context", context.Background()},
{"empty context", context.TODO()},
}
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err) require.Nil(b, err)
defer closeConn(b, conn) defer closeConn(b, conn)
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mrr := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
for mrr.NextResult() { for mrr.NextResult() {
rr := mrr.ResultReader() rr := mrr.ResultReader()
@@ -80,6 +92,8 @@ func BenchmarkExec(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
} }
})
}
} }
func BenchmarkExecPossibleToCancel(b *testing.B) { func BenchmarkExecPossibleToCancel(b *testing.B) {
@@ -130,19 +144,32 @@ func BenchmarkExecPossibleToCancel(b *testing.B) {
} }
func BenchmarkExecPrepared(b *testing.B) { func BenchmarkExecPrepared(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
benchmarks := []struct {
name string
ctx context.Context
}{
// Using an empty context other than context.Background() to compare
// performance
{"background context", context.Background()},
{"empty context", context.TODO()},
}
for _, bm := range benchmarks {
bm := bm
b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.Nil(b, err) require.Nil(b, err)
defer closeConn(b, conn) defer closeConn(b, conn)
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
require.Nil(b, err) require.Nil(b, err)
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rr := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil)
rowCount := 0 rowCount := 0
for rr.NextRow() { for rr.NextRow() {
@@ -165,6 +192,8 @@ func BenchmarkExecPrepared(b *testing.B) {
b.Fatalf("unexpected rowCount: %d", rowCount) b.Fatalf("unexpected rowCount: %d", rowCount)
} }
} }
})
}
} }
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
+25 -4
View File
@@ -362,6 +362,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()} return &contextAlreadyDoneError{err: ctx.Err()}
@@ -369,6 +370,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
@@ -392,6 +394,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, &contextAlreadyDoneError{err: ctx.Err()}
@@ -399,6 +402,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
@@ -489,8 +493,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
defer pgConn.conn.Close() defer pgConn.conn.Close()
if ctx != context.Background() {
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
// Ignore any errors sending Terminate message and waiting for server to close connection. // Ignore any errors sending Terminate message and waiting for server to close connection.
// This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully
@@ -600,6 +606,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, &contextAlreadyDoneError{err: ctx.Err()}
@@ -607,6 +614,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
buf := pgConn.wbuf buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
@@ -693,12 +701,14 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
} }
defer cancelConn.Close() defer cancelConn.Close()
if ctx != context.Background() {
contextWatcher := ctxwatch.NewContextWatcher( contextWatcher := ctxwatch.NewContextWatcher(
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) }, func() { cancelConn.SetDeadline(time.Time{}) },
) )
contextWatcher.Watch(ctx) contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch() defer contextWatcher.Unwatch()
}
buf := make([]byte, 16) buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[0:4], 16)
@@ -726,6 +736,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@@ -734,6 +745,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
for { for {
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
@@ -766,7 +778,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
ctx: ctx, ctx: ctx,
} }
multiResult := &pgConn.multiResultReader multiResult := &pgConn.multiResultReader
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
@@ -776,6 +788,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
}
buf := pgConn.wbuf buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf) buf = (&pgproto3.Query{String: sql}).Encode(buf)
@@ -822,7 +835,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(ctx, buf, result) pgConn.execExtendedSuffix(buf, result)
return result return result
} }
@@ -848,7 +861,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
buf := pgConn.wbuf buf := pgConn.wbuf
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(ctx, buf, result) pgConn.execExtendedSuffix(buf, result)
return result return result
} }
@@ -873,6 +886,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result return result
} }
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
@@ -882,11 +896,12 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
}
return result return result
} }
func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf)
@@ -907,6 +922,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
return nil, err return nil, err
} }
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
pgConn.unlock() pgConn.unlock()
@@ -915,6 +931,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
// Send copy to command // Send copy to command
buf := pgConn.wbuf buf := pgConn.wbuf
@@ -966,6 +983,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, &contextAlreadyDoneError{err: ctx.Err()}
@@ -973,6 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
// Send copy to command // Send copy to command
buf := pgConn.wbuf buf := pgConn.wbuf
@@ -1359,6 +1378,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
} }
multiResult := &pgConn.multiResultReader multiResult := &pgConn.multiResultReader
if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
@@ -1368,6 +1388,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
}
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)