Skip to content

Commit 771e158

Browse files
committed
fix format
Signed-off-by: Dominic789654 <xliu29@gmu.edu>
1 parent ad1989d commit 771e158

File tree

8 files changed

+21
-20
lines changed

8 files changed

+21
-20
lines changed

evaluation/evaluate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer
1818

1919
from kvpress import (
20-
CriticalKVPress,
21-
CriticalAdaKVPress,
2220
AdaKVPress,
21+
ChunkKVPress,
22+
CriticalAdaKVPress,
23+
CriticalKVPress,
24+
DuoAttentionPress,
2325
ExpectedAttentionPress,
2426
KnormPress,
2527
ObservedAttentionPress,
@@ -28,8 +30,6 @@
2830
StreamingLLMPress,
2931
ThinKPress,
3032
TOVAPress,
31-
DuoAttentionPress,
32-
ChunkKVPress
3333
)
3434

3535
logger = logging.getLogger(__name__)

kvpress/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from kvpress.presses.chunk_press import ChunkPress
1010
from kvpress.presses.chunkkv_press import ChunkKVPress
1111
from kvpress.presses.composed_press import ComposedPress
12+
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
13+
from kvpress.presses.duo_attention_press import DuoAttentionPress
1214
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
1315
from kvpress.presses.key_rerotation_press import KeyRerotationPress
1416
from kvpress.presses.knorm_press import KnormPress
@@ -21,8 +23,6 @@
2123
from kvpress.presses.streaming_llm_press import StreamingLLMPress
2224
from kvpress.presses.think_press import ThinKPress
2325
from kvpress.presses.tova_press import TOVAPress
24-
from kvpress.presses.criticalkv_press import CriticalKVPress, CriticalAdaKVPress
25-
from kvpress.presses.duo_attention_press import DuoAttentionPress
2626

2727
# Patch the attention functions to support head-wise compression
2828
patch_attention_functions()

kvpress/presses/chunkkv_press.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def compress(
5151
assert attentions is None, "ChunkPress does not support attentions."
5252

5353
kv_len = keys.shape[2]
54-
54+
5555
# 1. Calculate global scores first
5656
global_scores = self.press.score(
5757
module,

kvpress/presses/criticalkv_press.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from transformers.models.llama.modeling_llama import repeat_kv
99

1010
from kvpress.presses.base_press import BasePress
11-
from kvpress.presses.scorer_press import ScorerPress
1211
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
12+
from kvpress.presses.scorer_press import ScorerPress
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -49,7 +49,7 @@ def vwl1norm(values, module):
4949
# Future kernel fusion optimization could eliminate this intermediate variables to enhance performance.
5050
head_WoV_norm_list = []
5151
for head in range(V.size(1)):
52-
head_WoV = V[: , head, : , ...].matmul(Wo[head, ...].unsqueeze(0))
52+
head_WoV = V[:, head, :, ...].matmul(Wo[head, ...].unsqueeze(0))
5353
head_WoV_norm = torch.norm(head_WoV, p=1, dim=-1)
5454
head_WoV_norm_list.append(head_WoV_norm)
5555

kvpress/presses/duo_attention_press.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from io import StringIO
5-
from dataclasses import dataclass, field
64
from contextlib import contextmanager
5+
from dataclasses import dataclass, field
6+
from io import StringIO
77

8-
import torch
9-
import requests # type: ignore[import-untyped]
108
import numpy as np
9+
import requests # type: ignore[import-untyped]
10+
import torch
1111

1212
from kvpress.presses.base_press import BasePress
1313

14-
1514
PATTERNS_DICT = {
1615
"togethercomputer/Llama-2-7B-32K-Instruct": "Llama-2-7B-32K-Instruct/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
1716
"gradientai//Llama-3-8B-Instruct-Gradient-1048k": "Llama-3-8B-Instruct-Gradient-1048k/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501

tests/default_presses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from kvpress import (
7+
DuoAttentionPress,
78
ExpectedAttentionPress,
89
KnormPress,
910
RandomPress,
@@ -12,7 +13,6 @@
1213
StreamingLLMPress,
1314
ThinKPress,
1415
TOVAPress,
15-
DuoAttentionPress,
1616
)
1717

1818

tests/presses/test_duo_attention_press.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from kvpress.presses.duo_attention_press import DuoAttentionPress, PATTERNS_DICT
1+
from kvpress.presses.duo_attention_press import PATTERNS_DICT, DuoAttentionPress
22

33

44
def test_load_attention_pattern():

tests/presses/test_presses.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from transformers import DynamicCache
99

1010
from kvpress import (
11-
CriticalKVPress,
12-
CriticalAdaKVPress,
1311
AdaKVPress,
1412
ChunkKVPress,
1513
ChunkPress,
1614
ComposedPress,
15+
CriticalAdaKVPress,
16+
CriticalKVPress,
1717
KeyRerotationPress,
1818
KnormPress,
1919
ObservedAttentionPress,
@@ -57,8 +57,10 @@ def test_chunkkv_press(unit_test_model): # noqa: F811
5757

5858

5959
@pytest.mark.parametrize("press_dict", default_presses)
60-
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress,
61-
CriticalKVPress, CriticalAdaKVPress])
60+
@pytest.mark.parametrize(
61+
"wrapper_press",
62+
[None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress, CriticalKVPress, CriticalAdaKVPress],
63+
)
6264
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
6365
cls = press_dict["cls"]
6466
for kwargs in press_dict["kwargs"]:

0 commit comments

Comments
 (0)