@@ -170,6 +170,9 @@ def aten_scaled_dot_product_attention_23(
170170 if is_causal :
171171 attn_mask = _causal_attention_mask (query , key , op23 )
172172
173+ if enable_gqa :
174+ key , value = _attention_repeat_kv_for_group_query (query , key , value , op23 )
175+
173176 if attn_mask is None :
174177 return _aten_scaled_dot_product_attention_no_mask_onnx (
175178 query , key , value , scale , dropout_p , op23
@@ -180,6 +183,67 @@ def aten_scaled_dot_product_attention_23(
180183 )
181184
182185
186+ def _attention_repeat_kv_for_group_query (
187+ query : TFloat , key : TFloat , value : TFloat , op : Opset
188+ ) -> tuple [TFloat , TFloat ]:
189+ """Expand key and value for group query attention.
190+
191+ repeat_interleave is applied on key and value to match the number of heads in query.
192+
193+ Args:
194+ query: Tensor of shape [B, q_num_heads, q_S, E]
195+ key: Tensor of shape [B, k_num_heads, kv_S, E]
196+ value: Tensor of shape [B, v_num_heads, kv_S, E]
197+
198+ Returns:
199+ Tuple of (expanded_key, expanded_value) where:
200+ - expanded_key: Tensor of shape [B, q_num_heads, kv_S, E]
201+ - expanded_value: Tensor of shape [B, q_num_heads, kv_S, E
202+ """
203+
204+ assert (
205+ query .shape [1 ] > key .shape [1 ] == value .shape [1 ]
206+ and query .shape [1 ] % key .shape [1 ] == 0
207+ ), (
208+ "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
209+ )
210+
211+ # NOTE: QKV are expected to be 4D tensors
212+
213+ batch_size = op .Shape (query , start = 0 , end = 1 ) # [B]
214+ q_num_heads = op .Shape (query , start = 1 , end = 2 ) # [Hq]
215+ kv_num_heads = op .Shape (key , start = 1 , end = 2 ) # [Hk]
216+ qk_head_size = op .Shape (key , start = 3 , end = 4 ) # [Dk]
217+ v_head_size = op .Shape (value , start = 3 , end = 4 ) # [Dv]
218+ new_kv_seq_len = op .Shape (key , start = 2 , end = 3 ) # [T]
219+
220+ interleave_dim = op .Div (q_num_heads , kv_num_heads ) # Hq / Hk
221+ two = op .Constant (value_int = 2 )
222+ k_unsqueezed = op .Unsqueeze (key , two ) # [B, Hk, 1, T, Dk]
223+ v_unsqueezed = op .Unsqueeze (value , two ) # [B, Hv, 1, T, Dv]
224+
225+ k_expand_shape = op .Concat (
226+ batch_size , kv_num_heads , interleave_dim , new_kv_seq_len , qk_head_size , axis = 0
227+ )
228+ k_expand = op .Expand (k_unsqueezed , k_expand_shape )
229+ v_expand_shape = op .Concat (
230+ batch_size , kv_num_heads , interleave_dim , new_kv_seq_len , v_head_size , axis = 0
231+ )
232+ v_expand = op .Expand (v_unsqueezed , v_expand_shape )
233+
234+ k_attention_shape = op .Concat (
235+ batch_size , q_num_heads , new_kv_seq_len , qk_head_size , axis = 0
236+ )
237+ v_attention_shape = op .Concat (
238+ batch_size , q_num_heads , new_kv_seq_len , v_head_size , axis = 0
239+ )
240+
241+ expanded_key = op .Reshape (k_expand , k_attention_shape )
242+ expanded_value = op .Reshape (v_expand , v_attention_shape )
243+
244+ return expanded_key , expanded_value
245+
246+
183247def _attention_scale (query : TFloat , op : Opset ) -> TFloat :
184248 """Calculate the scale factor for the attention result.
185249
0 commit comments