@@ -30,6 +30,85 @@ class KeyRerotationPress(BasePress):
3030 def __post_init__ (self ):
3131 assert isinstance (self .press , ScorerPress )
3232
33+ @staticmethod
34+ def _rerotate_cos_sin (x , inv_freq , selected_positions ):
35+ """
36+ Compute cosine and sine rotary positional embeddings required to
37+ re-rotate pruned keys back into the canonical RoPE space.
38+
39+ Parameters
40+ ----------
41+ x : torch.Tensor
42+ Any key-like tensor that provides ``dtype`` and ``device``.
43+ Shape ``(bsz, num_key_value_heads, q_len, d)``.
44+ inv_freq : torch.Tensor
45+ ``module.rotary_emb.inv_freq``. Shape ``(d//2,)``.
46+ selected_positions : torch.Tensor
47+ Indices of the *kept* tokens.
48+ Shape ``(bsz, num_key_value_heads, n_kept)``.
49+
50+ Returns
51+ -------
52+ cos, sin : torch.Tensor
53+ Cosine and sine embeddings, each of shape
54+ ``(bsz, num_key_value_heads, n_kept, d)``, matching ``dtype``/``device`` of ``x``.
55+ """
56+ bsz , num_key_value_heads , n_kept = selected_positions .shape
57+ device = selected_positions .device
58+ device_type = x .device .type
59+ dtype = x .dtype
60+ # Original positional indices
61+ idx = torch .arange (0 , n_kept , device = device ) # (n_kept,)
62+ idx = idx .unsqueeze (0 ) # (1, n_kept)
63+ inv_freq = inv_freq [None , None , :, None ].float ().expand (bsz , num_key_value_heads , - 1 , 1 )
64+ idx = idx [:, None , :].float ().expand (bsz , num_key_value_heads , n_kept )
65+ # Compute delta between original and selected positions
66+ delta_pos = idx - selected_positions # (bsz, num_key_value_heads, n_kept)
67+ delta_pos = delta_pos .unsqueeze (2 ) # (bsz, num_key_value_heads, 1, n_kept)
68+
69+ device_type = device_type if isinstance (device_type , str ) and device_type != "mps" else "cpu"
70+
71+ with torch .autocast (device_type = device_type , enabled = False ):
72+ # Compute the new freq by scaling inv_freq by delta
73+ freqs = delta_pos .float () * inv_freq .float () # (bsz, num_key_value_heads, d//2, n_kept)
74+ freqs = freqs .transpose (2 , 3 ) # (bsz, num_key_value_heads, n_kept, d//2)
75+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
76+ # Compute cosine and sine required to re-rotate keys to selected positions
77+ cos = emb .cos ().contiguous ()
78+ sin = emb .sin ().contiguous ()
79+ return cos .to (dtype = dtype ), sin .to (dtype = dtype )
80+
81+ @staticmethod
82+ def rerotate_keys (
83+ module : nn .Module ,
84+ indices : torch .Tensor ,
85+ keys : torch .Tensor ,
86+ ) -> torch .Tensor :
87+ """
88+ Rerotate keys to have a uniform RoPE representation of keys after pruning.
89+
90+ Parameters
91+ ----------
92+ module : nn.Module
93+ The model module containing the rotary embedding.
94+ indices : torch.Tensor
95+ Indices of the kept tokens after pruning.
96+ keys : torch.Tensor
97+ The keys tensor to be rerotated.
98+
99+ Returns
100+ -------
101+ torch.Tensor
102+ The rerotated keys tensor of shape
103+ ``(bsz, num_heads, n_kept, d)``.
104+ """
105+ new_cos , new_sin = KeyRerotationPress ._rerotate_cos_sin (keys ,
106+ module .rotary_emb .inv_freq ,
107+ indices )
108+ indices = indices .unsqueeze (- 1 ).expand (- 1 , - 1 , - 1 , module .head_dim )
109+ keys = keys .gather (2 , indices ).contiguous ()
110+ return (keys * new_cos ) + (rotate_half (keys ) * new_sin )
111+
33112 def compress (
34113 self ,
35114 module : nn .Module ,
@@ -50,22 +129,7 @@ def compress(
50129 n_kept = int (q_len * (1 - self .press .compression_ratio ))
51130 indices = scores .topk (n_kept , dim = - 1 ).indices
52131 indices = torch .sort (indices , dim = 2 ).values
132+ keys = self .rerotate_keys (module , indices , keys )
53133 indices = indices .unsqueeze (- 1 ).expand (- 1 , - 1 , - 1 , module .head_dim )
54-
55- cos , sin = kwargs ["position_embeddings" ]
56- # Rerotate as follows
57- # 1. keys = RoPE(W_k * hidden_states)
58- # 2. keys_unrotated = RoPE^-1(keys)
59- # 3. keys_pruned = prune(keys_unrotated)
60- # 4. keys = RoPE(keys_pruned)
61-
62- # 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below
63- keys = (keys * cos .unsqueeze (1 )) + (rotate_half (keys ) * (- sin .unsqueeze (1 )))
64- # 3. Prune keys
65- keys = keys .gather (2 , indices ).contiguous ()
66- # 4. Apply RoPE
67- cos , sin = cos [:, :n_kept ], sin [:, :n_kept ]
68- keys = (keys * cos .unsqueeze (1 )) + (rotate_half (keys ) * sin .unsqueeze (1 ))
69-
70134 values = values .gather (2 , indices ).contiguous ()
71135 return keys , values
0 commit comments