Skip to content

Commit 98e1327

Browse files
Test head-wise compression (#103)
1 parent fb93b31 commit 98e1327

File tree

5 files changed

+62
-60
lines changed

5 files changed

+62
-60
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[flake8]
2+
exclude = .venv,venv,.git,__pycache__,build,dist, .mypy_cache
23
max-line-length = 120
34
per-file-ignores =
45
__init__.py:F401

evaluation/evaluate_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
output_dir: "./results"
55

6-
model: "meta-llama/LLama-3.1-8B-Instruct"
6+
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
77
dataset: "ruler" # see DATASET_REGISTRY in evaluate_registry.py
88
data_dir: "4096" # Subdirectory of the dataset (if applicable)
99

tests/default_presses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,6 @@ def load_attention_pattern(model):
7070
{"cls": KeyDiffPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
7171
{
7272
"cls": KVzipPress,
73-
"kwargs": [{"compression_ratio": 0.5}, {"compression_ratio": 0.8}],
73+
"kwargs": [{"compression_ratio": 0.5, "layerwise": False}, {"compression_ratio": 0.8, "layerwise": True}],
7474
},
7575
]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import pytest
4+
import torch
5+
from transformers import DynamicCache
6+
7+
from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress
8+
from tests.fixtures import unit_test_model # noqa: F401
9+
10+
11+
def compute_masked_percentage(module, batch_size, num_key_value_heads, seq_len):
12+
"""
13+
Compute the percentage of masked indices from module.masked_key_indices.
14+
"""
15+
if module.masked_key_indices is None:
16+
return 0.0
17+
18+
batch_indices, head_indices, seq_indices = module.masked_key_indices
19+
num_masked = len(batch_indices)
20+
total_positions = batch_size * num_key_value_heads * seq_len
21+
masked_percentage = num_masked / total_positions
22+
return masked_percentage
23+
24+
25+
@pytest.mark.parametrize("wrapper_press", [AdaKVPress, CriticalAdaKVPress])
26+
@pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8])
27+
def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ratio): # noqa: F811
28+
p = KnormPress(compression_ratio=compression_ratio)
29+
press = wrapper_press(press=p)
30+
with press(unit_test_model):
31+
input_ids = torch.randint(0, 1024, (1, 128))
32+
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
33+
34+
assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None
35+
headwise_compression_ratio = 0.0
36+
for layer in unit_test_model.model.layers:
37+
cr = compute_masked_percentage(layer.self_attn, 1, unit_test_model.config.num_key_value_heads, 128)
38+
headwise_compression_ratio += cr
39+
cumulative_compression_ratio = headwise_compression_ratio / len(unit_test_model.model.layers)
40+
assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2 # tolerate small differences
41+
42+
43+
# Only for KVzipPress, since it's the only non-wrapper press with head compression (apart from Duo)
44+
@pytest.mark.parametrize("press", [KVzipPress])
45+
@pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8])
46+
@pytest.mark.parametrize("layerwise", [True, False])
47+
def test_head_compression(unit_test_model, press, compression_ratio, layerwise): # noqa: F811
48+
press = KVzipPress(compression_ratio=compression_ratio, layerwise=layerwise)
49+
with press(unit_test_model):
50+
input_ids = torch.randint(0, 1024, (1, 128))
51+
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
52+
53+
assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None
54+
headwise_compression_ratio = 0.0
55+
for layer in unit_test_model.model.layers:
56+
cr = compute_masked_percentage(layer.self_attn, 1, unit_test_model.config.num_key_value_heads, 128)
57+
headwise_compression_ratio += cr
58+
cumulative_compression_ratio = headwise_compression_ratio / len(unit_test_model.model.layers)
59+
assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2 # tolerate small differences

tests/presses/test_kvzip_press.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)