11# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4+ from cachetools import cached , LRUCache # type: ignore[import-untyped]
45from contextlib import contextmanager
56from dataclasses import dataclass , field
67from io import StringIO
78
89import numpy as np
910import requests # type: ignore[import-untyped]
1011import torch
12+ from datasets import load_dataset
13+ from transformers import AutoTokenizer
14+ from transformers .models .llama .modeling_llama import apply_rotary_pos_emb
1115
1216from kvpress .presses .base_press import BasePress
1317
2024 "mistralai/Mistral-7B-Instruct-v0.3" : "Mistral-7B-Instruct-v0.3/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10" , # noqa: E501
2125}
2226
27+ cache = LRUCache (maxsize = 128 )
28+
2329
2430@dataclass
2531class DuoAttentionPress (BasePress ):
@@ -29,12 +35,15 @@ class DuoAttentionPress(BasePress):
2935 Splits attention heads into two types:
3036 - Retrieval heads: use the full KV cache
3137 - Streaming heads: use only sink and recent tokens.
32-
33- Head classification is based on scores loaded from https://github.com/mit-han-lab/duo-attention/
3438 The higher the head_compression_ratio, the more streaming heads are used.
39+
40+ Head classification is based on scores.
41+ - If on_the_fly_scoring=False, scores are loaded from https://github.com/mit-han-lab/duo-attention/
42+ - (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly
3543 """
3644
3745 head_compression_ratio : float = 0.0
46+ on_the_fly_scoring : bool = False
3847 compression_ratio_ : float = field (init = False , default = None )
3948 recent_size : int = field (init = False , default = None )
4049 sink_size : int = field (init = False , default = None )
@@ -45,7 +54,10 @@ def __post_init_from_model__(self, model):
4554 Initialize sink_size, recent_size, and streaming_mask from a model
4655 """
4756 # Load attention pattern from the DuoAttention repo
48- self .sink_size , self .recent_size , head_scores = self .load_attention_pattern (model )
57+ if self .on_the_fly_scoring :
58+ self .sink_size , self .recent_size , head_scores = 128 , 256 , duo_attention_on_the_fly (model )
59+ else :
60+ self .sink_size , self .recent_size , head_scores = self .load_attention_pattern (model )
4961
5062 # Define retrieval and streaming heads through a binary mask
5163 n_pruned = round (head_scores .size * self .head_compression_ratio )
@@ -82,6 +94,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
8294 return keys , values
8395
8496 @staticmethod
97+ @cached (cache , key = lambda model : model .config .name_or_path )
8598 def load_attention_pattern (model ):
8699 """
87100 Load the attention pattern from the DuoAttention repo
@@ -108,3 +121,68 @@ def __call__(self, model):
108121 self .__post_init_from_model__ (model )
109122 with super ().__call__ (model ):
110123 yield
124+
125+
126+ @cached (cache , key = lambda model , num_samples = 50 , q_len = 500 : (model .config .name_or_path , num_samples , q_len ))
127+ def duo_attention_on_the_fly (model , num_samples = 50 , q_len = 500 ):
128+ """
129+ New experimental method to quickly compute DuoAttention scores:
130+ - Compute the mean query and key on num_samples random samples from BookSum
131+ - Repeat the mean query and key q_len times and apply RoPE to get (Q, K)
132+ - Compute the attention weights for (Q[-1], K) and compute the "area under the cumulated attention curve"
133+ These scores could also be saved to avoid recomputing them but this method is still experimental
134+ """
135+
136+ tokenizer = AutoTokenizer .from_pretrained (model .config .name_or_path )
137+ num_heads = model .config .num_attention_heads
138+ num_key_value_heads = model .config .num_key_value_heads
139+ num_key_value_groups = num_heads // num_key_value_heads
140+
141+ # Load data
142+ dataset = load_dataset ("kmfoda/booksum" , split = "train" ).to_pandas ()
143+ texts = dataset .sample (num_samples , random_state = 42 )["chapter" ].tolist ()
144+
145+ # Initialize variables
146+ position_ids = torch .arange (q_len ).unsqueeze (0 )
147+ scores = torch .zeros ((model .config .num_hidden_layers , num_key_value_heads ))
148+
149+ # Compute scores
150+ for text in texts :
151+ with torch .no_grad ():
152+ # Compute hidden states
153+ inputs = tokenizer (text , return_tensors = "pt" ).to (model .device )
154+ hidden_states = list (model (** inputs , output_hidden_states = True ).hidden_states [:- 1 ])
155+
156+ for layer_idx , h in enumerate (hidden_states ):
157+ module = model .model .layers [layer_idx ]
158+ d = module .self_attn .head_dim
159+ h = module .input_layernorm (h )
160+
161+ # Mean query
162+ q = module .self_attn .q_proj (h )
163+ q = q .view (1 , q .shape [1 ], - 1 , d )
164+ q = q .mean (dim = 1 , keepdim = True )
165+ q = q .repeat (1 , q_len , 1 , 1 ).transpose (1 , 2 )
166+
167+ # Mean key
168+ k = module .self_attn .k_proj (h )
169+ k = k .view (1 , k .shape [1 ], - 1 , d )
170+ k = k .mean (dim = 1 , keepdim = True )
171+ k = k .repeat (1 , q_len , 1 , 1 ).transpose (1 , 2 )
172+
173+ # Apply RoPE
174+ cos , sin = model .model .rotary_emb (h , position_ids .to (h .device ))
175+ q , k = apply_rotary_pos_emb (q , k , cos , sin )
176+ k = k .repeat_interleave (num_key_value_groups , dim = 1 )
177+
178+ # Compute attention weights for the last token
179+ attn_weights = torch .matmul (q [:, :, - 1 :, :], k .transpose (2 , 3 )) / (d ** 0.5 )
180+ attn_weights = attn_weights .softmax (dim = - 1 , dtype = torch .float32 ).squeeze ()
181+
182+ # Compute score: area under the cumulated attention curve
183+ s = torch .cumsum (attn_weights , dim = 1 ).mean (1 )
184+ s = s .view (- 1 , num_key_value_groups ).mean (1 )
185+
186+ # Store the scores
187+ scores [layer_idx ] += s .cpu () / num_samples
188+ return scores .numpy ()
0 commit comments