|
6 | 6 |
|
7 | 7 | from onnxscript import ir |
8 | 8 | from onnxscript.rewriter import pattern |
| 9 | +from onnxscript.rewriter._basics import MatchFailureError |
9 | 10 |
|
10 | 11 |
|
11 | 12 | class PatternTest(unittest.TestCase): |
@@ -249,5 +250,98 @@ def rewrite(self, op, x): |
249 | 250 | self.assertEqual(rule.name, "SimpleIdentityRule") |
250 | 251 |
|
251 | 252 |
|
| 253 | +class RewriteFailureConventionsTest(unittest.TestCase): |
| 254 | + """Test that rewrite functions support the same failure conventions as check functions. |
| 255 | +
|
| 256 | + The check side supports three failure conventions: |
| 257 | + - Return a falsy MatchResult (via .fail()) |
| 258 | + - Raise MatchFailureError |
| 259 | + - Return None or False |
| 260 | +
|
| 261 | + The rewrite side should support all three as well, for uniformity. |
| 262 | + """ |
| 263 | + |
| 264 | + _IDENTITY_MODEL_TEXT = """ |
| 265 | + <ir_version: 7, opset_import: [ "" : 17]> |
| 266 | + agraph (float[N] x) => (float[N] z) |
| 267 | + { |
| 268 | + z = Identity(x) |
| 269 | + } |
| 270 | + """ |
| 271 | + |
| 272 | + def _apply_rewrite_rule(self, rewrite_fn): |
| 273 | + """Helper that builds a RewriteRule with the given rewrite function and applies it.""" |
| 274 | + |
| 275 | + def identity_pattern(op, x): |
| 276 | + return op.Identity(x) |
| 277 | + |
| 278 | + rule = pattern.RewriteRule(identity_pattern, rewrite_fn, name="TestRule") |
| 279 | + model = ir.from_onnx_text(self._IDENTITY_MODEL_TEXT) |
| 280 | + count = rule.apply_to_model(model) |
| 281 | + return count |
| 282 | + |
| 283 | + def test_rewrite_returning_none_is_treated_as_failure(self): |
| 284 | + """Returning None from rewrite indicates failure (no replacement made).""" |
| 285 | + |
| 286 | + def rewrite_returns_none(op, x): |
| 287 | + return None |
| 288 | + |
| 289 | + count = self._apply_rewrite_rule(rewrite_returns_none) |
| 290 | + self.assertEqual(count, 0) |
| 291 | + |
| 292 | + def test_rewrite_returning_false_is_treated_as_failure(self): |
| 293 | + """Returning False from rewrite indicates failure (deprecated but supported).""" |
| 294 | + |
| 295 | + def rewrite_returns_false(op, x): |
| 296 | + return False |
| 297 | + |
| 298 | + count = self._apply_rewrite_rule(rewrite_returns_false) |
| 299 | + self.assertEqual(count, 0) |
| 300 | + |
| 301 | + def test_rewrite_raising_match_failure_error_is_treated_as_failure(self): |
| 302 | + """Raising MatchFailureError from rewrite indicates failure.""" |
| 303 | + |
| 304 | + def rewrite_raises_error(op, x): |
| 305 | + raise MatchFailureError("Cannot rewrite this node") |
| 306 | + |
| 307 | + count = self._apply_rewrite_rule(rewrite_raises_error) |
| 308 | + self.assertEqual(count, 0) |
| 309 | + |
| 310 | + def test_rewrite_returning_falsy_match_result_is_treated_as_failure(self): |
| 311 | + """Returning a falsy MatchResult from rewrite indicates failure.""" |
| 312 | + |
| 313 | + def rewrite_returns_failed_match_result(op, x): |
| 314 | + result = pattern.MatchResult() |
| 315 | + return result.fail("Rewrite not applicable") |
| 316 | + |
| 317 | + count = self._apply_rewrite_rule(rewrite_returns_failed_match_result) |
| 318 | + self.assertEqual(count, 0) |
| 319 | + |
| 320 | + def test_rewrite_returning_ir_value_succeeds(self): |
| 321 | + """Returning an ir.Value from rewrite is success (the normal case).""" |
| 322 | + |
| 323 | + # Use a non-self-referential pattern to avoid infinite rewrite loops |
| 324 | + def add_zero_pattern(op, x): |
| 325 | + zero = pattern.Constant(0.0) |
| 326 | + return op.Add(x, zero) |
| 327 | + |
| 328 | + def identity_replacement(op, x): |
| 329 | + return op.Identity(x) |
| 330 | + |
| 331 | + rule = pattern.RewriteRule(add_zero_pattern, identity_replacement, name="TestRule") |
| 332 | + model = ir.from_onnx_text( |
| 333 | + """ |
| 334 | + <ir_version: 7, opset_import: [ "" : 17]> |
| 335 | + agraph (float[N] x) => (float[N] z) |
| 336 | + { |
| 337 | + c0 = Constant<value_float = 0.0>() |
| 338 | + z = Add(x, c0) |
| 339 | + } |
| 340 | + """ |
| 341 | + ) |
| 342 | + count = rule.apply_to_model(model) |
| 343 | + self.assertGreaterEqual(count, 1) |
| 344 | + |
| 345 | + |
252 | 346 | if __name__ == "__main__": |
253 | 347 | unittest.main() |
0 commit comments