Skip to content

Commit 318025b

Browse files
committed
Fix test
1 parent 22f5877 commit 318025b

1 file changed

Lines changed: 5 additions & 18 deletions

File tree

onnxscript/ir/passes/common/shape_inference_test.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
import numpy as np
88

99
from onnxscript import ir
10-
from onnxscript.ir.passes.common import shape_inference
10+
from onnxscript.ir.passes.common import _c_api_utils, shape_inference
1111

1212

1313
class TestShapeInferencePass(unittest.TestCase):
14+
def test_pass_is_in_place(self):
15+
self.assertTrue(shape_inference.ShapeInferencePass().in_place)
16+
1417
def test_pass(self):
1518
# Create a simple ONNX model with shape inference
1619
# Define the model
@@ -51,7 +54,7 @@ def test_pass_with_initializers(self):
5154
# _BIG_TENSOR_SIZE_LIMIT is in bytes, but we create big_dim as size
5255
# of a tensor. This is fine as we just need to create a big tensor whose size
5356
# passes _BIG_TENSOR_SIZE_LIMIT
54-
big_dim = shape_inference._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access
57+
big_dim = _c_api_utils._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access
5558
inputs = [
5659
ir.Value(
5760
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
@@ -129,22 +132,6 @@ def test_pass_with_initializers(self):
129132
ir.DataType.FLOAT,
130133
)
131134

132-
# Check that the original model is not modified
133-
self.assertIsNone(val_add.shape)
134-
self.assertIsNone(val_add.dtype)
135-
self.assertIsNone(val_mul.shape)
136-
self.assertIsNone(val_mul.dtype)
137-
self.assertEqual(len(model.graph.inputs), 2)
138-
self.assertEqual(len(model.graph.initializers), 2)
139-
self.assertIs(model.graph.initializers["input_b"].const_value, inputs[1].const_value)
140-
self.assertEqual(len(model.graph.outputs), 1)
141-
self.assertEqual(model.graph.outputs[0].shape, None)
142-
self.assertEqual(model.graph.outputs[0].dtype, None)
143-
# Check that the initializer is not modified
144-
self.assertIs(
145-
model.graph.initializers["initializer"].const_value, initializer.const_value
146-
)
147-
148135

149136
if __name__ == "__main__":
150137
unittest.main()

0 commit comments

Comments
 (0)