diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index a6b24b7141..2dd3fd8e3f 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -12,6 +12,26 @@ import onnxscript.rewriter.pattern as orp +class SqueezeReshape(orp.RewriteRuleClassBase): + """Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x. + + This pattern arises from the translation of pytorch symints. + """ + + def __init__(self): + super().__init__("SqueezeReshape1d", remove_nodes=False) + + def pattern(self, op, x): + return op.Reshape(op.Squeeze(x), [-1]) + + def rewrite(self, op, x: ir.Value): + return op.Identity(x) + + def check(self, context, x) -> bool: + del context # Unused + return ir_utils.has_rank(x, 1) + + class CastIdentity(orp.RewriteRuleAsClass): """Replaces ``Cast(., to=to)`` by ``Identity`` if possible.""" @@ -259,6 +279,7 @@ def check(cls, context, x, axes1, axes2) -> bool: transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity) transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose) unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze) +squeeze_reshape_1d_rule = SqueezeReshape.rule() def llama_p0_rule_set() -> orp.RewriteRuleSet: diff --git a/onnxscript/rewriter/llama_rule_sets_test.py b/onnxscript/rewriter/llama_rule_sets_test.py index 0d430760f4..2dd5762767 100644 --- a/onnxscript/rewriter/llama_rule_sets_test.py +++ b/onnxscript/rewriter/llama_rule_sets_test.py @@ -452,6 +452,43 @@ def test_llama_p0_rule_set_slice_split(self): self.assertEqual(["Split"], [n.op_type for n in rewritten_model.graph.node]) self._check_model(model_proto, rewritten_model) + def test_squeeze_reshape_1d_test(self): + rule = llama_rule_sets.squeeze_reshape_1d_rule + + def check(model_script, expected_count) -> None: + model_proto = model_script.to_model_proto() + ir_model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(ir_model) + self.assertEqual(count, expected_count) + if count > 0: + self.assertEqual([x.op_type for x in ir_model.graph], ["Identity"]) + rewritten_proto = ir.serde.serialize_model(ir_model) + self._check_model(model_proto, rewritten_proto) + + op = onnxscript.opset17 + + # input of shape [12] + @onnxscript.script() + def model1(X: ot.FLOAT[12]): + return op.Reshape(op.Squeeze(X), [-1]) + + check(model1, 1) + + # input of shape [1] + @onnxscript.script() + def model2(X: ot.FLOAT[1]): + return op.Reshape(op.Squeeze(X), [-1]) + + check(model2, 1) + + # input of shape [1, 1] + # This should NOT be optimized to Identity + @onnxscript.script() + def model3(X: ot.FLOAT[1, 1]): + return op.Reshape(op.Squeeze(X), [-1]) + + check(model3, 0) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 1f3e7e8c07..6a40d3e974 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1627,6 +1627,9 @@ def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> No if commute: rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules])) self.rules = rules + # We call remove_unused_nodes at end of rewriting if there is any rule that does + # NOT remove nodes (immediately when it is applied) + self.remove_unused_nodes = any(not rule.remove_nodes for rule in rules) def _apply_to_graph_or_function( self, @@ -1759,6 +1762,8 @@ def apply_to_model( count += self._apply_to_graph_or_function( model, function, verbose=verbose, tracer=tracer ) + if self.remove_unused_nodes: + onnxscript.optimizer.remove_unused_nodes(model) if tracer: tracer.report() return count