Skip to content

Commit 97408ee

Browse files
Fix FinchPress for Qwen (#82)
1 parent f7d77d3 commit 97408ee

File tree

4 files changed

+48
-22
lines changed

4 files changed

+48
-22
lines changed

evaluation/evaluate.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
CriticalKVPress,
2828
DuoAttentionPress,
2929
ExpectedAttentionPress,
30+
FinchPress,
3031
KnormPress,
3132
ObservedAttentionPress,
33+
PyramidKVPress,
34+
QFilterPress,
3235
RandomPress,
3336
SnapKVPress,
3437
StreamingLLMPress,
3538
ThinKPress,
3639
TOVAPress,
37-
QFilterPress,
38-
PyramidKVPress,
39-
FinchPress,
4040
)
4141

4242
logger = logging.getLogger(__name__)
@@ -197,7 +197,11 @@ def evaluate(
197197
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
198198

199199
if isinstance(press, FinchPress):
200-
df["context"] = df["context"] + pipe.tokenizer.bos_token
200+
assert compress_questions is True, "FinchPress requires compress_questions to be set to True"
201+
# FinchPress uses a delimiter token to separate context and question
202+
# So we need to update the tokenizer and the model embeddings.
203+
press.update_model_and_tokenizer(pipe.model, pipe.tokenizer)
204+
df["context"] = df["context"] + press.delimiter_token
201205

202206
if compress_questions:
203207
df["context"] = df["context"] + df["question"]

kvpress/presses/finch_press.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44

5-
from dataclasses import dataclass, field
65
from contextlib import contextmanager
6+
from dataclasses import dataclass, field
77

88
import torch
99
from torch.nn import functional as F
10+
from transformers.models.llama.modeling_llama import rotate_half
1011

1112
from kvpress.presses.base_press import BasePress
1213
from kvpress.presses.snapkv_press import SnapKVPress
13-
from transformers.models.llama.modeling_llama import rotate_half
1414

1515

1616
@dataclass
@@ -20,7 +20,9 @@ class FinchPress(BasePress):
2020
without chunked prefilling.
2121
2222
Finch starts with SnapKV-style compression, but the window size is not fixed. Instead, the user must provide
23-
a second <bos_token> between the context and the window (input = context + tokenizer.bos_token + question)
23+
a delimiter token between the context and the window (input = context + delimiter_token + question).
24+
The delimiter token is set by the user via the update_model_and_tokenizer method.
25+
2426
2527
The options are also available
2628
- normalizing scores using the number of non-zero attention weights in the window
@@ -32,6 +34,8 @@ class FinchPress(BasePress):
3234
chunk_length: int = None
3335
normalize_scores: bool = True
3436
rerotate_keys: bool = True
37+
delimiter_token: str = field(default=None, init=False) # To be set by the update_model_and_tokenizer method
38+
delimiter_token_id: int = field(default=None, init=False) # To be set by the update_model_and_tokenizer method
3539
window_size: int = field(default=None, init=False)
3640

3741
def score(self, module, hidden_states, keys, values, attentions, kwargs):
@@ -109,23 +113,40 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
109113

110114
def embed_token_forward_hook(self, module, input, output):
111115
"""
112-
Forward hook to detect a second <bos_token> delimiting the context and the window
116+
Forward hook to detect a delimiter token between the context and the window
113117
"""
114-
if input[0][0, 0] == self.bos_token_id: # prefilling
118+
if input[0].shape[1] > 1 and self.delimiter_token_id in input[0][0]: # prefilling
115119
assert len(input[0]) == 1, "Only batch size 1 is supported."
116-
try:
117-
context_length = int(torch.nonzero(input[0][0] == self.bos_token_id)[1].item())
118-
self.window_size = len(input[0][0]) - 1 - context_length
119-
assert self.window_size > 0, "No window detected (window size must be > 0)."
120-
# Remove the second <bos_token> from the output
121-
output = torch.cat([output[:, :context_length], output[:, context_length + 1 :]], dim=1)
122-
except IndexError:
123-
raise IndexError("A second <bos_token> must delimit the context and the question.")
120+
# Find the delimiter token and compute the window size
121+
delim_tokens = input[0][0] == self.delimiter_token_id
122+
assert delim_tokens.sum() == 1, "Only one delimiter token should be present."
123+
context_length = int(torch.nonzero(delim_tokens)[0].item())
124+
self.window_size = len(input[0][0]) - 1 - context_length
125+
assert self.window_size > 0, "No window detected (window size must be > 0)."
126+
# Remove the delimiter token from the output
127+
output = output[:, ~delim_tokens]
124128
return output
125129

130+
def update_model_and_tokenizer(self, model, tokenizer, delimiter_token : str = "<|finch_sep|>"):
131+
"""
132+
Set the delimiter token and update the tokenizer accordingly.
133+
This method should be called before calling the press.
134+
"""
135+
self.delimiter_token = delimiter_token
136+
if delimiter_token not in tokenizer.get_vocab():
137+
tokenizer.add_special_tokens({"additional_special_tokens": [delimiter_token]})
138+
self.delimiter_token_id = tokenizer.convert_tokens_to_ids(delimiter_token) # type: ignore
139+
# update model embeddings
140+
model.resize_token_embeddings(len(tokenizer))
141+
return tokenizer
142+
126143
@contextmanager
127144
def __call__(self, model):
128-
self.bos_token_id = model.generation_config.bos_token_id
145+
# The user should set the delimiter_token_id before calling the press.
146+
if self.delimiter_token_id is None:
147+
raise ValueError("""No delimiter token ID provided.
148+
Use the update_model_and_tokenizer method before calling the press.""")
149+
129150
with super().__call__(model):
130151
try:
131152
hook = model.model.embed_tokens.register_forward_hook(self.embed_token_forward_hook)

tests/presses/test_finch_press.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import torch
5+
56
from kvpress import FinchPress
67
from tests.fixtures import unit_test_model # noqa: F401
78

@@ -13,9 +14,8 @@ def test_finch_press(unit_test_model): # noqa: F811
1314
FinchPress(0.5, normalize_scores=False),
1415
FinchPress(0.2, chunk_length=5),
1516
]:
17+
press.delimiter_token_id = unit_test_model.config.eos_token_id
1618
with press(unit_test_model):
17-
bos = unit_test_model.generation_config.bos_token_id
1819
input_ids = torch.arange(10, 20)
19-
input_ids[0] = bos
20-
input_ids[8] = bos
20+
input_ids[8] = press.delimiter_token_id
2121
unit_test_model(input_ids.unsqueeze(0))

tests/presses/test_pyramidkv_press.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import pytest
5-
from kvpress.presses.pyramidkv_press import PyramidKVPress
65
import torch.nn as nn
76

7+
from kvpress.presses.pyramidkv_press import PyramidKVPress
8+
89

910
class MockConfig:
1011
def __init__(self, num_hidden_layers):

0 commit comments

Comments
 (0)