Skip to content
Closed
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
73dcb45
first commit
FaureElia Mar 12, 2025
81b53f9
created setup
FaureElia Mar 12, 2025
8f87626
first basic implementation of FinchPress
FaureElia Mar 13, 2025
1db9786
adjusted FinchPress
FaureElia Mar 14, 2025
98bc55a
introduce new parameter condition_len (still have to modify compute_f…
FaureElia Mar 15, 2025
008e7e9
Adapted compute_finch_attention to condition_len. TODO: Normalization…
miriam-16 Mar 17, 2025
5f2d089
fix question len, adjusted the question ids for generation.
giulio98 Mar 15, 2025
c512283
fix: parameter condition_len in function
miriam-16 Mar 17, 2025
87ba07e
add nomralization
FaureElia Mar 17, 2025
fa0e28d
clean FinchPress
FaureElia Mar 18, 2025
c7dff0e
normalization
giulio98 Mar 18, 2025
8970c97
fix len question ids
giulio98 Mar 18, 2025
84ce456
add self
giulio98 Mar 18, 2025
f0efe61
chunked forward -- not working(gibberish output), check shapes and le…
giulio98 Mar 18, 2025
15db060
Fixed attention mask and binary mask in compute_finch_attention
FaureElia Mar 22, 2025
26c64b0
updated gitignore
FaureElia Mar 22, 2025
a6ebc11
n_kept value at last iteration (TODO: further checks)
miriam-16 Mar 22, 2025
e2f3dd6
fixed pipeline for other presses and finished adjusting number of tok…
FaureElia Mar 23, 2025
54c991b
Fix: n_kept at last iteration considers condition_len
miriam-16 Mar 23, 2025
91e9f05
adjusted n_kept, at final chunk keeps 90% of context+full question
FaureElia Mar 23, 2025
c48ea98
added docstring and general code cleaning
miriam-16 Mar 24, 2025
7ec514a
reorder indices, add normalization parameter, add sink tokens
giulio98 Mar 25, 2025
657a27d
partially fixed finch press
FaureElia Mar 26, 2025
d0bb5d0
fixed finch press, aligned to original code
FaureElia Mar 28, 2025
fdde74d
preparing for pr
giulio98 Apr 4, 2025
c1acb0d
add ChunkKV
Dominic789654 Mar 5, 2025
8533beb
Update copyright date (#60)
SimJeg Mar 13, 2025
e71f4e9
Add QFilterPress (#54)
NathanGodey Mar 17, 2025
4f917cc
Add longbench benchmark
Xnhyacinth Mar 19, 2025
8d95b14
Add DuoAttention on the fly (#63)
SimJeg Mar 19, 2025
cace3a5
Resolve stash conflict in tests/default_presses.py
giulio98 Apr 4, 2025
a833ae3
Fix: remove conflict markers after saving changes
giulio98 Apr 4, 2025
39417ad
test passes
giulio98 Apr 4, 2025
2cdd69c
add readme
giulio98 Apr 4, 2025
0e4d4c8
commit
giulio98 Apr 4, 2025
ea42501
Merge branch 'main' into features/finch_press
giulio98 Apr 4, 2025
c570844
make style
giulio98 Apr 4, 2025
9afd7f8
remove leftovers
giulio98 Apr 4, 2025
e8d84fa
make style check pass
giulio98 Apr 6, 2025
689c9c5
fix: readme
giulio98 Apr 6, 2025
f628ca4
unmodify qfilter_press.py duo_attention_press.py and longbech/calcula…
giulio98 Apr 14, 2025
c5a9ca7
typo duo attention press import
giulio98 Apr 14, 2025
f8034e5
use SnapKV window attention for finch
giulio98 Apr 14, 2025
f830665
update past_key_values using last_output from original forward
giulio98 Apr 14, 2025
3f2ed03
add import snapkv
giulio98 Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
- `TOVAPress` ([source](kvpress/presses/tova_press.py), [paper](https://arxiv.org/abs/2401.06104)): attention weight of the last query averaged across heads
- `ObservedAttentionPress` ([source](kvpress/presses/observed_attention_press.py), [paper](https://arxiv.org/abs/2306.14048)): average attention weight observed during in pre-filling phase
- `QFilterPress` ([source](kvpress/presses/qfilter_press.py), [paper](https://arxiv.org/abs/2503.02812)): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.
- `FinchPress` (([source](kvpress/presses/finch_press.py)), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): average attention weight of the prompt.

Some presses rely on a different logic:
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
Expand All @@ -82,7 +83,6 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.


For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

## Evaluation
Expand Down
4 changes: 3 additions & 1 deletion evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
CriticalKVPress,
DuoAttentionPress,
ExpectedAttentionPress,
FinchPress,
KnormPress,
ObservedAttentionPress,
QFilterPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
QFilterPress,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,6 +77,7 @@
"think": ThinKPress(),
"tova": TOVAPress(),
"duo_attention": DuoAttentionPress(),
"finch": FinchPress(),
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
"qfilter": QFilterPress(),
Expand Down
1 change: 1 addition & 0 deletions evaluation/longbench/calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import string
from collections import Counter

import numpy as np
from rouge import Rouge

Expand Down
4 changes: 3 additions & 1 deletion kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.qfilter_press import QFilterPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()
Expand All @@ -49,6 +50,7 @@
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"FinchPress",
"ChunkKVPress",
"QFilterPress",
]
13 changes: 12 additions & 1 deletion kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers.pipelines.base import GenericTensor

from kvpress.presses.base_press import BasePress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
Expand Down Expand Up @@ -161,6 +162,14 @@ def _forward(
context_ids = input_tensors["context_ids"].to(self.model.device)
context_length = context_ids.shape[1]

if isinstance(press, FinchPress) or isinstance(getattr(press, "press", None), FinchPress):
# finch press cannot be done with multiple questions
assert len(input_tensors["questions_ids"]) == 1, "Finch press cannot be done with multiple questions"
question_ids = input_tensors["questions_ids"][0].to(self.model.device)
context_ids = torch.cat((context_ids, question_ids[:, :-1]), dim=1)
press.condition_len = len(question_ids[:, :-1][0])
input_tensors["questions_ids"][0] = question_ids[:, -1:]

# Prefilling using the press on the context
if cache is None:
cache = DynamicCache()
Expand All @@ -182,7 +191,9 @@ def _forward(
answer = self.generate_answer(
question_ids=question_ids.to(self.model.device),
cache=cache,
context_length=(cache.get_seq_length() if isinstance(press, KeyRerotationPress) else context_length),
context_length=(
cache.get_seq_length() if isinstance(press, (KeyRerotationPress, FinchPress)) else context_length
),
max_new_tokens=max_new_tokens,
)
answers.append(answer)
Expand Down
2 changes: 1 addition & 1 deletion kvpress/presses/duo_attention_press.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from cachetools import cached, LRUCache # type: ignore[import-untyped]
from contextlib import contextmanager
from dataclasses import dataclass, field
from io import StringIO

import numpy as np
import requests # type: ignore[import-untyped]
import torch
from cachetools import LRUCache, cached # type: ignore[import-untyped]
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
Expand Down
Loading
Loading