Skip to content

Commit e07b248

Browse files
authored
Update unused_removal_test.py
tests for issue #2211
1 parent d0cafdf commit e07b248

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

onnxscript/ir/passes/common/unused_removal_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,62 @@ def test_remove_unused_initializers(self):
5454
self.assertEqual(model.graph.node[0].op_type, "Mul")
5555
self.assertEqual(len(model.graph.initializer), 0)
5656

57+
def test_remove_unused_inputs_initializers():
58+
# remove inputs in case they are initializers
59+
# https://github.com/microsoft/onnxscript/issues/2211
60+
model = onnx.parser.parse_model(
61+
"""
62+
<ir_version: 10, opset_import: [ "" : 17]>
63+
agraph (float[N] x, float[N] two) => (float[N] z)
64+
<float two = {2.0,2.0}> {
65+
four = Add(two, two)
66+
z = Mul(x, x)
67+
}
68+
"""
69+
)
70+
ir_model = onnxscript.ir.serde.deserialize_model(model)
71+
remove_unused_nodes(ir_model)
72+
assert (len(ir_model.graph._nodes)== 1)
73+
assert (len(ir_model.graph.inputs)== 1)
74+
assert (ir_model.graph.node(0).op_type== "Mul")
75+
76+
def test_avoid_remove_unused_inputs_initializers():
77+
# supress remove inputs in case they are initializers
78+
# if explicitly said
79+
model = onnx.parser.parse_model(
80+
"""
81+
<ir_version: 10, opset_import: [ "" : 17]>
82+
agraph (float[N] x, float[N] two) => (float[N] z)
83+
<float two = {2.0,2.0}> {
84+
four = Add(two, two)
85+
z = Mul(x, x)
86+
}
87+
"""
88+
)
89+
ir_model = onnxscript.ir.serde.deserialize_model(model)
90+
remove_unused_nodes(ir_model,False)
91+
assert (len(ir_model.graph._nodes)== 1)
92+
assert (len(ir_model.graph.inputs)== 2)
93+
assert (ir_model.graph.node(0).op_type== "Mul")
94+
95+
def test_avoid_remove_unused_inputs():
96+
# preserve inputs as part of interface
97+
model = onnx.parser.parse_model(
98+
"""
99+
<ir_version: 10, opset_import: [ "" : 17]>
100+
agraph (float[N] x, float[N] two) => (float[N] z)
101+
{
102+
four = Add(two, two)
103+
z = Mul(x, x)
104+
}
105+
"""
106+
)
107+
ir_model = onnxscript.ir.serde.deserialize_model(model)
108+
remove_unused_nodes(ir_model)
109+
assert (len(ir_model.graph._nodes)== 1)
110+
assert (len(ir_model.graph.inputs)== 2)
111+
assert (ir_model.graph.node(0).op_type== "Mul")
112+
57113
def test_partially_used_nodes(self):
58114
model = onnx.parser.parse_model(
59115
"""

0 commit comments

Comments
 (0)