Skip to content

Commit 82c4016

Browse files
committed
address reviews
1 parent 2c87912 commit 82c4016

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414
class TestLiftConstantsToInitializersPass(unittest.TestCase):
1515
@parameterized.parameterized.expand(
1616
[
17-
(ir.DataType.FLOAT, np.float32),
18-
(ir.DataType.INT64, np.int64),
17+
(ir.DataType.FLOAT,),
18+
(ir.DataType.INT64,),
1919
]
2020
)
21-
def test_pass_with_lifting_float_and_int_constants_to_initializers(
22-
self, ir_dtype, numpy_dtype
23-
):
21+
def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype):
2422
inputs = [
2523
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
2624
ir.Value(
@@ -30,7 +28,7 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(
3028
),
3129
]
3230

33-
constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype))
31+
constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir_dtype.numpy()))
3432
const_node = ir.node(
3533
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
3634
)
@@ -127,11 +125,11 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
127125
)
128126
self.assertEqual(len(else_graph.initializers), 1)
129127
self.assertEqual(len(then_graph.initializers), 1)
130-
self.assertEqual(
128+
self.assertIs(
131129
else_graph.initializers["val_0"].const_value,
132130
else_constant_tensor,
133131
)
134-
self.assertEqual(
132+
self.assertIs(
135133
then_graph.initializers["val_0"].const_value,
136134
then_constant_tensor,
137135
)

0 commit comments

Comments
 (0)