From 2383561e4d1bbf50fde6a214aa04f296764e265f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 23:17:28 -0500 Subject: [PATCH] Use 0-alloc pgproto3/v2 --- auth_scram.go | 2 +- go.mod | 1 + go.sum | 2 ++ pgconn.go | 52 +++++++++++++++++++++++++++----------------------- pgconn_test.go | 10 +++++----- 5 files changed, 37 insertions(+), 30 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index b78a236a..50fbff40 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -21,7 +21,7 @@ import ( "fmt" "strconv" - "github.com/jackc/pgproto3" + "github.com/jackc/pgproto3/v2" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" ) diff --git a/go.mod b/go.mod index 09b4471d..232df737 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3 v1.1.0 + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a diff --git a/go.sum b/go.sum index 8872aac1..8e0e2c9f 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pgconn.go b/pgconn.go index 7e8909ea..7bc93435 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1,6 +1,7 @@ package pgconn import ( + "bytes" "context" "crypto/md5" "crypto/tls" @@ -17,7 +18,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/jackc/pgproto3" + "github.com/jackc/pgproto3/v2" ) var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) @@ -436,20 +437,23 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag string +type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { - s := string(ct) - index := strings.LastIndex(s, " ") - if index == -1 { + idx := bytes.LastIndexByte([]byte(ct), ' ') + if idx == -1 { return 0 } - n, _ := strconv.ParseInt(s[index+1:], 10, 64) + n, _ := strconv.ParseInt(string([]byte(ct)[idx+1:]), 10, 64) return n } +func (ct CommandTag) String() string { + return string(ct) +} + // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -756,7 +760,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if len(paramValues) > math.MaxUint16 { - result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result @@ -764,7 +768,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand("", ctx.Err()) + result.concludeCommand(nil, ctx.Err()) result.closed = true pgConn.unlock() return result @@ -783,7 +787,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - result.concludeCommand("", err) + result.concludeCommand(nil, err) result.cleanupContextDeadline() result.closed = true pgConn.unlock() @@ -797,7 +801,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): pgConn.unlock() - return "", ctx.Err() + return nil, ctx.Err() default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -812,7 +816,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm pgConn.hardClose() pgConn.unlock() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read results @@ -822,7 +826,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -831,7 +835,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := w.Write(msg.Data) if err != nil { pgConn.hardClose() - return "", err + return nil, err } case *pgproto3.ReadyForQuery: pgConn.unlock() @@ -854,7 +858,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co select { case <-ctx.Done(): - return "", ctx.Err() + return nil, ctx.Err() default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -867,7 +871,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read until copy in response or error. @@ -878,7 +882,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -908,7 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } } @@ -917,7 +921,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -939,7 +943,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read results @@ -947,7 +951,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1145,7 +1149,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { - return "", rr.err + return nil, rr.err } } @@ -1153,7 +1157,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for { msg, err := rr.receiveMessage() if err != nil { - return "", rr.err + return nil, rr.err } switch msg.(type) { @@ -1176,7 +1180,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { - rr.concludeCommand("", err) + rr.concludeCommand(nil, err) rr.cleanupContextDeadline() rr.closed = true if rr.multiResultReader == nil { @@ -1192,7 +1196,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand("", errorResponseToPgError(msg)) + rr.concludeCommand(nil, errorResponseToPgError(msg)) } return msg, nil diff --git a/pgconn_test.go b/pgconn_test.go index 3be61be8..2b1e68a3 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -475,7 +475,7 @@ func TestConnExecParamsCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(""), commandTag) + assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) assert.False(t, pgConn.IsAlive()) @@ -601,7 +601,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(""), commandTag) + assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) assert.False(t, pgConn.IsAlive()) } @@ -958,7 +958,7 @@ func TestConnCopyToCanceled(t *testing.T) { defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") assert.Equal(t, context.DeadlineExceeded, err) - assert.Equal(t, pgconn.CommandTag(""), res) + assert.Equal(t, pgconn.CommandTag(nil), res) assert.False(t, pgConn.IsAlive()) } @@ -977,7 +977,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) require.Equal(t, context.Canceled, err) - assert.Equal(t, pgconn.CommandTag(""), res) + assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) } @@ -1084,7 +1084,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) require.Equal(t, context.Canceled, err) - assert.Equal(t, pgconn.CommandTag(""), ct) + assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) }