Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@
import onnxscript.rewriter.pattern as orp


class SqueezeReshape(orp.RewriteRuleClassBase):
"""Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.

This pattern arises from the translation of pytorch symints.
"""

def __init__(self):
super().__init__("SqueezeReshape1d", remove_nodes=False)

def pattern(self, op, x):
return op.Reshape(op.Squeeze(x), [-1])
Comment thread
gramalingam marked this conversation as resolved.

def rewrite(self, op, x: ir.Value):
return op.Identity(x)

def check(self, context, x) -> bool:
del context # Unused
return ir_utils.has_rank(x, 1)


class CastIdentity(orp.RewriteRuleAsClass):
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""

Expand Down Expand Up @@ -259,6 +279,7 @@ def check(cls, context, x, axes1, axes2) -> bool:
transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity)
transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose)
unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze)
squeeze_reshape_1d_rule = SqueezeReshape.rule()


def llama_p0_rule_set() -> orp.RewriteRuleSet:
Expand Down
37 changes: 37 additions & 0 deletions onnxscript/rewriter/llama_rule_sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,43 @@ def test_llama_p0_rule_set_slice_split(self):
self.assertEqual(["Split"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model)

def test_squeeze_reshape_1d_test(self):
rule = llama_rule_sets.squeeze_reshape_1d_rule

def check(model_script, expected_count) -> None:
model_proto = model_script.to_model_proto()
ir_model = ir.serde.deserialize_model(model_proto)
count = rule.apply_to_model(ir_model)
self.assertEqual(count, expected_count)
if count > 0:
self.assertEqual([x.op_type for x in ir_model.graph], ["Identity"])
rewritten_proto = ir.serde.serialize_model(ir_model)
self._check_model(model_proto, rewritten_proto)

op = onnxscript.opset17

# input of shape [12]
@onnxscript.script()
def model1(X: ot.FLOAT[12]):
return op.Reshape(op.Squeeze(X), [-1])

check(model1, 1)

# input of shape [1]
@onnxscript.script()
def model2(X: ot.FLOAT[1]):
return op.Reshape(op.Squeeze(X), [-1])

check(model2, 1)

# input of shape [1, 1]
# This should NOT be optimized to Identity
@onnxscript.script()
def model3(X: ot.FLOAT[1, 1]):
return op.Reshape(op.Squeeze(X), [-1])

check(model3, 0)


if __name__ == "__main__":
unittest.main(verbosity=2)
5 changes: 5 additions & 0 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,9 @@ def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> No
if commute:
rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules]))
self.rules = rules
# We call remove_unused_nodes at end of rewriting if there is any rule that does
# NOT remove nodes (immediately when it is applied)
self.remove_unused_nodes = any(not rule.remove_nodes for rule in rules)

def _apply_to_graph_or_function(
self,
Expand Down Expand Up @@ -1759,6 +1762,8 @@ def apply_to_model(
count += self._apply_to_graph_or_function(
model, function, verbose=verbose, tracer=tracer
)
if self.remove_unused_nodes:
onnxscript.optimizer.remove_unused_nodes(model)
if tracer:
tracer.report()
return count
Expand Down