Skip to content

Commit e59cb83

Browse files
committed
Renamed pattern branches to match kv_range, query_range, and batch_range computation
1 parent 7519653 commit e59cb83

1 file changed

Lines changed: 47 additions & 56 deletions

File tree

  • onnxscript/rewriter/ort_fusions

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 47 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -387,73 +387,64 @@ def pattern(
387387
past_seq_len_0D = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0D"])
388388
total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D, _outputs=["total_seq_len_0D"])
389389

390-
# All of the Add node's outputs
391-
current_range_A = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["current_range_A"])
392-
total_seq_len_A = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_A"])
393-
current_range_B = op.Range(0, total_seq_len_0D, 1, _outputs=["current_range_B"])
394-
total_seq_len_B = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_B"])
395-
total_seq_len_C = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_C"])
390+
# Create ranges for different dimensions
391+
kv_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["kv_range"])
392+
total_seq_len_for_kv = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"])
393+
query_range = op.Range(0, total_seq_len_0D, 1, _outputs=["query_range"])
394+
total_seq_len_for_query = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_query"])
395+
total_seq_len_for_batch = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"])
396396

397-
total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"])
397+
#total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"])
398398

399-
# EXPAND BRANCH A
399+
# BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1]
400400
batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"])
401-
mask_shape_A = op.Concat(batch_size, [1], seq_len, total_seq_len_A, axis=0, _outputs=["mask_shape_A"])
402-
mask_shape_A_abs = op.Abs(mask_shape_A, _outputs=["mask_shape_A_abs"])
403-
reshaped_range_A = op.Reshape(current_range_A, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_range_A"])
404-
mask_expanded_A = op.Expand(reshaped_range_A, mask_shape_A_abs, _outputs=["mask_expanded_A"])
405-
406-
# EXPAND BRANCH B
407-
mask_shape_B = op.Concat(batch_size, [1], seq_len, total_seq_len_B, axis=0, _outputs=["mask_shape_B"])
408-
mask_shape_B_abs = op.Abs(mask_shape_B, _outputs=["mask_shape_B_abs"])
409-
reshaped_range_B = op.Reshape(current_range_B, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_range_B"])
410-
mask_expanded_B = op.Expand(reshaped_range_B, mask_shape_B_abs, _outputs=["mask_expanded_B"])
411-
412-
# EXPAND BRANCH C
413-
mask_shape_C = op.Concat(batch_size, [1], seq_len, total_seq_len_C, axis=0, _outputs=["mask_shape_C"])
414-
mask_shape_C_abs = op.Abs(mask_shape_C, _outputs=["mask_shape_C_abs"])
401+
kv_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_kv, axis=0, _outputs=["kv_mask_shape"])
402+
kv_mask_shape_abs = op.Abs(kv_mask_shape, _outputs=["kv_mask_shape_abs"])
403+
reshaped_kv_range = op.Reshape(kv_range, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_kv_range"])
404+
expanded_kv_range = op.Expand(reshaped_kv_range, kv_mask_shape_abs, _outputs=["expanded_kv_range"])
405+
406+
# BRANCH B: Query Range - Creates tensor with query positions [1, 1, 1, total_seq_len]
407+
query_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_query, axis=0, _outputs=["query_mask_shape"])
408+
query_mask_shape_abs = op.Abs(query_mask_shape, _outputs=["query_mask_shape_abs"])
409+
reshaped_query_range = op.Reshape(query_range, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_query_range"])
410+
expanded_query_range = op.Expand(reshaped_query_range, query_mask_shape_abs, _outputs=["expanded_query_range"])
411+
412+
# BRANCH C: Batch Range - Creates tensor with batch indices [batch_size, 1, 1, 1]
413+
batch_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_batch, axis=0, _outputs=["batch_mask_shape"])
414+
batch_mask_shape_abs = op.Abs(batch_mask_shape, _outputs=["batch_mask_shape_abs"])
415415
batch_size_squeezed = op.Squeeze(batch_size, _outputs=["batch_size_squeezed"])
416416
batch_range = op.Range(0, batch_size_squeezed, 1, _outputs=["batch_range"])
417-
reshaped_range_C = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_range_C"])
418-
mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"])
419-
420-
# EXPAND A/B TO AND
421-
mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"])
422-
mask_A_B_greater = op.Greater(mask_expanded_B, mask_expanded_A_sub, _outputs=["mask_A_B_greater"])
423-
mask_A_B_greater_bitwise = op.And(True, mask_A_B_greater, _outputs=["mask_A_B_greater_bitwise"])
424-
mask_A_B_less = op.LessOrEqual(mask_expanded_B, mask_expanded_A, _outputs=["mask_A_B_less"])
425-
mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"])
426-
mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"])
427-
428-
# EXPAND B/C TO AND
429-
unsqueezed_mask_expanded_B = op.Unsqueeze(mask_expanded_B, [-1], _outputs=["unsqueezed_mask_expanded_B"])
430-
unsqueezed_mask_expanded_C = op.Unsqueeze(mask_expanded_C, [-1], _outputs=["unsqueezed_mask_expanded_C"])
431-
mask_B_C_concat = op.Concat(unsqueezed_mask_expanded_C, unsqueezed_mask_expanded_B, axis=-1, _outputs=["mask_B_C_concat"])
417+
reshaped_batch_range = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_batch_range"])
418+
expanded_batch_range = op.Expand(reshaped_batch_range, batch_mask_shape_abs, _outputs=["expanded_batch_range"])
419+
420+
# Combine KV/Query Ranges for Sliding Window Mask
421+
kv_range_offset = op.Sub(expanded_kv_range, 262144, _outputs=["kv_range_offset"])
422+
query_gt_kv_offset = op.Greater(expanded_query_range, kv_range_offset, _outputs=["query_gt_kv_offset"])
423+
query_gt_kv_offset_mask = op.And(True, query_gt_kv_offset, _outputs=["query_gt_kv_offset_mask"])
424+
query_le_kv = op.LessOrEqual(expanded_query_range, expanded_kv_range, _outputs=["query_le_kv"])
425+
sliding_window_mask = op.And(query_gt_kv_offset_mask, query_le_kv, _outputs=["sliding_window_mask"])
426+
sliding_window_mask_final = op.And(True, sliding_window_mask, _outputs=["sliding_window_mask_final"])
427+
428+
# Combine Query/Batch Ranges for Attention Mask Lookup
429+
unsqueezed_query_range = op.Unsqueeze(expanded_query_range, [-1], _outputs=["unsqueezed_query_range"])
430+
unsqueezed_batch_range = op.Unsqueeze(expanded_batch_range, [-1], _outputs=["unsqueezed_batch_range"])
431+
batch_query_indices = op.Concat(unsqueezed_batch_range, unsqueezed_query_range, axis=-1, _outputs=["batch_query_indices"])
432432
attention_mask_bool = op.Cast(attention_mask, to=ir.DataType.BOOL, _outputs=["attention_mask_bool"])
433-
mask_gatherND = op.GatherND(attention_mask_bool, mask_B_C_concat, batch_dims=0, _outputs=["mask_gatherND"])
434-
435-
mask_A_B_C_combined = op.And(mask_A_B_combined_bitwise, mask_gatherND, _outputs=["mask_A_B_C_combined"])
436-
mask_A_B_C_negated = op.Not(mask_A_B_C_combined, _outputs=["mask_A_B_C_negated"])
437-
mask_A_B_C_fp32 = op.Cast(mask_A_B_C_negated, to=ir.DataType.FLOAT, _outputs=["mask_A_B_C_fp32"])
438-
mask_A_B_C_scaled = op.Mul(mask_A_B_C_fp32, pattern.ANY_VALUE)
433+
attention_lookup = op.GatherND(attention_mask_bool, batch_query_indices, batch_dims=0, _outputs=["attention_lookup"])
434+
435+
# Final Mask Combination
436+
final_attention_mask = op.And(sliding_window_mask_final, attention_lookup, _outputs=["final_attention_mask"])
437+
inverted_mask = op.Not(final_attention_mask, _outputs=["inverted_mask"])
438+
mask_fp32 = op.Cast(inverted_mask, to=ir.DataType.FLOAT, _outputs=["mask_fp32"])
439+
scaled_mask = op.Mul(mask_fp32, pattern.ANY_VALUE)
440+
439441
# Propagation to GQA
440-
mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"])
442+
sliced_mask = op.Slice(scaled_mask, [0], pattern.ANY_VALUE, [3], [1], _outputs=["sliced_mask"])
441443

442-
gqa_input = pattern.OrValue([mask_sliced, mask_A_B_C_scaled])
444+
gqa_input = pattern.OrValue([sliced_mask, scaled_mask])
443445

444446
return op.GQA(
445447
gqa_input,
446-
pattern.ANY_VALUE, # position_ids_k
447-
pattern.ANY_VALUE, # position_ids_q
448-
pattern.ANY_VALUE, # query
449-
pattern.ANY_VALUE, # key
450-
pattern.ANY_VALUE, # value
451-
pattern.ANY_VALUE, # past_key
452-
pattern.ANY_VALUE, # past_value
453-
pattern.ANY_VALUE, # seqlens_k (optional)
454-
pattern.ANY_VALUE, # total_seq_length (optional)
455-
pattern.ANY_VALUE, # cos
456-
pattern.ANY_VALUE, # sin
457448
_allow_other_inputs=True,
458449
_domain="ai.onnxruntime._fusion",
459450
_outputs=["attn_output", "key_seq", "value_seq"],

0 commit comments

Comments
 (0)