Skip to content

Commit a2e9d2a

Browse files
committed
assert to self.assert
1 parent 82c4016 commit a2e9d2a

1 file changed

Lines changed: 7 additions & 9 deletions

File tree

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,16 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
174174
)
175175

176176
# Check that the initializer is not in the graph yet
177-
assert len(model.graph.initializers) == 0
177+
self.assertEqual(len(model.graph.initializers), 0)
178178
# And 1 constant node
179-
assert len([node for node in model.graph if node.op_type == "Constant"]) == 1
179+
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)
180180

181181
# Perform lift constants to initializers
182182
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
183-
assert result.modified
183+
self.assertTrue(result.modified)
184184
# Check that the constant node is lifted to an initializer
185-
assert len(result.model.graph.initializers) == 1
186-
self.assertTrue(
187-
np.array_equal(
188-
result.model.graph.initializers["val_1"].const_value.raw,
189-
np.array(constant_value, dtype=np_dtype),
190-
)
185+
self.assertEqual(len(result.model.graph.initializers), 1)
186+
np.testing.assert_array_equal(
187+
result.model.graph.initializers["val_1"].const_value.raw,
188+
np.array(constant_value, dtype=np_dtype),
191189
)

0 commit comments

Comments
 (0)