22# SPDX-License-Identifier: Apache-2.0
33
44
5- from dataclasses import dataclass , field
65from contextlib import contextmanager
6+ from dataclasses import dataclass , field
77
88import torch
99from torch .nn import functional as F
10+ from transformers .models .llama .modeling_llama import rotate_half
1011
1112from kvpress .presses .base_press import BasePress
1213from kvpress .presses .snapkv_press import SnapKVPress
13- from transformers .models .llama .modeling_llama import rotate_half
1414
1515
1616@dataclass
@@ -20,7 +20,9 @@ class FinchPress(BasePress):
2020 without chunked prefilling.
2121
2222 Finch starts with SnapKV-style compression, but the window size is not fixed. Instead, the user must provide
23- a second <bos_token> between the context and the window (input = context + tokenizer.bos_token + question)
23+ a delimiter token between the context and the window (input = context + delimiter_token + question).
24+ The delimiter token is set by the user via the update_model_and_tokenizer method.
25+
2426
2527 The options are also available
2628 - normalizing scores using the number of non-zero attention weights in the window
@@ -32,6 +34,8 @@ class FinchPress(BasePress):
3234 chunk_length : int = None
3335 normalize_scores : bool = True
3436 rerotate_keys : bool = True
37+ delimiter_token : str = field (default = None , init = False ) # To be set by the update_model_and_tokenizer method
38+ delimiter_token_id : int = field (default = None , init = False ) # To be set by the update_model_and_tokenizer method
3539 window_size : int = field (default = None , init = False )
3640
3741 def score (self , module , hidden_states , keys , values , attentions , kwargs ):
@@ -109,23 +113,40 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
109113
110114 def embed_token_forward_hook (self , module , input , output ):
111115 """
112- Forward hook to detect a second <bos_token> delimiting the context and the window
116+ Forward hook to detect a delimiter token between the context and the window
113117 """
114- if input [0 ][ 0 , 0 ] == self .bos_token_id : # prefilling
118+ if input [0 ]. shape [ 1 ] > 1 and self .delimiter_token_id in input [ 0 ][ 0 ] : # prefilling
115119 assert len (input [0 ]) == 1 , "Only batch size 1 is supported."
116- try :
117- context_length = int ( torch . nonzero ( input [0 ][0 ] == self .bos_token_id )[ 1 ]. item ())
118- self . window_size = len ( input [ 0 ][ 0 ]) - 1 - context_length
119- assert self . window_size > 0 , "No window detected (window size must be > 0)."
120- # Remove the second <bos_token> from the output
121- output = torch . cat ([ output [:, : context_length ], output [:, context_length + 1 :]], dim = 1 )
122- except IndexError :
123- raise IndexError ( "A second <bos_token> must delimit the context and the question." )
120+ # Find the delimiter token and compute the window size
121+ delim_tokens = input [0 ][0 ] == self .delimiter_token_id
122+ assert delim_tokens . sum () == 1 , "Only one delimiter token should be present."
123+ context_length = int ( torch . nonzero ( delim_tokens )[ 0 ]. item ())
124+ self . window_size = len ( input [ 0 ][ 0 ]) - 1 - context_length
125+ assert self . window_size > 0 , "No window detected (window size must be > 0)."
126+ # Remove the delimiter token from the output
127+ output = output [:, ~ delim_tokens ]
124128 return output
125129
130+ def update_model_and_tokenizer (self , model , tokenizer , delimiter_token : str = "<|finch_sep|>" ):
131+ """
132+ Set the delimiter token and update the tokenizer accordingly.
133+ This method should be called before calling the press.
134+ """
135+ self .delimiter_token = delimiter_token
136+ if delimiter_token not in tokenizer .get_vocab ():
137+ tokenizer .add_special_tokens ({"additional_special_tokens" : [delimiter_token ]})
138+ self .delimiter_token_id = tokenizer .convert_tokens_to_ids (delimiter_token ) # type: ignore
139+ # update model embeddings
140+ model .resize_token_embeddings (len (tokenizer ))
141+ return tokenizer
142+
126143 @contextmanager
127144 def __call__ (self , model ):
128- self .bos_token_id = model .generation_config .bos_token_id
145+ # The user should set the delimiter_token_id before calling the press.
146+ if self .delimiter_token_id is None :
147+ raise ValueError ("""No delimiter token ID provided.
148+ Use the update_model_and_tokenizer method before calling the press.""" )
149+
129150 with super ().__call__ (model ):
130151 try :
131152 hook = model .model .embed_tokens .register_forward_hook (self .embed_token_forward_hook )
0 commit comments