11# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4+ from time import time
45from contextlib import contextmanager
56from dataclasses import dataclass , field
67from io import StringIO
78
9+
810import numpy as np
911import requests # type: ignore[import-untyped]
1012import torch
13+ from datasets import load_dataset
14+ from transformers import AutoTokenizer
15+ from transformers .models .llama .modeling_llama import apply_rotary_pos_emb
1116
1217from kvpress .presses .base_press import BasePress
1318
@@ -29,12 +34,15 @@ class DuoAttentionPress(BasePress):
2934 Splits attention heads into two types:
3035 - Retrieval heads: use the full KV cache
3136 - Streaming heads: use only sink and recent tokens.
32-
33- Head classification is based on scores loaded from https://github.com/mit-han-lab/duo-attention/
3437 The higher the head_compression_ratio, the more streaming heads are used.
38+
39+ Head classification is based on scores.
40+ - If on_the_fly_scoring=False, scores are loaded from https://github.com/mit-han-lab/duo-attention/
41+ - (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly
3542 """
3643
3744 head_compression_ratio : float = 0.0
45+ on_the_fly_scoring : bool = False
3846 compression_ratio_ : float = field (init = False , default = None )
3947 recent_size : int = field (init = False , default = None )
4048 sink_size : int = field (init = False , default = None )
@@ -44,15 +52,19 @@ def __post_init_from_model__(self, model):
4452 """
4553 Initialize sink_size, recent_size, and streaming_mask from a model
4654 """
47- # Load attention pattern from the DuoAttention repo
48- self .sink_size , self .recent_size , head_scores = self .load_attention_pattern (model )
49-
50- # Define retrieval and streaming heads through a binary mask
51- n_pruned = round (head_scores .size * self .head_compression_ratio )
52- self .streaming_mask = torch .zeros (head_scores .shape , dtype = bool , device = model .device )
53- if n_pruned > 0 :
54- indices = np .argsort (head_scores , axis = None )[:n_pruned ]
55- self .streaming_mask [np .unravel_index (indices , head_scores .shape )] = True
55+ if getattr (self , "_post_init_model_name" , None ) != model .config .name_or_path :
56+ # Load attention pattern from the DuoAttention repo
57+ self .sink_size , self .recent_size , head_scores = self .load_attention_pattern (model )
58+ if self .on_the_fly_scoring :
59+ head_scores = duo_attention_on_the_fly (model )
60+
61+ # Define retrieval and streaming heads through a binary mask
62+ n_pruned = round (head_scores .size * self .head_compression_ratio )
63+ self .streaming_mask = torch .zeros (head_scores .shape , dtype = bool , device = model .device )
64+ if n_pruned > 0 :
65+ indices = np .argsort (head_scores , axis = None )[:n_pruned ]
66+ self .streaming_mask [np .unravel_index (indices , head_scores .shape )] = True
67+ self ._post_init_model_name = model .config .name_or_path
5668
5769 @property
5870 def compression_ratio (self ) -> float :
@@ -108,3 +120,70 @@ def __call__(self, model):
108120 self .__post_init_from_model__ (model )
109121 with super ().__call__ (model ):
110122 yield
123+
124+
125+ def duo_attention_on_the_fly (model , num_samples = 50 , q_len = 500 ):
126+ """
127+ New experimental method to quickly compute DuoAttention scores:
128+ - Compute the mean query and key on num_samples random samples from BookSum
129+ - Repeat the mean query and key q_len times and apply RoPE to get (Q, K)
130+ - Compute the attention weights for (Q[-1], K) and compute the "area under the cumulated attention curve"
131+ These scores could also be saved to avoid recomputing them but this method is still experimental
132+ """
133+
134+ start = time ()
135+ print (f"Starting computation of DuoAttention scores based on { num_samples } samples" )
136+ tokenizer = AutoTokenizer .from_pretrained (model .config .name_or_path )
137+ num_heads = model .config .num_attention_heads
138+ num_key_value_heads = model .config .num_key_value_heads
139+ num_key_value_groups = num_heads // num_key_value_heads
140+
141+ # Load data
142+ dataset = load_dataset ("kmfoda/booksum" , split = "train" ).to_pandas ()
143+ texts = dataset .sample (num_samples , random_state = 42 )["chapter" ].tolist ()
144+
145+ # Initialize variables
146+ position_ids = torch .arange (q_len ).unsqueeze (0 )
147+ scores = torch .zeros ((model .config .num_hidden_layers , num_key_value_heads ))
148+
149+ # Compute scores
150+ for text in texts :
151+ with torch .no_grad ():
152+ # Compute hidden states
153+ inputs = tokenizer (text , return_tensors = "pt" ).to (model .device )
154+ hidden_states = list (model (** inputs , output_hidden_states = True ).hidden_states [:- 1 ])
155+
156+ for layer_idx , h in enumerate (hidden_states ):
157+ module = model .model .layers [layer_idx ]
158+ d = module .self_attn .head_dim
159+ h = module .input_layernorm (h )
160+
161+ # Mean query
162+ q = module .self_attn .q_proj (h )
163+ q = q .view (1 , q .shape [1 ], - 1 , d )
164+ q = q .mean (dim = 1 , keepdim = True )
165+ q = q .repeat (1 , q_len , 1 , 1 ).transpose (1 , 2 )
166+
167+ # Mean key
168+ k = module .self_attn .k_proj (h )
169+ k = k .view (1 , k .shape [1 ], - 1 , d )
170+ k = k .mean (dim = 1 , keepdim = True )
171+ k = k .repeat (1 , q_len , 1 , 1 ).transpose (1 , 2 )
172+
173+ # Apply RoPE
174+ cos , sin = model .model .rotary_emb (h , position_ids .to (h .device ))
175+ q , k = apply_rotary_pos_emb (q , k , cos , sin )
176+ k = k .repeat_interleave (num_key_value_groups , dim = 1 )
177+
178+ # Compute attention weights for the last token
179+ attn_weights = torch .matmul (q [:, :, - 1 :, :], k .transpose (2 , 3 )) / (d ** 0.5 )
180+ attn_weights = attn_weights .softmax (dim = - 1 , dtype = torch .float32 ).squeeze ()
181+
182+ # Compute score: area under the cumulated attention curve
183+ s = torch .cumsum (attn_weights , dim = 1 ).mean (1 ).cpu ()
184+ s = s .view (- 1 , num_key_value_groups ).mean (1 )
185+
186+ # Store the scores
187+ scores [layer_idx ] += s / num_samples
188+ print (f"Finished computation of DuoAttention scores in { time () - start :.2f} s" )
189+ return scores .numpy ()
0 commit comments