@@ -18,7 +18,9 @@ class TestLiftConstantsToInitializersPass(unittest.TestCase):
1818 (ir .DataType .INT64 , np .int64 ),
1919 ]
2020 )
21- def test_pass_with_lifting_constants_to_initializers (self , ir_dtype , numpy_dtype ):
21+ def test_pass_with_lifting_float_and_int_constants_to_initializers (
22+ self , ir_dtype , numpy_dtype
23+ ):
2224 inputs = [
2325 ir .Value (name = "input_a" , type = ir .TensorType (ir_dtype ), shape = ir .Shape ((2 , 3 ))),
2426 ir .Value (
@@ -29,10 +31,11 @@ def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype
2931 ]
3032
3133 constant_tensor = ir .tensor (np .random .rand (2 , 3 ).astype (numpy_dtype ))
32- attribute = ir .convenience .convert_attributes ({"value" : constant_tensor })
33- const_node = ir .Node ("" , "Constant" , inputs = [], attributes = attribute , num_outputs = 1 )
34- add_node = ir .Node ("" , "Add" , inputs = [inputs [0 ], const_node .outputs [0 ]])
35- mul_node = ir .Node ("" , "Mul" , inputs = [add_node .outputs [0 ], inputs [1 ]])
34+ const_node = ir .node (
35+ "Constant" , inputs = [], attributes = {"value" : constant_tensor }, num_outputs = 1
36+ )
37+ add_node = ir .node ("Add" , inputs = [inputs [0 ], const_node .outputs [0 ]])
38+ mul_node = ir .node ("Mul" , inputs = [add_node .outputs [0 ], inputs [1 ]])
3639
3740 model = ir .Model (
3841 graph = ir .Graph (
@@ -72,40 +75,34 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
7275 )
7376
7477 then_constant_tensor = ir .tensor (np .random .rand (2 , 3 ).astype (np .float32 ))
75- attribute = ir .convenience .convert_attributes ({"value" : then_constant_tensor })
76- then_const_node = ir .Node (
77- "" , "Constant" , inputs = [], attributes = attribute , num_outputs = 1
78+ then_const_node = ir .node (
79+ "Constant" , inputs = [], attributes = {"value" : then_constant_tensor }, num_outputs = 1
7880 )
7981 # then branch adds the constant to the input
8082 # else branch multiplies the input by the constant
81- add_node = ir .Node ( "" , "Add" , inputs = [input_value , then_const_node .outputs [0 ]])
83+ add_node = ir .node ( "Add" , inputs = [input_value , then_const_node .outputs [0 ]])
8284 then_graph = ir .Graph (
8385 inputs = [input_value ],
8486 outputs = [add_node .outputs [0 ]],
8587 nodes = [then_const_node , add_node ],
8688 opset_imports = {"" : 20 },
8789 )
8890 else_constant_tensor = ir .tensor (np .random .rand (2 , 3 ).astype (np .float32 ))
89- attribute = ir .convenience .convert_attributes ({"value" : else_constant_tensor })
90- else_const_node = ir .Node (
91- "" , "Constant" , inputs = [], attributes = attribute , num_outputs = 1
91+ else_const_node = ir .node (
92+ "Constant" , inputs = [], attributes = {"value" : else_constant_tensor }, num_outputs = 1
9293 )
93- mul_node = ir .Node ( "" , "Mul" , inputs = [input_value , else_const_node .outputs [0 ]])
94+ mul_node = ir .node ( "Mul" , inputs = [input_value , else_const_node .outputs [0 ]])
9495 else_graph = ir .Graph (
9596 inputs = [input_value ],
9697 outputs = [mul_node .outputs [0 ]],
9798 nodes = [else_const_node , mul_node ],
9899 opset_imports = {"" : 20 },
99100 )
100101 # create a conditional node that uses the then and else graphs
101- attribute = ir .convenience .convert_attributes (
102- {"then_branch" : then_graph , "else_branch" : else_graph }
103- )
104- cond_node = ir .Node (
105- "" ,
102+ cond_node = ir .node (
106103 "If" ,
107104 inputs = [input_value ],
108- attributes = attribute ,
105+ attributes = { "then_branch" : then_graph , "else_branch" : else_graph } ,
109106 num_outputs = 1 ,
110107 )
111108 # construnct the model
@@ -128,15 +125,69 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
128125 raise AssertionError (
129126 f"Constant node '{ node .name } ' was not lifted to initializers"
130127 )
131- if node .op_type == "Add" :
132- self .assertEqual (len (node .graph .initializers ), 1 )
133- self .assertEqual (
134- node .graph .initializers ["val_0" ].const_value ,
135- then_constant_tensor ,
136- )
137- if node .op_type == "Mul" :
138- self .assertEqual (len (node .graph .initializers ), 1 )
139- self .assertEqual (
140- node .graph .initializers ["val_0" ].const_value ,
141- else_constant_tensor ,
142- )
128+ self .assertEqual (len (else_graph .initializers ), 1 )
129+ self .assertEqual (len (then_graph .initializers ), 1 )
130+ self .assertEqual (
131+ else_graph .initializers ["val_0" ].const_value ,
132+ else_constant_tensor ,
133+ )
134+ self .assertEqual (
135+ then_graph .initializers ["val_0" ].const_value ,
136+ then_constant_tensor ,
137+ )
138+
139+ @parameterized .parameterized .expand (
140+ [
141+ (1.0 , "value_float" , np .float32 ),
142+ (1 , "value_int" , np .int64 ),
143+ ("hello world!" , "value_string" , np .bytes_ ),
144+ ([1.0 , 2.0 , 3.0 ], "value_floats" , np .float32 ),
145+ ([1 , 2 , 3 ], "value_ints" , np .int64 ),
146+ (["hello world!" , "thank you." ], "value_strings" , np .bytes_ ),
147+ ]
148+ )
149+ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings (
150+ self , value , constant_attribute , np_dtype
151+ ):
152+ input_value = ir .Value (
153+ name = "input" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
154+ )
155+
156+ constant_value = value
157+ const_node = ir .node (
158+ "Constant" ,
159+ inputs = [],
160+ attributes = {constant_attribute : constant_value },
161+ num_outputs = 1 ,
162+ )
163+ identity_node_constant = ir .node (
164+ "Identity" , inputs = [const_node .outputs [0 ]], num_outputs = 1
165+ )
166+ identity_node_input = ir .node ("Identity" , inputs = [input_value ], num_outputs = 1 )
167+
168+ model = ir .Model (
169+ graph = ir .Graph (
170+ inputs = [input_value ],
171+ outputs = [identity_node_input .outputs [0 ], identity_node_constant .outputs [0 ]],
172+ nodes = [identity_node_input , const_node , identity_node_constant ],
173+ opset_imports = {"" : 20 },
174+ ),
175+ ir_version = 10 ,
176+ )
177+
178+ # Check that the initializer is not in the graph yet
179+ assert len (model .graph .initializers ) == 0
180+ # And 1 constant node
181+ assert len ([node for node in model .graph if node .op_type == "Constant" ]) == 1
182+
183+ # Perform lift constants to initializers
184+ result = constant_manipulation .LiftConstantsToInitializersPass ()(model )
185+ assert result .modified
186+ # Check that the constant node is lifted to an initializer
187+ assert len (result .model .graph .initializers ) == 1
188+ self .assertTrue (
189+ np .array_equal (
190+ result .model .graph .initializers ["val_1" ].const_value .raw ,
191+ np .array (constant_value , dtype = np_dtype ),
192+ )
193+ )
0 commit comments