@@ -54,6 +54,62 @@ def test_remove_unused_initializers(self):
5454 self .assertEqual (model .graph .node [0 ].op_type , "Mul" )
5555 self .assertEqual (len (model .graph .initializer ), 0 )
5656
57+ def test_remove_unused_inputs_initializers ():
58+ # remove inputs in case they are initializers
59+ # https://github.com/microsoft/onnxscript/issues/2211
60+ model = onnx .parser .parse_model (
61+ """
62+ <ir_version: 10, opset_import: [ "" : 17]>
63+ agraph (float[N] x, float[N] two) => (float[N] z)
64+ <float two = {2.0,2.0}> {
65+ four = Add(two, two)
66+ z = Mul(x, x)
67+ }
68+ """
69+ )
70+ ir_model = onnxscript .ir .serde .deserialize_model (model )
71+ remove_unused_nodes (ir_model )
72+ assert (len (ir_model .graph ._nodes )== 1 )
73+ assert (len (ir_model .graph .inputs )== 1 )
74+ assert (ir_model .graph .node (0 ).op_type == "Mul" )
75+
76+ def test_avoid_remove_unused_inputs_initializers ():
77+ # supress remove inputs in case they are initializers
78+ # if explicitly said
79+ model = onnx .parser .parse_model (
80+ """
81+ <ir_version: 10, opset_import: [ "" : 17]>
82+ agraph (float[N] x, float[N] two) => (float[N] z)
83+ <float two = {2.0,2.0}> {
84+ four = Add(two, two)
85+ z = Mul(x, x)
86+ }
87+ """
88+ )
89+ ir_model = onnxscript .ir .serde .deserialize_model (model )
90+ remove_unused_nodes (ir_model ,False )
91+ assert (len (ir_model .graph ._nodes )== 1 )
92+ assert (len (ir_model .graph .inputs )== 2 )
93+ assert (ir_model .graph .node (0 ).op_type == "Mul" )
94+
95+ def test_avoid_remove_unused_inputs ():
96+ # preserve inputs as part of interface
97+ model = onnx .parser .parse_model (
98+ """
99+ <ir_version: 10, opset_import: [ "" : 17]>
100+ agraph (float[N] x, float[N] two) => (float[N] z)
101+ {
102+ four = Add(two, two)
103+ z = Mul(x, x)
104+ }
105+ """
106+ )
107+ ir_model = onnxscript .ir .serde .deserialize_model (model )
108+ remove_unused_nodes (ir_model )
109+ assert (len (ir_model .graph ._nodes )== 1 )
110+ assert (len (ir_model .graph .inputs )== 2 )
111+ assert (ir_model .graph .node (0 ).op_type == "Mul" )
112+
57113 def test_partially_used_nodes (self ):
58114 model = onnx .parser .parse_model (
59115 """
0 commit comments