11# Copyright (c) Microsoft Corporation.
22# Licensed under the MIT License.
3+ """Test file for basic optimization rules (formerly llama_rule_sets).
4+
5+ .. deprecated::
6+ This test file is deprecated. New tests should be added to basic_rules_test.py.
7+ This file is kept for backward compatibility and to ensure existing tests pass.
8+ """
39from __future__ import annotations
410
511import unittest
1218
1319import onnxscript
1420import onnxscript .onnx_types as ot
15- import onnxscript .rewriter .llama_rule_sets as llama_rule_sets
21+ import onnxscript .rewriter .basic_rules as basic_rules
1622from onnxscript import ir
1723from onnxscript .onnx_opset import opset18
1824
@@ -29,7 +35,7 @@ def _make_model(*args, **kwargs) -> ir.Model:
2935 return ir .serde .deserialize_model (onnx .helper .make_model (* args , ** kwargs ))
3036
3137
32- class LlamaRuleSetsTest (unittest .TestCase ):
38+ class BasicRuleSetsTest (unittest .TestCase ):
3339 def _get_random_inputs (self , model : onnx .ModelProto ) -> dict [str , Any ]:
3440 feeds : dict [str , Any ] = {}
3541 for i in model .graph .input :
@@ -98,7 +104,7 @@ def _check_model(
98104 ]
99105 )
100106 def test_llama_p0_rule_set_identity (self , _ : str , model : ir .Model ):
101- rule_set = llama_rule_sets . llama_p0_rule_set ()
107+ rule_set = basic_rules . basic_optimization_rules ()
102108 model_proto = ir .serde .serialize_model (model )
103109 rule_set .apply_to_model (model )
104110 rewritten_model = ir .serde .serialize_model (model )
@@ -126,7 +132,7 @@ def test_llama_p0_rule_set_identity(self, _: str, model: ir.Model):
126132 ]
127133 )
128134 def test_llama_p0_rule_set_transpose_transpose (self , _ : str , model : ir .Model ):
129- rule_set = llama_rule_sets . llama_p0_rule_set ()
135+ rule_set = basic_rules . basic_optimization_rules ()
130136 model_proto = ir .serde .serialize_model (model )
131137 rule_set .apply_to_model (model )
132138 rewritten_model = ir .serde .serialize_model (model )
@@ -153,10 +159,10 @@ def cast_cast_model(x):
153159 ]
154160 )
155161 def test_llama_p0_rule_set_cast_cast (self , _ : str , type1 , type2 , type3 ):
156- rule_set = llama_rule_sets .cast_cast_rule
162+ rule = basic_rules .cast_cast_rule
157163 model_proto = self ._double_cast_model (type1 , type2 , type3 )
158164 model = ir .serde .deserialize_model (model_proto )
159- rule_set .apply_to_model (model )
165+ rule .apply_to_model (model )
160166 rewritten_model = ir .serde .serialize_model (model )
161167
162168 self .assertEqual (["Cast" ], [n .op_type for n in model .graph ])
@@ -173,7 +179,7 @@ def test_llama_p0_rule_set_cast_cast(self, _: str, type1, type2, type3):
173179 ]
174180 )
175181 def test_llama_p0_rule_set_cast_identity (self , _ : str , model : ir .Model ):
176- rule_set = llama_rule_sets . llama_p0_rule_set ()
182+ rule_set = basic_rules . basic_optimization_rules ()
177183 model_proto = ir .serde .serialize_model (model )
178184 rule_set .apply_to_model (model )
179185 rewritten_model = ir .serde .serialize_model (model )
@@ -229,7 +235,7 @@ def test_llama_p0_rule_set_cast_identity(self, _: str, model: ir.Model):
229235 def test_llama_p0_rule_set_expand_identity (
230236 self , _ : str , model : ir .Model , expected_nodes : tuple [str , ...]
231237 ):
232- rule_set = llama_rule_sets . llama_p0_rule_set ()
238+ rule_set = basic_rules . basic_optimization_rules ()
233239 model_proto = ir .serde .serialize_model (model )
234240 rule_set .apply_to_model (model )
235241 rewritten_model = ir .serde .serialize_model (model )
@@ -311,7 +317,7 @@ def test_llama_p0_rule_set_expand_identity(
311317 ]
312318 )
313319 def test_llama_p0_rule_set_unsqueeze_unsqueeze (self , _ : str , model : ir .Model ):
314- rule_set = llama_rule_sets . llama_p0_rule_set ()
320+ rule_set = basic_rules . basic_optimization_rules ()
315321 model_proto = ir .serde .serialize_model (model )
316322 rule_set .apply_to_model (model )
317323 rewritten_model = ir .serde .serialize_model (model )
@@ -370,7 +376,7 @@ def test_llama_p0_rule_set_unsqueeze_unsqueeze(self, _: str, model: ir.Model):
370376 ]
371377 )
372378 def test_llama_p0_rule_set_reshape_reshape (self , _ : str , model : ir .Model ):
373- rule_set = llama_rule_sets . llama_p0_rule_set ()
379+ rule_set = basic_rules . basic_optimization_rules ()
374380 model_proto = ir .serde .serialize_model (model )
375381 rule_set .apply_to_model (model )
376382 rewritten_model = ir .serde .serialize_model (model )
@@ -421,15 +427,15 @@ def _slides_split_models(cls):
421427 def test_llama_p0_rule_set_slice_split (self ):
422428 for model_proto in self ._slides_split_models ():
423429 ir_model = ir .serde .deserialize_model (model_proto )
424- rule_set = llama_rule_sets . llama_p0_rule_set ()
430+ rule_set = basic_rules . basic_optimization_rules ()
425431 rule_set .apply_to_model (ir_model )
426432 rewritten_model = ir .serde .serialize_model (ir_model )
427433
428434 self .assertEqual (["Split" ], [n .op_type for n in rewritten_model .graph .node ])
429435 self ._check_model (model_proto , rewritten_model )
430436
431437 def test_squeeze_reshape_1d_test (self ):
432- rule = llama_rule_sets .squeeze_reshape_1d_rule
438+ rule = basic_rules .squeeze_reshape_1d_rule
433439
434440 def check (model_script , expected_count ) -> None :
435441 model_proto = model_script .to_model_proto ()
0 commit comments