diff --git a/docs/ir/ir_api/core.md b/docs/ir/ir_api/core.md index fb3f98edd6..ad11a9a751 100644 --- a/docs/ir/ir_api/core.md +++ b/docs/ir/ir_api/core.md @@ -16,6 +16,7 @@ ir.load ir.save ir.from_proto + ir.from_onnx_text ir.to_proto ir.tensor ir.node diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 3c96f0eeeb..b5daebe235 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -71,6 +71,7 @@ "TensorProtoTensor", # Conversion functions "from_proto", + "from_onnx_text", "to_proto", # Convenience constructors "tensor", @@ -144,7 +145,7 @@ TypeProtocol, ValueProtocol, ) -from onnxscript.ir.serde import TensorProtoTensor, from_proto, to_proto +from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto def __set_module() -> None: diff --git a/onnxscript/ir/passes/common/inliner_test.py b/onnxscript/ir/passes/common/inliner_test.py index 7a64a8d4b4..1a4be6ce8e 100644 --- a/onnxscript/ir/passes/common/inliner_test.py +++ b/onnxscript/ir/passes/common/inliner_test.py @@ -8,7 +8,6 @@ from typing import Callable, Sequence import onnx -from onnx import parser from onnxscript import ir from onnxscript.ir.passes.common import inliner @@ -44,14 +43,12 @@ def _check( self, input_model: str, expected_model: str, renameable: Sequence[str] | None = None ) -> None: name_check = _name_checker(renameable) - model_proto = parser.parse_model(input_model) - model_ir = ir.serde.deserialize_model(model_proto) + model_ir = ir.from_onnx_text(input_model) inliner.InlinePass()(model_ir) proto = ir.serde.serialize_model(model_ir) text = onnx.printer.to_text(proto) print(text) - expected_proto = parser.parse_model(expected_model) - expected_ir = ir.serde.deserialize_model(expected_proto) + expected_ir = ir.from_onnx_text(expected_model) self.assertEqual(len(model_ir.graph), len(expected_ir.graph)) for node, expected_node in zip(model_ir.graph, expected_ir.graph): # TODO: handle node renaming diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index ede4e14974..b5be445aef 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -21,6 +21,7 @@ "TensorProtoTensor", # Deserialization "from_proto", + "from_onnx_text", "deserialize_attribute", "deserialize_dimension", "deserialize_function", @@ -190,6 +191,15 @@ def from_proto(proto: object) -> object: ) +def from_onnx_text(model_text: str, /) -> _core.Model: + """Convert the ONNX textual representation to an IR model. + + Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html + """ + proto = onnx.parser.parse_model(model_text) + return deserialize_model(proto) + + @typing.overload def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap] @typing.overload diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 81ed911c9e..5a98cb5d51 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -13,16 +13,10 @@ from onnxscript.optimizer import _constant_folding -def _create_model(model_text: str) -> ir.Model: - """Create a model from the given text.""" - model = onnx.parser.parse_model(model_text) - return ir.serde.deserialize_model(model) - - class FoldConstantsTest(unittest.TestCase): def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs): if isinstance(model, str): - model = _create_model(model) + model = ir.from_onnx_text(model) _constant_folding.fold_constants( model, onnx_shape_inference=onnx_shape_inference, **kwargs ) @@ -552,7 +546,7 @@ def test_large_transpose(self): z = MatMul (x, wt) } """ - model = _create_model(model_text) + model = ir.from_onnx_text(model_text) w = model.graph.initializers["w"] w.shape = ir.Shape([512, 256]) w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32)) diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/no_op_test.py index 4e509e7f3a..2b2a57f32a 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/no_op_test.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import unittest -import onnx.parser import parameterized from onnxscript import ir @@ -11,8 +10,7 @@ class NoOpTest(unittest.TestCase): def _check(self, model_text: str) -> None: - model_proto = onnx.parser.parse_model(model_text) - model = ir.serde.deserialize_model(model_proto) + model = ir.from_onnx_text(model_text) count = no_op.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(model.graph[-1].op_type, "Identity") diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index 3c73498230..2726dc1a4e 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -5,7 +5,6 @@ import unittest import onnx.defs -import onnx.parser from onnxscript import ir, version_converter @@ -43,7 +42,7 @@ def test_upstream_coverage(self): self.assertIn((name, upgrade_version), op_upgrades) def test_version_convert_non_standard_onnx_domain(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) @@ -58,7 +57,6 @@ def test_version_convert_non_standard_onnx_domain(self): } """ ) - model = ir.serde.deserialize_model(model_proto) self.assertEqual(model.graph.node(4).op_type, "GridSample") self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear") @@ -76,7 +74,7 @@ def test_version_convert_non_standard_onnx_domain(self): class VersionConverter18to17Test(unittest.TestCase): def test_version_convert_compatible(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) @@ -91,14 +89,13 @@ def test_version_convert_compatible(self): } """ ) - model = ir.serde.deserialize_model(model_proto) target_version = 17 version_converter.convert_version(model, target_version=target_version) class VersionConverter18to19Test(unittest.TestCase): def test_version_convert_compatible(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) @@ -113,7 +110,6 @@ def test_version_convert_compatible(self): } """ ) - model = ir.serde.deserialize_model(model_proto) target_version = 19 version_converter.convert_version(model, target_version=target_version) @@ -127,7 +123,7 @@ def test_version_convert_compatible(self): class VersionConverter19to20Test(unittest.TestCase): def test_version_convert_compatible(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output) @@ -140,7 +136,6 @@ def test_version_convert_compatible(self): } """ ) - model = ir.serde.deserialize_model(model_proto) target_version = 20 version_converter.convert_version(model, target_version=target_version) @@ -155,7 +150,7 @@ def test_version_convert_compatible(self): self.assertEqual(len(model.graph.node(3).inputs), 2) def test_version_convert_gridsample_linear(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) @@ -170,7 +165,6 @@ def test_version_convert_gridsample_linear(self): } """ ) - model = ir.serde.deserialize_model(model_proto) self.assertEqual(model.graph.node(4).op_type, "GridSample") self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear") @@ -186,7 +180,7 @@ def test_version_convert_gridsample_linear(self): self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear") def test_version_convert_gridsample_cubic(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) @@ -201,7 +195,6 @@ def test_version_convert_gridsample_cubic(self): } """ ) - model = ir.serde.deserialize_model(model_proto) self.assertEqual(model.graph.node(4).op_type, "GridSample") self.assertEqual(model.graph.node(4).attributes["mode"].value, "bicubic") @@ -217,7 +210,7 @@ def test_version_convert_gridsample_cubic(self): self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic") def test_version_convert_inline(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output) @@ -236,7 +229,6 @@ def test_version_convert_inline(self): } """ ) - model = ir.serde.deserialize_model(model_proto) target_version = 20 version_converter.convert_version(model, target_version=target_version) @@ -254,7 +246,7 @@ def test_version_convert_inline(self): class VersionConverter20to21Test(unittest.TestCase): def test_version_groupnorm(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[1, 4, 512, 512] input_x, float[2] scale, float[2] bias) => (float[4, 512, 512] output) @@ -265,7 +257,6 @@ def test_version_groupnorm(self): } """ ) - model = ir.serde.deserialize_model(model_proto) target_version = 21 version_converter.convert_version(model, target_version=target_version) @@ -285,7 +276,7 @@ def test_version_groupnorm(self): self.assertEqual(model.graph.node(9).version, 21) def test_version_groupnorm_no_bias(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[1, 4, 512, 512] input_x, float[2] scale) => (float[4, 512, 512] output) @@ -296,7 +287,6 @@ def test_version_groupnorm_no_bias(self): } """ ) - model = ir.serde.deserialize_model(model_proto) target_version = 21 version_converter.convert_version(model, target_version=target_version) @@ -306,7 +296,7 @@ def test_version_groupnorm_no_bias(self): class VersionConverter23to24Test(unittest.TestCase): def test_version_convert_compatible(self): - model_proto = onnx.parser.parse_model( + model = ir.from_onnx_text( """ agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) @@ -321,7 +311,6 @@ def test_version_convert_compatible(self): } """ ) - model = ir.serde.deserialize_model(model_proto) target_version = 24 version_converter.convert_version(model, target_version=target_version)