Skip to content

Commit 634148e

Browse files
authored
Introduce pattern.any_value (#2175)
Introduce pattern.any_value as a convenience when writing patterns. It is more precise than using `_allow_other_inputs=True` (which will allow any number of inputs).
1 parent e659cb4 commit 634148e

3 files changed

Lines changed: 46 additions & 10 deletions

File tree

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,17 @@ def pattern(
8787
shape_B111,
8888
):
8989
# Reshape query from (B, S, D) to (B, S, H, D/H)
90-
query_BSHDh = op.Reshape(query_BSD, _allow_other_inputs=True, _outputs=["query_BSHDh"])
90+
query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"])
9191
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
9292
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
9393

9494
# Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
95-
key_BSHkvDh = op.Reshape(key_BSDkv, _allow_other_inputs=True, _outputs=["key_BSHkvDh"])
95+
key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"])
9696
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
9797
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
9898

9999
# Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H)
100-
value_BSHkvDh = op.Reshape(
101-
value_BSDkv, _allow_other_inputs=True, _outputs=["value_BSHkvDh"]
102-
)
100+
value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"])
103101
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
104102
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])
105103

@@ -129,18 +127,18 @@ def pattern(
129127

130128
key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
131129
key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2)
132-
key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, _allow_other_inputs=True)
130+
key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE)
133131
key_seq_BHTDh = op.Reshape(
134-
key_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["key_seq_BHTDh"]
132+
key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"]
135133
)
136134

137135
# Concatenate past_value cache and current value, expand across heads
138136
# that share key/value.
139137
value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
140138
value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2)
141-
value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, _allow_other_inputs=True)
139+
value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE)
142140
value_seq_BHTDh = op.Reshape(
143-
value_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["value_seq_BHTDh"]
141+
value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"]
144142
)
145143

146144
mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111)
@@ -158,7 +156,7 @@ def pattern(
158156
attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3])
159157
# Reshape back to (B, S, D)
160158
attention_BSD = op.Reshape(
161-
attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"]
159+
attention_BSHDh, pattern.ANY_VALUE, _outputs=["attention_BSD"]
162160
)
163161
return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh
164162

onnxscript/rewriter/pattern.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,20 @@ def _is_pattern_variable(x: Any) -> bool:
634634
return type(x) is ValuePattern
635635

636636

637+
class AnyValue(ValuePattern):
638+
"""Represents a pattern that matches against any value."""
639+
640+
def __init__(self) -> None:
641+
super().__init__(None)
642+
643+
def clone(self, node_map: dict[NodePattern, NodePattern]) -> AnyValue:
644+
# A single instance of AnyValue suffices.
645+
return self
646+
647+
648+
ANY_VALUE = AnyValue()
649+
650+
637651
class Constant(ValuePattern):
638652
"""Represents a pattern that matches against a scalar constant value."""
639653

@@ -1108,6 +1122,9 @@ def _bind_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bo
11081122

11091123
def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool:
11101124
"""Match an IR value against a ValuePattern instance."""
1125+
if isinstance(pattern_value, AnyValue):
1126+
return True
1127+
11111128
if not self._bind_value(pattern_value, value):
11121129
return False
11131130

onnxscript/rewriter/pattern_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,27 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]:
667667
onnxscript.optimizer.inline(model)
668668
self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul", "Add", "Mul"])
669669

670+
def test_any_value(self):
671+
def source_pattern(op, x):
672+
return op.Add(x, op.Mul(0, pattern.ANY_VALUE))
673+
674+
def replacement(op, x):
675+
return op.Identity(x)
676+
677+
rule = pattern.RewriteRule(source_pattern, replacement)
678+
679+
@script()
680+
def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]:
681+
zero = op.Constant(value_float=0.0)
682+
return op.Add(x, op.Mul(zero, y))
683+
684+
model_proto = test_model.to_model_proto()
685+
model = ir.serde.deserialize_model(model_proto)
686+
self.assertEqual([x.op_type for x in model.graph], ["Constant", "Mul", "Add"])
687+
rule.apply_to_model(model)
688+
self.assertEqual(len(model.graph), 2)
689+
self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"])
690+
670691

671692
class PatternBuilderTest(unittest.TestCase):
672693
def test_pattern_builder_context(self):

0 commit comments

Comments
 (0)