2
0

Use 0-alloc pgproto3/v2

This commit is contained in:
Jack Christensen
2019-04-18 23:17:28 -05:00
parent 9d30dad837
commit 2383561e4d
5 changed files with 37 additions and 30 deletions
+1 -1
View File
@@ -21,7 +21,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/jackc/pgproto3" "github.com/jackc/pgproto3/v2"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
"golang.org/x/text/secure/precis" "golang.org/x/text/secure/precis"
) )
+1
View File
@@ -6,6 +6,7 @@ require (
github.com/jackc/pgio v1.0.0 github.com/jackc/pgio v1.0.0
github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgproto3 v1.1.0 github.com/jackc/pgproto3 v1.1.0
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf
github.com/pkg/errors v0.8.1 github.com/pkg/errors v0.8.1
github.com/stretchr/testify v1.3.0 github.com/stretchr/testify v1.3.0
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a
+2
View File
@@ -8,6 +8,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+28 -24
View File
@@ -1,6 +1,7 @@
package pgconn package pgconn
import ( import (
"bytes"
"context" "context"
"crypto/md5" "crypto/md5"
"crypto/tls" "crypto/tls"
@@ -17,7 +18,7 @@ import (
"time" "time"
"github.com/jackc/pgio" "github.com/jackc/pgio"
"github.com/jackc/pgproto3" "github.com/jackc/pgproto3/v2"
) )
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
@@ -436,20 +437,23 @@ func (pgConn *PgConn) ParameterStatus(key string) string {
} }
// CommandTag is the result of an Exec function // CommandTag is the result of an Exec function
type CommandTag string type CommandTag []byte
// RowsAffected returns the number of rows affected. If the CommandTag was not // RowsAffected returns the number of rows affected. If the CommandTag was not
// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. // for a row affecting command (e.g. "CREATE TABLE") then it returns 0.
func (ct CommandTag) RowsAffected() int64 { func (ct CommandTag) RowsAffected() int64 {
s := string(ct) idx := bytes.LastIndexByte([]byte(ct), ' ')
index := strings.LastIndex(s, " ") if idx == -1 {
if index == -1 {
return 0 return 0
} }
n, _ := strconv.ParseInt(s[index+1:], 10, 64) n, _ := strconv.ParseInt(string([]byte(ct)[idx+1:]), 10, 64)
return n return n
} }
func (ct CommandTag) String() string {
return string(ct)
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err. // true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error { func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
@@ -756,7 +760,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
result := &pgConn.resultReader result := &pgConn.resultReader
if len(paramValues) > math.MaxUint16 { if len(paramValues) > math.MaxUint16 {
result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
return result return result
@@ -764,7 +768,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
select { select {
case <-ctx.Done(): case <-ctx.Done():
result.concludeCommand("", ctx.Err()) result.concludeCommand(nil, ctx.Err())
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
return result return result
@@ -783,7 +787,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
_, err := pgConn.conn.Write(buf) _, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
result.concludeCommand("", err) result.concludeCommand(nil, err)
result.cleanupContextDeadline() result.cleanupContextDeadline()
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
@@ -797,7 +801,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
select { select {
case <-ctx.Done(): case <-ctx.Done():
pgConn.unlock() pgConn.unlock()
return "", ctx.Err() return nil, ctx.Err()
default: default:
} }
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
@@ -812,7 +816,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
pgConn.hardClose() pgConn.hardClose()
pgConn.unlock() pgConn.unlock()
return "", preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
// Read results // Read results
@@ -822,7 +826,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -831,7 +835,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
_, err := w.Write(msg.Data) _, err := w.Write(msg.Data)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return "", err return nil, err
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.unlock() pgConn.unlock()
@@ -854,7 +858,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
select { select {
case <-ctx.Done(): case <-ctx.Done():
return "", ctx.Err() return nil, ctx.Err()
default: default:
} }
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
@@ -867,7 +871,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err := pgConn.conn.Write(buf) _, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
// Read until copy in response or error. // Read until copy in response or error.
@@ -878,7 +882,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -908,7 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
} }
@@ -917,7 +921,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -939,7 +943,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
// Read results // Read results
@@ -947,7 +951,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -1145,7 +1149,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
for !rr.commandConcluded { for !rr.commandConcluded {
_, err := rr.receiveMessage() _, err := rr.receiveMessage()
if err != nil { if err != nil {
return "", rr.err return nil, rr.err
} }
} }
@@ -1153,7 +1157,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
for { for {
msg, err := rr.receiveMessage() msg, err := rr.receiveMessage()
if err != nil { if err != nil {
return "", rr.err return nil, rr.err
} }
switch msg.(type) { switch msg.(type) {
@@ -1176,7 +1180,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
} }
if err != nil { if err != nil {
rr.concludeCommand("", err) rr.concludeCommand(nil, err)
rr.cleanupContextDeadline() rr.cleanupContextDeadline()
rr.closed = true rr.closed = true
if rr.multiResultReader == nil { if rr.multiResultReader == nil {
@@ -1192,7 +1196,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
rr.concludeCommand(CommandTag(msg.CommandTag), nil) rr.concludeCommand(CommandTag(msg.CommandTag), nil)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
rr.concludeCommand("", errorResponseToPgError(msg)) rr.concludeCommand(nil, errorResponseToPgError(msg))
} }
return msg, nil return msg, nil
+5 -5
View File
@@ -475,7 +475,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
} }
assert.Equal(t, 0, rowCount) assert.Equal(t, 0, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, context.DeadlineExceeded, err)
assert.False(t, pgConn.IsAlive()) assert.False(t, pgConn.IsAlive())
@@ -601,7 +601,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
} }
assert.Equal(t, 0, rowCount) assert.Equal(t, 0, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, context.DeadlineExceeded, err)
assert.False(t, pgConn.IsAlive()) assert.False(t, pgConn.IsAlive())
} }
@@ -958,7 +958,7 @@ func TestConnCopyToCanceled(t *testing.T) {
defer cancel() defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, context.DeadlineExceeded, err)
assert.Equal(t, pgconn.CommandTag(""), res) assert.Equal(t, pgconn.CommandTag(nil), res)
assert.False(t, pgConn.IsAlive()) assert.False(t, pgConn.IsAlive())
} }
@@ -977,7 +977,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
require.Error(t, err) require.Error(t, err)
require.Equal(t, context.Canceled, err) require.Equal(t, context.Canceled, err)
assert.Equal(t, pgconn.CommandTag(""), res) assert.Equal(t, pgconn.CommandTag(nil), res)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@@ -1084,7 +1084,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err) require.Error(t, err)
require.Equal(t, context.Canceled, err) require.Equal(t, context.Canceled, err)
assert.Equal(t, pgconn.CommandTag(""), ct) assert.Equal(t, pgconn.CommandTag(nil), ct)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }