Skip to content

Commit f53863a

Browse files
Copilotjustinchuby
andcommitted
Update llama_rule_sets_test.py to use basic_rules
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent 7ccdbfb commit f53863a

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

onnxscript/rewriter/llama_rule_sets_test.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3+
"""Test file for basic optimization rules (formerly llama_rule_sets).
4+
5+
.. deprecated::
6+
This test file is deprecated. New tests should be added to basic_rules_test.py.
7+
This file is kept for backward compatibility and to ensure existing tests pass.
8+
"""
39
from __future__ import annotations
410

511
import unittest
@@ -12,7 +18,7 @@
1218

1319
import onnxscript
1420
import onnxscript.onnx_types as ot
15-
import onnxscript.rewriter.llama_rule_sets as llama_rule_sets
21+
import onnxscript.rewriter.basic_rules as basic_rules
1622
from onnxscript import ir
1723
from onnxscript.onnx_opset import opset18
1824

@@ -29,7 +35,7 @@ def _make_model(*args, **kwargs) -> ir.Model:
2935
return ir.serde.deserialize_model(onnx.helper.make_model(*args, **kwargs))
3036

3137

32-
class LlamaRuleSetsTest(unittest.TestCase):
38+
class BasicRuleSetsTest(unittest.TestCase):
3339
def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
3440
feeds: dict[str, Any] = {}
3541
for i in model.graph.input:
@@ -98,7 +104,7 @@ def _check_model(
98104
]
99105
)
100106
def test_llama_p0_rule_set_identity(self, _: str, model: ir.Model):
101-
rule_set = llama_rule_sets.llama_p0_rule_set()
107+
rule_set = basic_rules.basic_optimization_rules()
102108
model_proto = ir.serde.serialize_model(model)
103109
rule_set.apply_to_model(model)
104110
rewritten_model = ir.serde.serialize_model(model)
@@ -126,7 +132,7 @@ def test_llama_p0_rule_set_identity(self, _: str, model: ir.Model):
126132
]
127133
)
128134
def test_llama_p0_rule_set_transpose_transpose(self, _: str, model: ir.Model):
129-
rule_set = llama_rule_sets.llama_p0_rule_set()
135+
rule_set = basic_rules.basic_optimization_rules()
130136
model_proto = ir.serde.serialize_model(model)
131137
rule_set.apply_to_model(model)
132138
rewritten_model = ir.serde.serialize_model(model)
@@ -153,10 +159,10 @@ def cast_cast_model(x):
153159
]
154160
)
155161
def test_llama_p0_rule_set_cast_cast(self, _: str, type1, type2, type3):
156-
rule_set = llama_rule_sets.cast_cast_rule
162+
rule = basic_rules.cast_cast_rule
157163
model_proto = self._double_cast_model(type1, type2, type3)
158164
model = ir.serde.deserialize_model(model_proto)
159-
rule_set.apply_to_model(model)
165+
rule.apply_to_model(model)
160166
rewritten_model = ir.serde.serialize_model(model)
161167

162168
self.assertEqual(["Cast"], [n.op_type for n in model.graph])
@@ -173,7 +179,7 @@ def test_llama_p0_rule_set_cast_cast(self, _: str, type1, type2, type3):
173179
]
174180
)
175181
def test_llama_p0_rule_set_cast_identity(self, _: str, model: ir.Model):
176-
rule_set = llama_rule_sets.llama_p0_rule_set()
182+
rule_set = basic_rules.basic_optimization_rules()
177183
model_proto = ir.serde.serialize_model(model)
178184
rule_set.apply_to_model(model)
179185
rewritten_model = ir.serde.serialize_model(model)
@@ -229,7 +235,7 @@ def test_llama_p0_rule_set_cast_identity(self, _: str, model: ir.Model):
229235
def test_llama_p0_rule_set_expand_identity(
230236
self, _: str, model: ir.Model, expected_nodes: tuple[str, ...]
231237
):
232-
rule_set = llama_rule_sets.llama_p0_rule_set()
238+
rule_set = basic_rules.basic_optimization_rules()
233239
model_proto = ir.serde.serialize_model(model)
234240
rule_set.apply_to_model(model)
235241
rewritten_model = ir.serde.serialize_model(model)
@@ -311,7 +317,7 @@ def test_llama_p0_rule_set_expand_identity(
311317
]
312318
)
313319
def test_llama_p0_rule_set_unsqueeze_unsqueeze(self, _: str, model: ir.Model):
314-
rule_set = llama_rule_sets.llama_p0_rule_set()
320+
rule_set = basic_rules.basic_optimization_rules()
315321
model_proto = ir.serde.serialize_model(model)
316322
rule_set.apply_to_model(model)
317323
rewritten_model = ir.serde.serialize_model(model)
@@ -370,7 +376,7 @@ def test_llama_p0_rule_set_unsqueeze_unsqueeze(self, _: str, model: ir.Model):
370376
]
371377
)
372378
def test_llama_p0_rule_set_reshape_reshape(self, _: str, model: ir.Model):
373-
rule_set = llama_rule_sets.llama_p0_rule_set()
379+
rule_set = basic_rules.basic_optimization_rules()
374380
model_proto = ir.serde.serialize_model(model)
375381
rule_set.apply_to_model(model)
376382
rewritten_model = ir.serde.serialize_model(model)
@@ -421,15 +427,15 @@ def _slides_split_models(cls):
421427
def test_llama_p0_rule_set_slice_split(self):
422428
for model_proto in self._slides_split_models():
423429
ir_model = ir.serde.deserialize_model(model_proto)
424-
rule_set = llama_rule_sets.llama_p0_rule_set()
430+
rule_set = basic_rules.basic_optimization_rules()
425431
rule_set.apply_to_model(ir_model)
426432
rewritten_model = ir.serde.serialize_model(ir_model)
427433

428434
self.assertEqual(["Split"], [n.op_type for n in rewritten_model.graph.node])
429435
self._check_model(model_proto, rewritten_model)
430436

431437
def test_squeeze_reshape_1d_test(self):
432-
rule = llama_rule_sets.squeeze_reshape_1d_rule
438+
rule = basic_rules.squeeze_reshape_1d_rule
433439

434440
def check(model_script, expected_count) -> None:
435441
model_proto = model_script.to_model_proto()

0 commit comments

Comments
 (0)