Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/ir/ir_api/core.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ir.load
ir.save
ir.from_proto
ir.from_onnx_text
ir.to_proto
ir.tensor
ir.node
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"TensorProtoTensor",
# Conversion functions
"from_proto",
"from_onnx_text",
"to_proto",
# Convenience constructors
"tensor",
Expand Down Expand Up @@ -144,7 +145,7 @@
TypeProtocol,
ValueProtocol,
)
from onnxscript.ir.serde import TensorProtoTensor, from_proto, to_proto
from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto


def __set_module() -> None:
Expand Down
7 changes: 2 additions & 5 deletions onnxscript/ir/passes/common/inliner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Callable, Sequence

import onnx
from onnx import parser

from onnxscript import ir
from onnxscript.ir.passes.common import inliner
Expand Down Expand Up @@ -44,14 +43,12 @@ def _check(
self, input_model: str, expected_model: str, renameable: Sequence[str] | None = None
) -> None:
name_check = _name_checker(renameable)
model_proto = parser.parse_model(input_model)
model_ir = ir.serde.deserialize_model(model_proto)
model_ir = ir.from_onnx_text(input_model)
inliner.InlinePass()(model_ir)
proto = ir.serde.serialize_model(model_ir)
text = onnx.printer.to_text(proto)
print(text)
expected_proto = parser.parse_model(expected_model)
expected_ir = ir.serde.deserialize_model(expected_proto)
expected_ir = ir.from_onnx_text(expected_model)
self.assertEqual(len(model_ir.graph), len(expected_ir.graph))
for node, expected_node in zip(model_ir.graph, expected_ir.graph):
# TODO: handle node renaming
Expand Down
10 changes: 10 additions & 0 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"TensorProtoTensor",
# Deserialization
"from_proto",
"from_onnx_text",
"deserialize_attribute",
"deserialize_dimension",
"deserialize_function",
Expand Down Expand Up @@ -190,6 +191,15 @@ def from_proto(proto: object) -> object:
)


def from_onnx_text(model_text: str, /) -> _core.Model:
"""Convert the ONNX textual representation to an IR model.

Comment thread
justinchuby marked this conversation as resolved.
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
"""
proto = onnx.parser.parse_model(model_text)
return deserialize_model(proto)


@typing.overload
def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap]
@typing.overload
Expand Down
10 changes: 2 additions & 8 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,10 @@
from onnxscript.optimizer import _constant_folding


def _create_model(model_text: str) -> ir.Model:
"""Create a model from the given text."""
model = onnx.parser.parse_model(model_text)
return ir.serde.deserialize_model(model)


class FoldConstantsTest(unittest.TestCase):
def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs):
if isinstance(model, str):
model = _create_model(model)
model = ir.from_onnx_text(model)
_constant_folding.fold_constants(
model, onnx_shape_inference=onnx_shape_inference, **kwargs
)
Expand Down Expand Up @@ -552,7 +546,7 @@ def test_large_transpose(self):
z = MatMul (x, wt)
}
"""
model = _create_model(model_text)
model = ir.from_onnx_text(model_text)
w = model.graph.initializers["w"]
w.shape = ir.Shape([512, 256])
w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32))
Expand Down
4 changes: 1 addition & 3 deletions onnxscript/rewriter/no_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.
import unittest

import onnx.parser
import parameterized

from onnxscript import ir
Expand All @@ -11,8 +10,7 @@

class NoOpTest(unittest.TestCase):
def _check(self, model_text: str) -> None:
model_proto = onnx.parser.parse_model(model_text)
model = ir.serde.deserialize_model(model_proto)
model = ir.from_onnx_text(model_text)
count = no_op.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(model.graph[-1].op_type, "Identity")
Expand Down
31 changes: 10 additions & 21 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import unittest

import onnx.defs
import onnx.parser

from onnxscript import ir, version_converter

Expand Down Expand Up @@ -43,7 +42,7 @@ def test_upstream_coverage(self):
self.assertIn((name, upgrade_version), op_upgrades)

def test_version_convert_non_standard_onnx_domain(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "local" : 1]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
Expand All @@ -58,7 +57,6 @@ def test_version_convert_non_standard_onnx_domain(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
self.assertEqual(model.graph.node(4).op_type, "GridSample")
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")

Expand All @@ -76,7 +74,7 @@ def test_version_convert_non_standard_onnx_domain(self):

class VersionConverter18to17Test(unittest.TestCase):
def test_version_convert_compatible(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
Expand All @@ -91,14 +89,13 @@ def test_version_convert_compatible(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 17
version_converter.convert_version(model, target_version=target_version)


class VersionConverter18to19Test(unittest.TestCase):
def test_version_convert_compatible(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
Expand All @@ -113,7 +110,6 @@ def test_version_convert_compatible(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 19
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -127,7 +123,7 @@ def test_version_convert_compatible(self):

class VersionConverter19to20Test(unittest.TestCase):
def test_version_convert_compatible(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
Expand All @@ -140,7 +136,6 @@ def test_version_convert_compatible(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 20
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -155,7 +150,7 @@ def test_version_convert_compatible(self):
self.assertEqual(len(model.graph.node(3).inputs), 2)

def test_version_convert_gridsample_linear(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
Expand All @@ -170,7 +165,6 @@ def test_version_convert_gridsample_linear(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
self.assertEqual(model.graph.node(4).op_type, "GridSample")
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")

Expand All @@ -186,7 +180,7 @@ def test_version_convert_gridsample_linear(self):
self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear")

def test_version_convert_gridsample_cubic(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
Expand All @@ -201,7 +195,6 @@ def test_version_convert_gridsample_cubic(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
self.assertEqual(model.graph.node(4).op_type, "GridSample")
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bicubic")

Expand All @@ -217,7 +210,7 @@ def test_version_convert_gridsample_cubic(self):
self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic")

def test_version_convert_inline(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 8, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output)
Expand All @@ -236,7 +229,6 @@ def test_version_convert_inline(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 20
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -254,7 +246,7 @@ def test_version_convert_inline(self):

class VersionConverter20to21Test(unittest.TestCase):
def test_version_groupnorm(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[1, 4, 512, 512] input_x, float[2] scale, float[2] bias) => (float[4, 512, 512] output)
Expand All @@ -265,7 +257,6 @@ def test_version_groupnorm(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 21
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -285,7 +276,7 @@ def test_version_groupnorm(self):
self.assertEqual(model.graph.node(9).version, 21)

def test_version_groupnorm_no_bias(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[1, 4, 512, 512] input_x, float[2] scale) => (float[4, 512, 512] output)
Expand All @@ -296,7 +287,6 @@ def test_version_groupnorm_no_bias(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 21
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -306,7 +296,7 @@ def test_version_groupnorm_no_bias(self):

class VersionConverter23to24Test(unittest.TestCase):
def test_version_convert_compatible(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 23]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
Expand All @@ -321,7 +311,6 @@ def test_version_convert_compatible(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 24
version_converter.convert_version(model, target_version=target_version)

Expand Down
Loading