forked from NVIDIA/kvpress
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
235 lines (209 loc) · 8.91 KB
/
evaluate.py
File metadata and controls
235 lines (209 loc) · 8.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
from pathlib import Path
from typing import Optional
import torch
from datasets import load_dataset
from fire import Fire
from infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from longbench.calculate_metrics import calculate_metrics as longbench_scorer
from longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
from longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
from loogle.calculate_metrics import calculate_metrics as loogle_scorer
from ruler.calculate_metrics import calculate_metrics as ruler_scorer
from tqdm import tqdm
from transformers import pipeline
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer
from kvpress import (
AdaKVPress,
ChunkKVPress,
ComposedPress,
CriticalAdaKVPress,
CriticalKVPress,
DuoAttentionPress,
ExpectedAttentionPress,
FinchPress,
KnormPress,
ObservedAttentionPress,
QFilterPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)
logger = logging.getLogger(__name__)
DATASET_DICT = {
"loogle": "simonjegou/loogle",
"ruler": "simonjegou/ruler",
"zero_scrolls": "simonjegou/zero_scrolls",
"infinitebench": "MaxJeblick/InfiniteBench",
"longbench": "Xnhyacinth/LongBench",
"longbench-e": "Xnhyacinth/LongBench",
"longbench-v2": "Xnhyacinth/LongBench-v2",
}
SCORER_DICT = {
"loogle": loogle_scorer,
"ruler": ruler_scorer,
"zero_scrolls": zero_scrolls_scorer,
"infinitebench": infinite_bench_scorer,
"longbench": longbench_scorer,
"longbench-e": longbench_scorer_e,
"longbench-v2": longbenchv2_scorer,
}
PRESS_DICT = {
"criti_adasnapkv": CriticalAdaKVPress(SnapKVPress()),
"criti_ada_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
"criti_snapkv": CriticalKVPress(SnapKVPress()),
"criti_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
"adasnapkv": AdaKVPress(SnapKVPress()),
"ada_expected_attention": AdaKVPress(ExpectedAttentionPress()),
"expected_attention": ExpectedAttentionPress(),
"ada_expected_attention_e2": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
"knorm": KnormPress(),
"observed_attention": ObservedAttentionPress(),
"random": RandomPress(),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"think": ThinKPress(),
"tova": TOVAPress(),
"duo_attention": DuoAttentionPress(),
"finch": FinchPress(),
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
"qfilter": QFilterPress(),
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
}
def evaluate(
dataset: str,
data_dir: Optional[str] = None,
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
device: Optional[str] = None,
press_name: str = "expected_attention",
compression_ratio: float = 0.1,
fraction: float = 1.0,
max_new_tokens: Optional[int] = None,
max_context_length: Optional[int] = None,
compress_questions: bool = False,
key_channel_compression_ratio: float = 0.5,
):
"""
Evaluate a model on a dataset using a press and save the results
Parameters
----------
dataset : str
Dataset to evaluate
data_dir : str, optional
Subdirectory of the dataset to evaluate, by default None
model : str, optional
Model to use, by default "meta-llama/Meta-Llama-3.1-8B-Instruct"
device : str, optional
Model device, by default cuda:0 if available else cpu. For multi-GPU use "auto"
press_name : str, optional
Press to use (see PRESS_DICT), by default "expected_attention"
compression_ratio : float, optional
Compression ratio for the press, by default 0.1
max_new_tokens : int, optional
Maximum number of new tokens to generate, by default use the default for the task (recommended)
fraction : float, optional
Fraction of the dataset to evaluate, by default 1.0
max_context_length : int, optional
Maximum number of tokens to use in the context. By default will use the maximum length supported by the model.
compress_questions : bool, optional
Whether to compress the questions as well, by default False
key_channel_compression_ratio : float, optional
key Channel Compression ratio for the channel press, by default 0.5
"""
assert dataset in DATASET_DICT, f"No dataset found for {dataset}"
assert dataset in SCORER_DICT, f"No scorer found for {dataset}"
data_dir = str(data_dir) if data_dir else None
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
save_dir = Path(__file__).parent / "results"
save_dir.mkdir(exist_ok=True)
save_filename = save_dir / (
"__".join([dataset, data_dir if data_dir else "", model.replace("/", "--"), press_name, str(compression_ratio)])
+ ".csv"
)
if save_filename.exists():
logger.warning(f"Results already exist at {save_filename}")
# Load dataframe
df = load_dataset(DATASET_DICT[dataset], data_dir=data_dir, split="test").to_pandas()
if fraction < 1.0:
df = df.sample(frac=fraction, random_state=42)
save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix)
if max_context_length is not None:
save_filename = save_filename.with_name(
save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix
)
if compress_questions:
df["context"] = df["context"] + df["question"]
df["question"] = ""
save_filename = save_filename.with_name(save_filename.stem + "__compressed_questions" + save_filename.suffix)
# Load press
assert press_name in PRESS_DICT
press = PRESS_DICT[press_name]
if isinstance(press, (DuoAttentionPress)):
press.head_compression_ratio = compression_ratio
elif isinstance(press, (ComposedPress)):
for ps in press.presses:
if isinstance(ps, (ThinKPress)):
ps.key_channel_compression_ratio = key_channel_compression_ratio
save_filename = save_filename.with_name(
save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
)
else:
ps.compression_ratio = compression_ratio # type:ignore[attr-defined]
elif isinstance(press, (ThinKPress)):
press.key_channel_compression_ratio = key_channel_compression_ratio
save_filename = save_filename.with_name(
save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
)
else:
press.compression_ratio = compression_ratio # type:ignore[attr-defined]
# Initialize pipeline with the correct attention implementation
model_kwargs = {"torch_dtype": "auto"}
if isinstance(press, ObservedAttentionPress):
model_kwargs["attn_implementation"] = "eager"
else:
try:
import flash_attn # noqa: F401
model_kwargs["attn_implementation"] = "flash_attention_2"
except ImportError:
pass
if device == "auto":
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
else:
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
# Run pipeline on each context
df["predicted_answer"] = None
df_context = df.groupby("context")
assert all(df_context["answer_prefix"].nunique() == 1)
for context, df_ in tqdm(df_context, total=df["context"].nunique()):
questions = df_["question"].to_list()
max_new_tokens_ = max_new_tokens if max_new_tokens is not None else df_["max_new_tokens"].iloc[0]
answer_prefix = df_["answer_prefix"].iloc[0]
output = pipe(
context,
questions=questions,
answer_prefix=answer_prefix,
press=press,
max_new_tokens=max_new_tokens_,
max_context_length=max_context_length,
)
df.loc[df_.index, "predicted_answer"] = output["answers"]
df.loc[df_.index, "compression_ratio"] = press.compression_ratio # type:ignore[attr-defined]
torch.cuda.empty_cache()
# Save answers
df[["predicted_answer", "compression_ratio"]].to_csv(str(save_filename), index=False)
# Calculate metrics
scorer = SCORER_DICT[dataset]
metrics = scorer(df)
with open(str(save_filename).replace(".csv", ".json"), "w") as f:
json.dump(metrics, f)
print(f"Average compression ratio: {df['compression_ratio'].mean():.2f}")
print(metrics)
if __name__ == "__main__":
Fire(evaluate)