Skip to content

Commit 1077da7

Browse files
gramalingamCopilot
andauthored
Unify failure-handling in rewrite-rule (#2866)
Unify failure-handling in rewrite-rule --------- Signed-off-by: G Ramalingam <grama@microsoft.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 0cced88 commit 1077da7

File tree

3 files changed

+133
-8
lines changed

3 files changed

+133
-8
lines changed

docs/tutorial/rewriter/node_value_checkers.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,16 @@ This means you should be careful when designing patterns with multiple alternati
179179

180180
## Error Handling
181181

182-
Checkers can return either:
183-
- `True`: Check passed, continue matching
184-
- `False`: Check failed, pattern does not match
185-
- `MatchResult`: More detailed result with potential failure reasons
182+
Both check functions (including condition functions and node/value-level checkers) and
183+
rewrite functions support the same conventions for indicating failure:
186184

187-
If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions.
185+
- **`MatchResult` with `.fail()`** *(recommended)*: Return `MatchResult().fail("reason", source)` to indicate failure with a descriptive reason and optional source node/value. This provides the most useful diagnostic information for debugging.
186+
- **Raise `MatchFailureError`** *(recommended)*: Import it as `from onnxscript.rewriter.rewriter import MatchFailureError` and raise `MatchFailureError("reason", source1, source2, ...)` to indicate failure associated with one or more `ir.Node` or `ir.Value` objects. Each source should be passed as a separate positional argument (do not pass a list as a single argument). This is especially convenient in utility functions called from a check or rewrite, since it avoids having to explicitly propagate failure status through the call chain.
187+
- **Return `None` or `False`**: These indicate failure without providing a reason. They are supported but not recommended, since a failure reason is valuable for debugging why a rule did not apply.
188+
189+
Including a descriptive failure reason is strongly encouraged. The rewriter's tracing infrastructure
190+
uses these reasons to report why rules failed to match, which is essential for diagnosing
191+
issues when developing or debugging rewrite rules.
192+
193+
For **check functions**, success is indicated by returning `True` or a truthy `MatchResult`.
194+
For **rewrite functions**, success is indicated by returning one or more `ir.Value` results.

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,34 @@ def __init__(self, function) -> None:
217217

218218
def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None:
219219
context = RewriterContext()
220-
new_outputs = self._function(context, **match.bindings)
221-
if new_outputs is None:
222-
return None # Failed to create replacement subgraph
220+
try:
221+
new_outputs = self._function(context, **match.bindings)
222+
except _basics.MatchFailureError as e:
223+
match.fail(e.reason, list(e.failure_sources))
224+
return None
225+
# Support the same failure conventions as check functions for uniformity:
226+
# - None or False: indicates failure without a reason (not recommended)
227+
# - Falsy MatchResult: failure with reason/source info (recommended)
228+
# - MatchFailureError exception: failure with reason/source info (recommended)
229+
if new_outputs is None or new_outputs is False:
230+
return None
231+
if isinstance(new_outputs, _basics.MatchResult):
232+
if not new_outputs:
233+
# A falsy MatchResult is the recommended way to signal failure with
234+
# reason/source information from a replacement function.
235+
match.fail(
236+
new_outputs.reason,
237+
new_outputs.failure_nodes_and_values,
238+
)
239+
return None
240+
# A truthy MatchResult should never be returned from a replacement
241+
# function. Treat this as a programmer error to avoid silent failures.
242+
raise TypeError(
243+
"Replacement function returned a truthy MatchResult. "
244+
"Replacement functions should either return None/False for a "
245+
"generic failure, return a *falsy* MatchResult to provide "
246+
"failure details, or raise MatchFailureError."
247+
)
223248
if not isinstance(new_outputs, Sequence):
224249
new_outputs = [new_outputs]
225250
return ReplacementSubgraph(

onnxscript/rewriter/pattern_base_test.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from onnxscript import ir
88
from onnxscript.rewriter import pattern
9+
from onnxscript.rewriter._basics import MatchFailureError
910

1011

1112
class PatternTest(unittest.TestCase):
@@ -249,5 +250,97 @@ def rewrite(self, op, x):
249250
self.assertEqual(rule.name, "SimpleIdentityRule")
250251

251252

253+
class RewriteFailureConventionsTest(unittest.TestCase):
254+
"""Test that rewrite functions support the same failure conventions as check functions.
255+
256+
The check side supports three failure conventions:
257+
- Return a falsy MatchResult (via .fail())
258+
- Raise MatchFailureError
259+
- Return None or False
260+
261+
The rewrite side should support all three as well, for uniformity.
262+
"""
263+
264+
_IDENTITY_MODEL_TEXT = """
265+
<ir_version: 7, opset_import: [ "" : 17]>
266+
agraph (float[N] x) => (float[N] z)
267+
{
268+
z = Identity(x)
269+
}
270+
"""
271+
272+
def _apply_rewrite_rule(self, rewrite_fn):
273+
"""Helper that builds a RewriteRule with the given rewrite function and applies it."""
274+
275+
def identity_pattern(op, x):
276+
return op.Identity(x)
277+
278+
rule = pattern.RewriteRule(identity_pattern, rewrite_fn, name="TestRule")
279+
model = ir.from_onnx_text(self._IDENTITY_MODEL_TEXT)
280+
count = rule.apply_to_model(model)
281+
return count
282+
283+
def test_rewrite_returning_none_is_treated_as_failure(self):
284+
"""Returning None from rewrite indicates failure (no replacement made)."""
285+
286+
def rewrite_returns_none(op, x):
287+
return None
288+
289+
count = self._apply_rewrite_rule(rewrite_returns_none)
290+
self.assertEqual(count, 0)
291+
292+
def test_rewrite_returning_false_is_treated_as_failure(self):
293+
"""Returning False from rewrite indicates failure (deprecated but supported)."""
294+
295+
def rewrite_returns_false(op, x):
296+
return False
297+
298+
count = self._apply_rewrite_rule(rewrite_returns_false)
299+
self.assertEqual(count, 0)
300+
301+
def test_rewrite_raising_match_failure_error_is_treated_as_failure(self):
302+
"""Raising MatchFailureError from rewrite indicates failure."""
303+
304+
def rewrite_raises_error(op, x):
305+
raise MatchFailureError("Cannot rewrite this node")
306+
307+
count = self._apply_rewrite_rule(rewrite_raises_error)
308+
self.assertEqual(count, 0)
309+
310+
def test_rewrite_returning_falsy_match_result_is_treated_as_failure(self):
311+
"""Returning a falsy MatchResult from rewrite indicates failure."""
312+
313+
def rewrite_returns_failed_match_result(op, x):
314+
result = pattern.MatchResult()
315+
return result.fail("Rewrite not applicable")
316+
317+
count = self._apply_rewrite_rule(rewrite_returns_failed_match_result)
318+
self.assertEqual(count, 0)
319+
320+
def test_rewrite_returning_ir_value_succeeds(self):
321+
"""Returning an ir.Value from rewrite is success (the normal case)."""
322+
323+
# Use a non-self-referential pattern to avoid infinite rewrite loops
324+
def add_zero_pattern(op, x):
325+
return op.Add(x, 0.0)
326+
327+
def identity_replacement(op, x):
328+
return op.Identity(x)
329+
330+
rule = pattern.RewriteRule(add_zero_pattern, identity_replacement, name="TestRule")
331+
model = ir.from_onnx_text(
332+
"""
333+
<ir_version: 7, opset_import: [ "" : 17]>
334+
agraph (float[N] x) => (float[N] z)
335+
{
336+
c0 = Constant<value_float = 0.0>()
337+
z = Add(x, c0)
338+
}
339+
"""
340+
)
341+
count = rule.apply_to_model(model)
342+
self.assertGreaterEqual(count, 1)
343+
344+
252345
if __name__ == "__main__":
253346
unittest.main()

0 commit comments

Comments
 (0)