1414class TestLiftConstantsToInitializersPass (unittest .TestCase ):
1515 @parameterized .parameterized .expand (
1616 [
17- (ir .DataType .FLOAT ,),
18- (ir .DataType .INT64 ,),
17+ (ir .DataType .FLOAT , True ),
18+ (ir .DataType .FLOAT , False ),
19+ (ir .DataType .INT64 , True ),
20+ (ir .DataType .INT64 , False ),
1921 ]
2022 )
21- def test_pass_with_lifting_float_and_int_constants_to_initializers (self , ir_dtype ):
23+ def test_pass_with_lifting_float_and_int_constants_to_initializers (
24+ self , ir_dtype : ir .DataType , lift_all_constants : bool
25+ ):
2226 inputs = [
2327 ir .Value (name = "input_a" , type = ir .TensorType (ir_dtype ), shape = ir .Shape ((2 , 3 ))),
2428 ir .Value (
@@ -51,7 +55,9 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtyp
5155 self .assertEqual (len ([node for node in model .graph if node .op_type == "Constant" ]), 1 )
5256
5357 # Perform lift constants to initializers
54- result = constant_manipulation .LiftConstantsToInitializersPass ()(model )
58+ result = constant_manipulation .LiftConstantsToInitializersPass (
59+ lift_all_constants = lift_all_constants
60+ )(model )
5561 self .assertTrue (result .modified )
5662 # Check that the constant node is lifted to an initializer
5763 self .assertEqual (len (result .model .graph .initializers ), 1 )
@@ -67,7 +73,15 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtyp
6773 len ([node for node in result .model .graph if node .op_type == "Constant" ]), 0
6874 )
6975
70- def test_pass_with_lifting_constants_to_initializers_within_subgraph (self ):
76+ @parameterized .parameterized .expand (
77+ [
78+ (True ,),
79+ (False ,),
80+ ]
81+ )
82+ def test_pass_with_lifting_constants_to_initializers_within_subgraph (
83+ self , lift_all_constants : bool
84+ ):
7185 input_value = ir .Value (
7286 name = "input" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
7387 )
@@ -115,7 +129,9 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
115129 graph = main_graph ,
116130 ir_version = 10 ,
117131 )
118- result = constant_manipulation .LiftConstantsToInitializersPass ()(model )
132+ result = constant_manipulation .LiftConstantsToInitializersPass (
133+ lift_all_constants = lift_all_constants
134+ )(model )
119135 self .assertTrue (result .modified )
120136 # Check that the constant node is lifted to the subgraph initializers
121137 for node in ir .traversal .RecursiveGraphIterator (result .model .graph ):
@@ -136,16 +152,26 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
136152
137153 @parameterized .parameterized .expand (
138154 [
139- (1.0 , "value_float" , np .float32 ),
140- (1 , "value_int" , np .int64 ),
141- ("hello world!" , "value_string" , np .bytes_ ),
142- ([1.0 , 2.0 , 3.0 ], "value_floats" , np .float32 ),
143- ([1 , 2 , 3 ], "value_ints" , np .int64 ),
144- (["hello world!" , "thank you." ], "value_strings" , np .bytes_ ),
155+ (1.0 , "value_float" , np .float32 , True ),
156+ (1.0 , "value_float" , np .float32 , False ),
157+ (1 , "value_int" , np .int64 , True ),
158+ (1 , "value_int" , np .int64 , False ),
159+ ("hello world!" , "value_string" , np .bytes_ , True ),
160+ ("hello world!" , "value_string" , np .bytes_ , False ),
161+ ([1.0 , 2.0 , 3.0 ], "value_floats" , np .float32 , True ),
162+ ([1.0 , 2.0 , 3.0 ], "value_floats" , np .float32 , False ),
163+ ([1 , 2 , 3 ], "value_ints" , np .int64 , True ),
164+ ([1 , 2 , 3 ], "value_ints" , np .int64 , False ),
165+ (["hello world!" , "thank you." ], "value_strings" , np .bytes_ , True ),
166+ (["hello world!" , "thank you." ], "value_strings" , np .bytes_ , False ),
145167 ]
146168 )
147169 def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings (
148- self , value , constant_attribute , np_dtype
170+ self ,
171+ value : float | int | str | list [float ] | list [int ] | list [str ],
172+ constant_attribute : str ,
173+ np_dtype : type [np .dtype ],
174+ lift_all_constants : bool ,
149175 ):
150176 input_value = ir .Value (
151177 name = "input" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
@@ -179,11 +205,47 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
179205 self .assertEqual (len ([node for node in model .graph if node .op_type == "Constant" ]), 1 )
180206
181207 # Perform lift constants to initializers
182- result = constant_manipulation .LiftConstantsToInitializersPass ()(model )
183- self .assertTrue (result .modified )
184- # Check that the constant node is lifted to an initializer
185- self .assertEqual (len (result .model .graph .initializers ), 1 )
186- np .testing .assert_array_equal (
187- result .model .graph .initializers ["val_1" ].const_value .numpy (),
188- np .array (constant_value , dtype = np_dtype ),
208+ result = constant_manipulation .LiftConstantsToInitializersPass (
209+ lift_all_constants = lift_all_constants
210+ )(model )
211+ if lift_all_constants :
212+ self .assertTrue (result .modified )
213+ # Check that the constant node is lifted to an initializer
214+ self .assertEqual (len (result .model .graph .initializers ), 1 )
215+ np .testing .assert_array_equal (
216+ result .model .graph .initializers ["val_1" ].const_value .numpy (),
217+ np .array (constant_value , dtype = np_dtype ),
218+ )
219+ else :
220+ self .assertFalse (result .modified )
221+ # Check that the constant node is not lifted to an initializer
222+ self .assertEqual (len (result .model .graph .initializers ), 0 )
223+
224+ def test_not_lifting_constants_to_initializers_when_it_is_output (self ):
225+ input_value = ir .Value (
226+ name = "input" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
189227 )
228+ identity_node_input = ir .node ("Identity" , inputs = [input_value ], num_outputs = 1 )
229+
230+ constant_value = ir .tensor (np .random .rand (2 , 3 ).astype (np .float32 ))
231+ const_node = ir .node (
232+ "Constant" ,
233+ inputs = [],
234+ attributes = {"value" : constant_value },
235+ num_outputs = 1 ,
236+ )
237+
238+ model = ir .Model (
239+ graph = ir .Graph (
240+ inputs = [input_value ],
241+ outputs = [identity_node_input .outputs [0 ], const_node .outputs [0 ]],
242+ nodes = [identity_node_input , const_node ],
243+ opset_imports = {"" : 20 },
244+ ),
245+ ir_version = 10 ,
246+ )
247+
248+ result = constant_manipulation .LiftConstantsToInitializersPass ()(model )
249+ self .assertFalse (result .modified )
250+ # Check that the constant node is not lifted to an initializer
251+ self .assertEqual (len (result .model .graph .initializers ), 0 )
0 commit comments