|
7 | 7 | import numpy as np |
8 | 8 |
|
9 | 9 | from onnxscript import ir |
10 | | -from onnxscript.ir.passes.common import shape_inference |
| 10 | +from onnxscript.ir.passes.common import _c_api_utils, shape_inference |
11 | 11 |
|
12 | 12 |
|
13 | 13 | class TestShapeInferencePass(unittest.TestCase): |
| 14 | + def test_pass_is_in_place(self): |
| 15 | + self.assertTrue(shape_inference.ShapeInferencePass().in_place) |
| 16 | + |
14 | 17 | def test_pass(self): |
15 | 18 | # Create a simple ONNX model with shape inference |
16 | 19 | # Define the model |
@@ -51,7 +54,7 @@ def test_pass_with_initializers(self): |
51 | 54 | # _BIG_TENSOR_SIZE_LIMIT is in bytes, but we create big_dim as size |
52 | 55 | # of a tensor. This is fine as we just need to create a big tensor whose size |
53 | 56 | # passes _BIG_TENSOR_SIZE_LIMIT |
54 | | - big_dim = shape_inference._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access |
| 57 | + big_dim = _c_api_utils._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access |
55 | 58 | inputs = [ |
56 | 59 | ir.Value( |
57 | 60 | name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) |
@@ -129,22 +132,6 @@ def test_pass_with_initializers(self): |
129 | 132 | ir.DataType.FLOAT, |
130 | 133 | ) |
131 | 134 |
|
132 | | - # Check that the original model is not modified |
133 | | - self.assertIsNone(val_add.shape) |
134 | | - self.assertIsNone(val_add.dtype) |
135 | | - self.assertIsNone(val_mul.shape) |
136 | | - self.assertIsNone(val_mul.dtype) |
137 | | - self.assertEqual(len(model.graph.inputs), 2) |
138 | | - self.assertEqual(len(model.graph.initializers), 2) |
139 | | - self.assertIs(model.graph.initializers["input_b"].const_value, inputs[1].const_value) |
140 | | - self.assertEqual(len(model.graph.outputs), 1) |
141 | | - self.assertEqual(model.graph.outputs[0].shape, None) |
142 | | - self.assertEqual(model.graph.outputs[0].dtype, None) |
143 | | - # Check that the initializer is not modified |
144 | | - self.assertIs( |
145 | | - model.graph.initializers["initializer"].const_value, initializer.const_value |
146 | | - ) |
147 | | - |
148 | 135 |
|
149 | 136 | if __name__ == "__main__": |
150 | 137 | unittest.main() |
0 commit comments