diff --git a/.gitignore b/.gitignore index 7a6353d6..6eb9d442 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .envrc +vendor/ \ No newline at end of file diff --git a/auth_scram.go b/auth_scram.go index d102d305..bdaf3e92 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -249,7 +249,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) - serverSignature := computeHMAC(serverKey[:], authMessage) + serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) base64.StdEncoding.Encode(buf, serverSignature) return buf diff --git a/benchmark_test.go b/benchmark_test.go index 51e11e24..8067c985 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -20,6 +20,7 @@ func BenchmarkConnect(b *testing.B) { } for _, bm := range benchmarks { + bm := bm b.Run(bm.name, func(b *testing.B) { connString := os.Getenv(bm.env) if connString == "" { @@ -54,12 +55,12 @@ func BenchmarkExec(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -101,12 +102,12 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -145,12 +146,12 @@ func BenchmarkExecPrepared(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -191,7 +192,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } diff --git a/go.mod b/go.mod index 9401dce8..b1c84049 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 1b6862a0..50dfc2fd 100644 --- a/go.sum +++ b/go.sum @@ -2,12 +2,16 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pgconn.go b/pgconn.go index c51742ae..9e4f6253 100644 --- a/pgconn.go +++ b/pgconn.go @@ -241,16 +241,16 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { +func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = c.txPasswordMessage(c.Config.Password) + err = pgConn.txPasswordMessage(pgConn.Config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) - err = c.txPasswordMessage(digestedPassword) + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: - err = c.scramAuth(msg.SASLAuthMechanisms) + err = pgConn.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } @@ -514,11 +514,11 @@ readloop: func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: string(msg.Severity), + Severity: msg.Severity, Code: string(msg.Code), Message: string(msg.Message), Detail: string(msg.Detail), - Hint: string(msg.Hint), + Hint: msg.Hint, Position: msg.Position, InternalPosition: msg.InternalPosition, InternalQuery: string(msg.InternalQuery), @@ -527,7 +527,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { TableName: string(msg.TableName), ColumnName: string(msg.ColumnName), DataTypeName: string(msg.DataTypeName), - ConstraintName: string(msg.ConstraintName), + ConstraintName: msg.ConstraintName, File: string(msg.File), Line: msg.Line, Routine: string(msg.Routine), @@ -919,7 +919,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { - copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + copyFail := &pgproto3.CopyFail{Message: readErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) diff --git a/pgconn_test.go b/pgconn_test.go index 310b387b..4389fe99 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -37,6 +37,7 @@ func TestConnect(t *testing.T) { } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { connString := os.Getenv(tt.env) if connString == "" { @@ -194,13 +195,13 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { dialCount := 0 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount += 1 + dialCount++ return net.Dial(network, address) } acceptConnCount := 0 config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount += 1 + acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") } @@ -302,7 +303,7 @@ func TestConnExecEmpty(t *testing.T) { resultCount := 0 for multiResult.NextResult() { - resultCount += 1 + resultCount++ multiResult.ResultReader().Close() } assert.Equal(t, 0, resultCount) @@ -753,12 +754,12 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) - results, err = mrr.ReadAll() + results, err := mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err)