@@ -80,25 +80,6 @@ def _check_model(
8080 opset_imports = [onnx .helper .make_opsetid ("" , 18 )],
8181 ),
8282 ),
83- (
84- "mul_by_one" ,
85- _make_model (
86- onnx .helper .make_graph (
87- [
88- onnx .helper .make_node ("Mul" , ["X" , "one" ], ["Y" ]),
89- ],
90- "name" ,
91- [onnx .helper .make_tensor_value_info ("X" , FLOAT , [None ])],
92- [onnx .helper .make_tensor_value_info ("Y" , FLOAT , [None ])],
93- [
94- onnx .numpy_helper .from_array (
95- np .array ([1 ], dtype = np .float32 ), name = "one"
96- )
97- ],
98- ),
99- opset_imports = [onnx .helper .make_opsetid ("" , 18 )],
100- ),
101- ),
10283 (
10384 "canceled_out_transposes" ,
10485 _make_model (
@@ -180,7 +161,7 @@ def test_llama_p0_rule_set_transpose_transpose(self, _: str, model: ir.Model):
180161 ]
181162 )
182163 def test_llama_p0_rule_set_cast_cast (self , _ : str , model : ir .Model ):
183- rule_set = llama_rule_sets .llama_p0_rule_set ()
164+ rule_set = llama_rule_sets .cast_cast_rule
184165 model_proto = ir .serde .serialize_model (model )
185166 rule_set .apply_to_model (model )
186167 rewritten_model = ir .serde .serialize_model (model )
0 commit comments