diff --git a/auth_scram.go b/auth_scram.go index d102d305..bdaf3e92 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -249,7 +249,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) - serverSignature := computeHMAC(serverKey[:], authMessage) + serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) base64.StdEncoding.Encode(buf, serverSignature) return buf diff --git a/benchmark_test.go b/benchmark_test.go index 51e11e24..8067c985 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -20,6 +20,7 @@ func BenchmarkConnect(b *testing.B) { } for _, bm := range benchmarks { + bm := bm b.Run(bm.name, func(b *testing.B) { connString := os.Getenv(bm.env) if connString == "" { @@ -54,12 +55,12 @@ func BenchmarkExec(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -101,12 +102,12 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -145,12 +146,12 @@ func BenchmarkExecPrepared(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -191,7 +192,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 0b491bf8..a1b3c863 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -87,6 +87,9 @@ func TestContextWatcherStress(t *testing.T) { if i%2 == 1 { cancel() } + + // To avoid context leak + cancel() } actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) diff --git a/pgconn.go b/pgconn.go index c51742ae..9e4f6253 100644 --- a/pgconn.go +++ b/pgconn.go @@ -241,16 +241,16 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { +func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = c.txPasswordMessage(c.Config.Password) + err = pgConn.txPasswordMessage(pgConn.Config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) - err = c.txPasswordMessage(digestedPassword) + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: - err = c.scramAuth(msg.SASLAuthMechanisms) + err = pgConn.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } @@ -514,11 +514,11 @@ readloop: func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: string(msg.Severity), + Severity: msg.Severity, Code: string(msg.Code), Message: string(msg.Message), Detail: string(msg.Detail), - Hint: string(msg.Hint), + Hint: msg.Hint, Position: msg.Position, InternalPosition: msg.InternalPosition, InternalQuery: string(msg.InternalQuery), @@ -527,7 +527,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { TableName: string(msg.TableName), ColumnName: string(msg.ColumnName), DataTypeName: string(msg.DataTypeName), - ConstraintName: string(msg.ConstraintName), + ConstraintName: msg.ConstraintName, File: string(msg.File), Line: msg.Line, Routine: string(msg.Routine), @@ -919,7 +919,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { - copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + copyFail := &pgproto3.CopyFail{Message: readErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) diff --git a/pgconn_test.go b/pgconn_test.go index 310b387b..4389fe99 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -37,6 +37,7 @@ func TestConnect(t *testing.T) { } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { connString := os.Getenv(tt.env) if connString == "" { @@ -194,13 +195,13 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { dialCount := 0 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount += 1 + dialCount++ return net.Dial(network, address) } acceptConnCount := 0 config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount += 1 + acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") } @@ -302,7 +303,7 @@ func TestConnExecEmpty(t *testing.T) { resultCount := 0 for multiResult.NextResult() { - resultCount += 1 + resultCount++ multiResult.ResultReader().Close() } assert.Equal(t, 0, resultCount) @@ -753,12 +754,12 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) - results, err = mrr.ReadAll() + results, err := mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err)