Skip to content

Commit 33966b1

Browse files
committed
fix tests
1 parent 2a61151 commit 33966b1

2 files changed

Lines changed: 2 additions & 21 deletions

File tree

onnxscript/rewriter/llama_rule_sets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet:
291291
"""
292292
return orp.RewriteRuleSet(
293293
[
294-
# cast_cast_rule, Might have precision issues.
294+
# cast_cast_rule, # Might have precision issues.
295295
cast_identity_rule,
296296
expand_identity_rule,
297297
reshape_reshape_rule,

onnxscript/rewriter/llama_rule_sets_test.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,25 +80,6 @@ def _check_model(
8080
opset_imports=[onnx.helper.make_opsetid("", 18)],
8181
),
8282
),
83-
(
84-
"mul_by_one",
85-
_make_model(
86-
onnx.helper.make_graph(
87-
[
88-
onnx.helper.make_node("Mul", ["X", "one"], ["Y"]),
89-
],
90-
"name",
91-
[onnx.helper.make_tensor_value_info("X", FLOAT, [None])],
92-
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None])],
93-
[
94-
onnx.numpy_helper.from_array(
95-
np.array([1], dtype=np.float32), name="one"
96-
)
97-
],
98-
),
99-
opset_imports=[onnx.helper.make_opsetid("", 18)],
100-
),
101-
),
10283
(
10384
"canceled_out_transposes",
10485
_make_model(
@@ -180,7 +161,7 @@ def test_llama_p0_rule_set_transpose_transpose(self, _: str, model: ir.Model):
180161
]
181162
)
182163
def test_llama_p0_rule_set_cast_cast(self, _: str, model: ir.Model):
183-
rule_set = llama_rule_sets.llama_p0_rule_set()
164+
rule_set = llama_rule_sets.cast_cast_rule
184165
model_proto = ir.serde.serialize_model(model)
185166
rule_set.apply_to_model(model)
186167
rewritten_model = ir.serde.serialize_model(model)

0 commit comments

Comments
 (0)