-
Notifications
You must be signed in to change notification settings - Fork 128
Expand file tree
/
Copy pathevaluate_registry.py
More file actions
114 lines (108 loc) · 4.62 KB
/
evaluate_registry.py
File metadata and controls
114 lines (108 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from benchmarks.aime25.calculate_metrics import calculate_metrics as aime25_scorer
from benchmarks.infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from benchmarks.longbench.calculate_metrics import calculate_metrics as longbench_scorer
from benchmarks.longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
from benchmarks.longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
from benchmarks.loogle.calculate_metrics import calculate_metrics as loogle_scorer
from benchmarks.math500.calculate_metrics import calculate_metrics as math500_scorer
from benchmarks.needle_in_haystack.calculate_metrics import calculate_metrics as needle_in_haystack_scorer
from benchmarks.ruler.calculate_metrics import calculate_metrics as ruler_scorer
from benchmarks.zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer
from kvpress import (
AdaKVPress,
BlockPress,
ChunkKVPress,
CompactorPress,
ComposedPress,
CriticalAdaKVPress,
CriticalKVPress,
DecodingPress,
DuoAttentionPress,
ExpectedAttentionPress,
FinchPress,
KeyDiffPress,
KnormPress,
KVzapPress,
KVzipPress,
ObservedAttentionPress,
PyramidKVPress,
QFilterPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
DMSPress,
ThinKPress,
TOVAPress,
CURPress,
LagKVPress,
)
# These dictionaries define the available datasets, scorers, and KVPress methods for evaluation.
DATASET_REGISTRY = {
"loogle": "simonjegou/loogle",
"ruler": "simonjegou/ruler",
"zero_scrolls": "simonjegou/zero_scrolls",
"infinitebench": "MaxJeblick/InfiniteBench",
"longbench": "Xnhyacinth/LongBench",
"longbench-e": "Xnhyacinth/LongBench",
"longbench-v2": "simonjegou/LongBench-v2",
"needle_in_haystack": "alessiodevoto/paul_graham_essays",
# Datasets used to be used for decoding compression
"aime25": "alessiodevoto/aime25",
"math500": "alessiodevoto/math500",
}
SCORER_REGISTRY = {
"loogle": loogle_scorer,
"ruler": ruler_scorer,
"zero_scrolls": zero_scrolls_scorer,
"infinitebench": infinite_bench_scorer,
"longbench": longbench_scorer,
"longbench-e": longbench_scorer_e,
"longbench-v2": longbenchv2_scorer,
"needle_in_haystack": needle_in_haystack_scorer,
"aime25": aime25_scorer,
"math500": math500_scorer,
}
PRESS_REGISTRY = {
"adakv_snapkv": AdaKVPress(SnapKVPress()),
"block_keydiff": BlockPress(press=KeyDiffPress(), block_size=128),
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
"critical_adakv_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
"critical_adakv_snapkv": CriticalAdaKVPress(SnapKVPress()),
"critical_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
"critical_snapkv": CriticalKVPress(SnapKVPress()),
"cur": CURPress(),
"duo_attention": DuoAttentionPress(),
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
"expected_attention": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
"finch": FinchPress(),
"keydiff": KeyDiffPress(),
"kvzip": KVzipPress(),
"kvzip_plus": KVzipPress(kvzip_plus_normalization=True),
"kvzap_linear": DMSPress(press=KVzapPress(model_type="linear")),
"kvzap_mlp": DMSPress(press=KVzapPress(model_type="mlp")),
"kvzap_mlp_head": KVzapPress(model_type="mlp"),
"kvzap_mlp_layer": AdaKVPress(KVzapPress(model_type="mlp")),
"lagkv": LagKVPress(),
"knorm": KnormPress(),
"observed_attention": ObservedAttentionPress(),
"pyramidkv": PyramidKVPress(),
"qfilter": QFilterPress(),
"random": RandomPress(),
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"think": ThinKPress(),
"tova": TOVAPress(),
"compactor": CompactorPress(),
"adakv_compactor": AdaKVPress(CompactorPress()),
"no_press": None,
"decoding_knorm": DecodingPress(base_press=KnormPress()),
"decoding_streaming_llm": DecodingPress(base_press=StreamingLLMPress()),
"decoding_tova": DecodingPress(base_press=TOVAPress()),
"decoding_qfilter": DecodingPress(base_press=QFilterPress()),
"decoding_adakv_expected_attention_e2": DecodingPress(base_press=AdaKVPress(ExpectedAttentionPress(epsilon=1e-2))),
"decoding_adakv_snapkv": DecodingPress(base_press=AdaKVPress(SnapKVPress())),
"decoding_keydiff": DecodingPress(base_press=KeyDiffPress()),
}