diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index ddb42a31da..71f107328b 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -12,6 +12,7 @@ cast_constant_of_shape, collapse_slices, gemm_to_matmul_add, + llama_rule_sets, no_op, ) @@ -23,6 +24,7 @@ gemm_to_matmul_add.rule, *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, + *llama_rule_sets.llama_p0_rule_set().rules, ] diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index a87d01e785..d6c4177ae8 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -7,8 +7,7 @@ import numpy as np -import onnxscript.ir as ir -from onnxscript.optimizer import basic_constant_propagation +from onnxscript import ir, optimizer def display_nodes(nodes: Sequence[ir.Node]) -> None: @@ -54,7 +53,7 @@ def visit(node: ir.Node, depth): def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: node = value.producer() if node is not None: - basic_constant_propagation([node]) + optimizer.basic_constant_propagation([node]) return value.const_value diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 17df20267c..dd8c2aedaf 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -6,10 +6,9 @@ import onnx.numpy_helper -import onnxscript.ir as ir -import onnxscript.rewriter._ir_utils as ir_utils -import onnxscript.rewriter.no_op as no_op -import onnxscript.rewriter.pattern as orp +from onnxscript import ir +from onnxscript.rewriter import _ir_utils as ir_utils +from onnxscript.rewriter import pattern as orp class SqueezeReshape(orp.RewriteRuleClassBase): @@ -292,15 +291,11 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet: """ return orp.RewriteRuleSet( [ - no_op.mul_by_1_rule, - no_op.add_0_rule, - no_op.add_0_rule, - no_op.div_by_1_rule, - cast_cast_rule, + # cast_cast_rule, # Might have precision issues. cast_identity_rule, expand_identity_rule, reshape_reshape_rule, - slice_split_rule, + slice_split_rule, # Affect collapse slices rules? transpose_identity_rule, transpose_transpose_rule, unsqueeze_unsqueeze_rule, diff --git a/onnxscript/rewriter/llama_rule_sets_test.py b/onnxscript/rewriter/llama_rule_sets_test.py index 2dd5762767..29bbcb6004 100644 --- a/onnxscript/rewriter/llama_rule_sets_test.py +++ b/onnxscript/rewriter/llama_rule_sets_test.py @@ -80,25 +80,6 @@ def _check_model( opset_imports=[onnx.helper.make_opsetid("", 18)], ), ), - ( - "mul_by_one", - _make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Mul", ["X", "one"], ["Y"]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [None])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [None])], - [ - onnx.numpy_helper.from_array( - np.array([1], dtype=np.float32), name="one" - ) - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ), - ), ( "canceled_out_transposes", _make_model( @@ -180,7 +161,7 @@ def test_llama_p0_rule_set_transpose_transpose(self, _: str, model: ir.Model): ] ) def test_llama_p0_rule_set_cast_cast(self, _: str, model: ir.Model): - rule_set = llama_rule_sets.llama_p0_rule_set() + rule_set = llama_rule_sets.cast_cast_rule model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model)