Skip to content

Commit 678349d

Browse files
committed
Unify error handling in rewrite rules
Signed-off-by: G Ramalingam <grama@microsoft.com>
1 parent 19e5284 commit 678349d

2 files changed

Lines changed: 111 additions & 3 deletions

File tree

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,23 @@ def __init__(self, function) -> None:
217217

218218
def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None:
219219
context = RewriterContext()
220-
new_outputs = self._function(context, **match.bindings)
221-
if new_outputs is None:
222-
return None # Failed to create replacement subgraph
220+
try:
221+
new_outputs = self._function(context, **match.bindings)
222+
except _basics.MatchFailureError as e:
223+
match.fail(e.reason, list(e.failure_sources))
224+
return None
225+
# Support the same failure conventions as check functions for uniformity:
226+
# - None or False: simple failure indicator (deprecated for False, but supported)
227+
# - Falsy MatchResult: failure with reason/source info
228+
if new_outputs is None or new_outputs is False:
229+
return None
230+
if isinstance(new_outputs, _basics.MatchResult):
231+
if not new_outputs:
232+
match.fail(
233+
new_outputs.reason,
234+
new_outputs.failure_nodes_and_values,
235+
)
236+
return None
223237
if not isinstance(new_outputs, Sequence):
224238
new_outputs = [new_outputs]
225239
return ReplacementSubgraph(

onnxscript/rewriter/pattern_base_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from onnxscript import ir
88
from onnxscript.rewriter import pattern
9+
from onnxscript.rewriter._basics import MatchFailureError
910

1011

1112
class PatternTest(unittest.TestCase):
@@ -249,5 +250,98 @@ def rewrite(self, op, x):
249250
self.assertEqual(rule.name, "SimpleIdentityRule")
250251

251252

253+
class RewriteFailureConventionsTest(unittest.TestCase):
254+
"""Test that rewrite functions support the same failure conventions as check functions.
255+
256+
The check side supports three failure conventions:
257+
- Return a falsy MatchResult (via .fail())
258+
- Raise MatchFailureError
259+
- Return None or False
260+
261+
The rewrite side should support all three as well, for uniformity.
262+
"""
263+
264+
_IDENTITY_MODEL_TEXT = """
265+
<ir_version: 7, opset_import: [ "" : 17]>
266+
agraph (float[N] x) => (float[N] z)
267+
{
268+
z = Identity(x)
269+
}
270+
"""
271+
272+
def _apply_rewrite_rule(self, rewrite_fn):
273+
"""Helper that builds a RewriteRule with the given rewrite function and applies it."""
274+
275+
def identity_pattern(op, x):
276+
return op.Identity(x)
277+
278+
rule = pattern.RewriteRule(identity_pattern, rewrite_fn, name="TestRule")
279+
model = ir.from_onnx_text(self._IDENTITY_MODEL_TEXT)
280+
count = rule.apply_to_model(model)
281+
return count
282+
283+
def test_rewrite_returning_none_is_treated_as_failure(self):
284+
"""Returning None from rewrite indicates failure (no replacement made)."""
285+
286+
def rewrite_returns_none(op, x):
287+
return None
288+
289+
count = self._apply_rewrite_rule(rewrite_returns_none)
290+
self.assertEqual(count, 0)
291+
292+
def test_rewrite_returning_false_is_treated_as_failure(self):
293+
"""Returning False from rewrite indicates failure (deprecated but supported)."""
294+
295+
def rewrite_returns_false(op, x):
296+
return False
297+
298+
count = self._apply_rewrite_rule(rewrite_returns_false)
299+
self.assertEqual(count, 0)
300+
301+
def test_rewrite_raising_match_failure_error_is_treated_as_failure(self):
302+
"""Raising MatchFailureError from rewrite indicates failure."""
303+
304+
def rewrite_raises_error(op, x):
305+
raise MatchFailureError("Cannot rewrite this node")
306+
307+
count = self._apply_rewrite_rule(rewrite_raises_error)
308+
self.assertEqual(count, 0)
309+
310+
def test_rewrite_returning_falsy_match_result_is_treated_as_failure(self):
311+
"""Returning a falsy MatchResult from rewrite indicates failure."""
312+
313+
def rewrite_returns_failed_match_result(op, x):
314+
result = pattern.MatchResult()
315+
return result.fail("Rewrite not applicable")
316+
317+
count = self._apply_rewrite_rule(rewrite_returns_failed_match_result)
318+
self.assertEqual(count, 0)
319+
320+
def test_rewrite_returning_ir_value_succeeds(self):
321+
"""Returning an ir.Value from rewrite is success (the normal case)."""
322+
323+
# Use a non-self-referential pattern to avoid infinite rewrite loops
324+
def add_zero_pattern(op, x):
325+
zero = pattern.Constant(0.0)
326+
return op.Add(x, zero)
327+
328+
def identity_replacement(op, x):
329+
return op.Identity(x)
330+
331+
rule = pattern.RewriteRule(add_zero_pattern, identity_replacement, name="TestRule")
332+
model = ir.from_onnx_text(
333+
"""
334+
<ir_version: 7, opset_import: [ "" : 17]>
335+
agraph (float[N] x) => (float[N] z)
336+
{
337+
c0 = Constant<value_float = 0.0>()
338+
z = Add(x, c0)
339+
}
340+
"""
341+
)
342+
count = rule.apply_to_model(model)
343+
self.assertGreaterEqual(count, 1)
344+
345+
252346
if __name__ == "__main__":
253347
unittest.main()

0 commit comments

Comments
 (0)