1414class TestLiftConstantsToInitializersPass (unittest .TestCase ):
1515 @parameterized .parameterized .expand (
1616 [
17- (ir .DataType .FLOAT , np . float32 ),
18- (ir .DataType .INT64 , np . int64 ),
17+ (ir .DataType .FLOAT ,),
18+ (ir .DataType .INT64 ,),
1919 ]
2020 )
21- def test_pass_with_lifting_float_and_int_constants_to_initializers (
22- self , ir_dtype , numpy_dtype
23- ):
21+ def test_pass_with_lifting_float_and_int_constants_to_initializers (self , ir_dtype ):
2422 inputs = [
2523 ir .Value (name = "input_a" , type = ir .TensorType (ir_dtype ), shape = ir .Shape ((2 , 3 ))),
2624 ir .Value (
@@ -30,7 +28,7 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(
3028 ),
3129 ]
3230
33- constant_tensor = ir .tensor (np .random .rand (2 , 3 ).astype (numpy_dtype ))
31+ constant_tensor = ir .tensor (np .random .rand (2 , 3 ).astype (ir_dtype . numpy () ))
3432 const_node = ir .node (
3533 "Constant" , inputs = [], attributes = {"value" : constant_tensor }, num_outputs = 1
3634 )
@@ -127,11 +125,11 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
127125 )
128126 self .assertEqual (len (else_graph .initializers ), 1 )
129127 self .assertEqual (len (then_graph .initializers ), 1 )
130- self .assertEqual (
128+ self .assertIs (
131129 else_graph .initializers ["val_0" ].const_value ,
132130 else_constant_tensor ,
133131 )
134- self .assertEqual (
132+ self .assertIs (
135133 then_graph .initializers ["val_0" ].const_value ,
136134 then_constant_tensor ,
137135 )
0 commit comments