@@ -387,73 +387,64 @@ def pattern(
387387 past_seq_len_0D = op .Squeeze (past_seq_len , _outputs = ["past_seq_len_0D" ])
388388 total_seq_len_0D = op .Add (past_seq_len_0D , seq_len_0D , _outputs = ["total_seq_len_0D" ])
389389
390- # All of the Add node's outputs
391- current_range_A = op .Range (past_seq_len_0D , total_seq_len_0D , 1 , _outputs = ["current_range_A " ])
392- total_seq_len_A = op .Reshape (total_seq_len_0D , [- 1 ], allowzero = 0 , _outputs = ["total_seq_len_A " ])
393- current_range_B = op .Range (0 , total_seq_len_0D , 1 , _outputs = ["current_range_B " ])
394- total_seq_len_B = op .Reshape (total_seq_len_0D , [- 1 ], allowzero = 0 , _outputs = ["total_seq_len_B " ])
395- total_seq_len_C = op .Reshape (total_seq_len_0D , [- 1 ], allowzero = 0 , _outputs = ["total_seq_len_C " ])
390+ # Create ranges for different dimensions
391+ kv_range = op .Range (past_seq_len_0D , total_seq_len_0D , 1 , _outputs = ["kv_range " ])
392+ total_seq_len_for_kv = op .Reshape (total_seq_len_0D , [- 1 ], allowzero = 0 , _outputs = ["total_seq_len_for_kv " ])
393+ query_range = op .Range (0 , total_seq_len_0D , 1 , _outputs = ["query_range " ])
394+ total_seq_len_for_query = op .Reshape (total_seq_len_0D , [- 1 ], allowzero = 0 , _outputs = ["total_seq_len_for_query " ])
395+ total_seq_len_for_batch = op .Reshape (total_seq_len_0D , [- 1 ], allowzero = 0 , _outputs = ["total_seq_len_for_batch " ])
396396
397- total_seq_len_final = op .Reshape (total_seq_len_0D , pattern .ANY_VALUE , allowzero = 0 , _outputs = ["total_seq_len_final" ])
397+ # total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"])
398398
399- # EXPAND BRANCH A
399+ # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1]
400400 batch_size = op .Shape (past_kv_cache_2 , end = 1 , start = 0 , _outputs = ["batch_size" ])
401- mask_shape_A = op .Concat (batch_size , [1 ], seq_len , total_seq_len_A , axis = 0 , _outputs = ["mask_shape_A " ])
402- mask_shape_A_abs = op .Abs (mask_shape_A , _outputs = ["mask_shape_A_abs " ])
403- reshaped_range_A = op .Reshape (current_range_A , [1 , 1 , - 1 , 1 ], allowzero = 1 , _outputs = ["reshaped_range_A " ])
404- mask_expanded_A = op .Expand (reshaped_range_A , mask_shape_A_abs , _outputs = ["mask_expanded_A " ])
405-
406- # EXPAND BRANCH B
407- mask_shape_B = op .Concat (batch_size , [1 ], seq_len , total_seq_len_B , axis = 0 , _outputs = ["mask_shape_B " ])
408- mask_shape_B_abs = op .Abs (mask_shape_B , _outputs = ["mask_shape_B_abs " ])
409- reshaped_range_B = op .Reshape (current_range_B , [1 , 1 , 1 , - 1 ], allowzero = 1 , _outputs = ["reshaped_range_B " ])
410- mask_expanded_B = op .Expand (reshaped_range_B , mask_shape_B_abs , _outputs = ["mask_expanded_B " ])
411-
412- # EXPAND BRANCH C
413- mask_shape_C = op .Concat (batch_size , [1 ], seq_len , total_seq_len_C , axis = 0 , _outputs = ["mask_shape_C " ])
414- mask_shape_C_abs = op .Abs (mask_shape_C , _outputs = ["mask_shape_C_abs " ])
401+ kv_mask_shape = op .Concat (batch_size , [1 ], seq_len , total_seq_len_for_kv , axis = 0 , _outputs = ["kv_mask_shape " ])
402+ kv_mask_shape_abs = op .Abs (kv_mask_shape , _outputs = ["kv_mask_shape_abs " ])
403+ reshaped_kv_range = op .Reshape (kv_range , [1 , 1 , - 1 , 1 ], allowzero = 1 , _outputs = ["reshaped_kv_range " ])
404+ expanded_kv_range = op .Expand (reshaped_kv_range , kv_mask_shape_abs , _outputs = ["expanded_kv_range " ])
405+
406+ # BRANCH B: Query Range - Creates tensor with query positions [1, 1, 1, total_seq_len]
407+ query_mask_shape = op .Concat (batch_size , [1 ], seq_len , total_seq_len_for_query , axis = 0 , _outputs = ["query_mask_shape " ])
408+ query_mask_shape_abs = op .Abs (query_mask_shape , _outputs = ["query_mask_shape_abs " ])
409+ reshaped_query_range = op .Reshape (query_range , [1 , 1 , 1 , - 1 ], allowzero = 1 , _outputs = ["reshaped_query_range " ])
410+ expanded_query_range = op .Expand (reshaped_query_range , query_mask_shape_abs , _outputs = ["expanded_query_range " ])
411+
412+ # BRANCH C: Batch Range - Creates tensor with batch indices [batch_size, 1, 1, 1]
413+ batch_mask_shape = op .Concat (batch_size , [1 ], seq_len , total_seq_len_for_batch , axis = 0 , _outputs = ["batch_mask_shape " ])
414+ batch_mask_shape_abs = op .Abs (batch_mask_shape , _outputs = ["batch_mask_shape_abs " ])
415415 batch_size_squeezed = op .Squeeze (batch_size , _outputs = ["batch_size_squeezed" ])
416416 batch_range = op .Range (0 , batch_size_squeezed , 1 , _outputs = ["batch_range" ])
417- reshaped_range_C = op .Reshape (batch_range , [- 1 , 1 , 1 , 1 ], allowzero = 1 , _outputs = ["reshaped_range_C " ])
418- mask_expanded_C = op .Expand (reshaped_range_C , mask_shape_C_abs , _outputs = ["mask_expanded_C " ])
419-
420- # EXPAND A/B TO AND
421- mask_expanded_A_sub = op .Sub (mask_expanded_A , 262144 , _outputs = ["mask_expanded_A_sub " ])
422- mask_A_B_greater = op .Greater (mask_expanded_B , mask_expanded_A_sub , _outputs = ["mask_A_B_greater " ])
423- mask_A_B_greater_bitwise = op .And (True , mask_A_B_greater , _outputs = ["mask_A_B_greater_bitwise " ])
424- mask_A_B_less = op .LessOrEqual (mask_expanded_B , mask_expanded_A , _outputs = ["mask_A_B_less " ])
425- mask_A_B_combined = op .And (mask_A_B_greater_bitwise , mask_A_B_less , _outputs = ["mask_A_B_combined " ])
426- mask_A_B_combined_bitwise = op .And (True , mask_A_B_combined , _outputs = ["mask_A_B_combined_bitwise " ])
427-
428- # EXPAND B/C TO AND
429- unsqueezed_mask_expanded_B = op .Unsqueeze (mask_expanded_B , [- 1 ], _outputs = ["unsqueezed_mask_expanded_B " ])
430- unsqueezed_mask_expanded_C = op .Unsqueeze (mask_expanded_C , [- 1 ], _outputs = ["unsqueezed_mask_expanded_C " ])
431- mask_B_C_concat = op .Concat (unsqueezed_mask_expanded_C , unsqueezed_mask_expanded_B , axis = - 1 , _outputs = ["mask_B_C_concat " ])
417+ reshaped_batch_range = op .Reshape (batch_range , [- 1 , 1 , 1 , 1 ], allowzero = 1 , _outputs = ["reshaped_batch_range " ])
418+ expanded_batch_range = op .Expand (reshaped_batch_range , batch_mask_shape_abs , _outputs = ["expanded_batch_range " ])
419+
420+ # Combine KV/Query Ranges for Sliding Window Mask
421+ kv_range_offset = op .Sub (expanded_kv_range , 262144 , _outputs = ["kv_range_offset " ])
422+ query_gt_kv_offset = op .Greater (expanded_query_range , kv_range_offset , _outputs = ["query_gt_kv_offset " ])
423+ query_gt_kv_offset_mask = op .And (True , query_gt_kv_offset , _outputs = ["query_gt_kv_offset_mask " ])
424+ query_le_kv = op .LessOrEqual (expanded_query_range , expanded_kv_range , _outputs = ["query_le_kv " ])
425+ sliding_window_mask = op .And (query_gt_kv_offset_mask , query_le_kv , _outputs = ["sliding_window_mask " ])
426+ sliding_window_mask_final = op .And (True , sliding_window_mask , _outputs = ["sliding_window_mask_final " ])
427+
428+ # Combine Query/Batch Ranges for Attention Mask Lookup
429+ unsqueezed_query_range = op .Unsqueeze (expanded_query_range , [- 1 ], _outputs = ["unsqueezed_query_range " ])
430+ unsqueezed_batch_range = op .Unsqueeze (expanded_batch_range , [- 1 ], _outputs = ["unsqueezed_batch_range " ])
431+ batch_query_indices = op .Concat (unsqueezed_batch_range , unsqueezed_query_range , axis = - 1 , _outputs = ["batch_query_indices " ])
432432 attention_mask_bool = op .Cast (attention_mask , to = ir .DataType .BOOL , _outputs = ["attention_mask_bool" ])
433- mask_gatherND = op .GatherND (attention_mask_bool , mask_B_C_concat , batch_dims = 0 , _outputs = ["mask_gatherND" ])
434-
435- mask_A_B_C_combined = op .And (mask_A_B_combined_bitwise , mask_gatherND , _outputs = ["mask_A_B_C_combined" ])
436- mask_A_B_C_negated = op .Not (mask_A_B_C_combined , _outputs = ["mask_A_B_C_negated" ])
437- mask_A_B_C_fp32 = op .Cast (mask_A_B_C_negated , to = ir .DataType .FLOAT , _outputs = ["mask_A_B_C_fp32" ])
438- mask_A_B_C_scaled = op .Mul (mask_A_B_C_fp32 , pattern .ANY_VALUE )
433+ attention_lookup = op .GatherND (attention_mask_bool , batch_query_indices , batch_dims = 0 , _outputs = ["attention_lookup" ])
434+
435+ # Final Mask Combination
436+ final_attention_mask = op .And (sliding_window_mask_final , attention_lookup , _outputs = ["final_attention_mask" ])
437+ inverted_mask = op .Not (final_attention_mask , _outputs = ["inverted_mask" ])
438+ mask_fp32 = op .Cast (inverted_mask , to = ir .DataType .FLOAT , _outputs = ["mask_fp32" ])
439+ scaled_mask = op .Mul (mask_fp32 , pattern .ANY_VALUE )
440+
439441 # Propagation to GQA
440- mask_sliced = op .Slice (mask_A_B_C_scaled , [0 ], pattern .ANY_VALUE , [3 ], [1 ], _outputs = ["mask_sliced " ])
442+ sliced_mask = op .Slice (scaled_mask , [0 ], pattern .ANY_VALUE , [3 ], [1 ], _outputs = ["sliced_mask " ])
441443
442- gqa_input = pattern .OrValue ([mask_sliced , mask_A_B_C_scaled ])
444+ gqa_input = pattern .OrValue ([sliced_mask , scaled_mask ])
443445
444446 return op .GQA (
445447 gqa_input ,
446- pattern .ANY_VALUE , # position_ids_k
447- pattern .ANY_VALUE , # position_ids_q
448- pattern .ANY_VALUE , # query
449- pattern .ANY_VALUE , # key
450- pattern .ANY_VALUE , # value
451- pattern .ANY_VALUE , # past_key
452- pattern .ANY_VALUE , # past_value
453- pattern .ANY_VALUE , # seqlens_k (optional)
454- pattern .ANY_VALUE , # total_seq_length (optional)
455- pattern .ANY_VALUE , # cos
456- pattern .ANY_VALUE , # sin
457448 _allow_other_inputs = True ,
458449 _domain = "ai.onnxruntime._fusion" ,
459450 _outputs = ["attn_output" , "key_seq" , "value_seq" ],
0 commit comments