Skip to content

Commit b00036c

Browse files
authored
Update unused_removal_test.py
errors fix
1 parent b3f9cde commit b00036c

File tree

1 file changed

+55
-55
lines changed

1 file changed

+55
-55
lines changed

onnxscript/ir/passes/common/unused_removal_test.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
class RemoveUnusedTest(unittest.TestCase):
1414
using_ir: bool
1515

16-
def remove_unused_nodes(self, model: onnx.ModelProto):
16+
def remove_unused_nodes(self, model: onnx.ModelProto, remove_initialized_inputs: bool=False):
1717
if self.using_ir:
1818
model_ir = ir.serde.deserialize_model(model)
19-
onnxscript.optimizer.remove_unused_nodes(model_ir)
19+
onnxscript.optimizer.remove_unused_nodes(model_ir, remove_initialized_inputs)
2020
model = ir.serde.serialize_model(model_ir)
2121
return model
22-
onnxscript.optimizer.remove_unused_nodes(model)
22+
onnxscript.optimizer.remove_unused_nodes(model, remove_initialized_inputs)
2323
return model
2424

2525
def test_remove_unused_nodes(self):
@@ -54,61 +54,61 @@ def test_remove_unused_initializers(self):
5454
self.assertEqual(model.graph.node[0].op_type, "Mul")
5555
self.assertEqual(len(model.graph.initializer), 0)
5656

57-
def test_remove_unused_inputs_initializers():
58-
# remove inputs in case they are initializers
59-
# https://github.com/microsoft/onnxscript/issues/2211
60-
model = onnx.parser.parse_model(
57+
def test_remove_unused_inputs_initializers():
58+
# remove inputs in case they are initializers
59+
# if explicitly said
60+
# https://github.com/microsoft/onnxscript/issues/2211
61+
model = onnx.parser.parse_model(
62+
"""
63+
<ir_version: 10, opset_import: [ "" : 17]>
64+
agraph (float[N] x, float[N] two) => (float[N] z)
65+
<float two = {2.0,2.0}> {
66+
four = Add(two, two)
67+
z = Mul(x, x)
68+
}
6169
"""
62-
<ir_version: 10, opset_import: [ "" : 17]>
63-
agraph (float[N] x, float[N] two) => (float[N] z)
64-
<float two = {2.0,2.0}> {
65-
four = Add(two, two)
66-
z = Mul(x, x)
67-
}
68-
"""
69-
)
70-
ir_model = onnxscript.ir.serde.deserialize_model(model)
71-
remove_unused_nodes(ir_model)
72-
assert (len(ir_model.graph._nodes)== 1)
73-
assert (len(ir_model.graph.inputs)== 1)
74-
assert (ir_model.graph.node(0).op_type== "Mul")
75-
76-
def test_avoid_remove_unused_inputs_initializers():
77-
# supress remove inputs in case they are initializers
78-
# if explicitly said
79-
model = onnx.parser.parse_model(
70+
)
71+
ir_model = onnxscript.ir.serde.deserialize_model(model)
72+
ir_model = self.remove_unused_nodes(ir_model,True)
73+
assert (len(ir_model.graph._nodes)== 1)
74+
assert (len(ir_model.graph.inputs)== 1)
75+
assert (ir_model.graph.node(0).op_type== "Mul")
76+
77+
def test_avoid_remove_unused_inputs_initializers():
78+
# supress remove inputs in case they are initializers until explicitly said
79+
model = onnx.parser.parse_model(
80+
"""
81+
<ir_version: 10, opset_import: [ "" : 17]>
82+
agraph (float[N] x, float[N] two) => (float[N] z)
83+
<float two = {2.0,2.0}> {
84+
four = Add(two, two)
85+
z = Mul(x, x)
86+
}
8087
"""
81-
<ir_version: 10, opset_import: [ "" : 17]>
82-
agraph (float[N] x, float[N] two) => (float[N] z)
83-
<float two = {2.0,2.0}> {
84-
four = Add(two, two)
85-
z = Mul(x, x)
86-
}
87-
"""
88-
)
89-
ir_model = onnxscript.ir.serde.deserialize_model(model)
90-
remove_unused_nodes(ir_model,False)
91-
assert (len(ir_model.graph._nodes)== 1)
92-
assert (len(ir_model.graph.inputs)== 2)
93-
assert (ir_model.graph.node(0).op_type== "Mul")
94-
95-
def test_avoid_remove_unused_inputs():
96-
# preserve inputs as part of interface
97-
model = onnx.parser.parse_model(
88+
)
89+
ir_model = onnxscript.ir.serde.deserialize_model(model)
90+
ir_model = self.remove_unused_nodes(ir_model)
91+
assert (len(ir_model.graph._nodes)== 1)
92+
assert (len(ir_model.graph.inputs)== 2)
93+
assert (ir_model.graph.node(0).op_type== "Mul")
94+
95+
def test_avoid_remove_unused_inputs():
96+
# preserve inputs as part of interface
97+
model = onnx.parser.parse_model(
98+
"""
99+
<ir_version: 10, opset_import: [ "" : 17]>
100+
agraph (float[N] x, float[N] two) => (float[N] z)
101+
{
102+
four = Add(two, two)
103+
z = Mul(x, x)
104+
}
98105
"""
99-
<ir_version: 10, opset_import: [ "" : 17]>
100-
agraph (float[N] x, float[N] two) => (float[N] z)
101-
{
102-
four = Add(two, two)
103-
z = Mul(x, x)
104-
}
105-
"""
106-
)
107-
ir_model = onnxscript.ir.serde.deserialize_model(model)
108-
remove_unused_nodes(ir_model)
109-
assert (len(ir_model.graph._nodes)== 1)
110-
assert (len(ir_model.graph.inputs)== 2)
111-
assert (ir_model.graph.node(0).op_type== "Mul")
106+
)
107+
ir_model = onnxscript.ir.serde.deserialize_model(model)
108+
ir_model = self.remove_unused_nodes(ir_model,True)
109+
assert (len(ir_model.graph._nodes)== 1)
110+
assert (len(ir_model.graph.inputs)== 2)
111+
assert (ir_model.graph.node(0).op_type== "Mul")
112112

113113
def test_partially_used_nodes(self):
114114
model = onnx.parser.parse_model(

0 commit comments

Comments
 (0)