Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onnxscript/_framework_apis/torch_2_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"save_model_with_external_data",
"torchlib_opset",
]
import logging
from typing import TYPE_CHECKING

from onnxscript import ir, optimizer, version_converter
Expand All @@ -25,6 +26,9 @@
from onnxscript.onnx_opset._impl.opset18 import Opset18


logger = logging.getLogger(__name__)


def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""
optimizer.optimize_ir(model)
Expand All @@ -34,8 +38,9 @@
def convert_version(model: ir.Model, target_version: int) -> ir.Model:
"""Convert the model to the specified ONNX opset version."""
if target_version < 18:
logger.warning("Conversion to opset < 18 is not supported.")

Check warning on line 41 in onnxscript/_framework_apis/torch_2_6.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_6.py#L41

Added line #L41 was not covered by tests
return model
version_converter.convert_version(model, target_version)
version_converter.convert_version(model, target_version, fallback=True)

Check warning on line 43 in onnxscript/_framework_apis/torch_2_6.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_6.py#L43

Added line #L43 was not covered by tests
return model


Expand Down
154 changes: 146 additions & 8 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,157 @@
from __future__ import annotations

__all__ = [
# Functions
"ConvertVersionPass",
"convert_version",
]

import onnxscript.optimizer
import logging

import onnx

from onnxscript import ir
from onnxscript.ir.passes.common import _c_api_utils
from onnxscript.ir.passes.common import inliner as _inliner
from onnxscript.ir.passes.common import unused_removal as _unused_removal
from onnxscript.version_converter import _version_converter

logger = logging.getLogger(__name__)


class ConvertVersionPass(ir.passes.InPlacePass):
"""Convert the model to the specified ONNX opset version.

This pass leverages the onnxscript version converter to convert the model. If
the conversion is not supported, it falls back to the onnx C API to convert
the model. This pass is in-place.

The pass is an no-op if the c-api fails.

Attributes:
target_version: The target ONNX opset version to convert the model to.
fallback: Whether to fallback to the onnx version converter if the
target version is not supported. Default is False.
"""

def __init__(self, target_version: int, fallback: bool = False) -> None:
super().__init__()
self.target_version = target_version
self.fallback = fallback
self.convert_pass = ir.passes.Sequential(
_inliner.InlinePass(),
_ConvertVersionPassRequiresInline(
target_version=target_version,
fallback=fallback,
),
_unused_removal.RemoveUnusedNodesPass(),
_unused_removal.RemoveUnusedFunctionsPass(),
_unused_removal.RemoveUnusedOpsetsPass(),
)

def call(self, model: ir.Model) -> ir.passes.PassResult:
return self.convert_pass(model)


class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
"""Convert the model to the specified ONNX opset version.

This pass leverages the onnxscript version converter to convert the model. If
the conversion is not supported, it falls back to the onnx C API to convert
the model. This pass is in-place.

The pass is an no-op if the c-api fails.

Attributes:
target_version: The target ONNX opset version to convert the model to.
fallback: Whether to fallback to the onnx version converter if the
target version is not supported. Default is False.
"""

def __init__(self, target_version: int, fallback: bool = False) -> None:
super().__init__()
self.target_version = target_version
self.fallback = fallback

def call(self, model: ir.Model) -> ir.passes.PassResult:
if model.functions:
raise ValueError(

Check warning on line 79 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L79

Added line #L79 was not covered by tests
"The model contains functions. The version conversion pass does not support "
"functions. Please use `onnxscript.ir.passes.common.inliner.InlinePass` to inline the "
f"functions before applying this pass ({self.__class__.__name__})."
)
if "" in model.graph.opset_imports:
onnx_opset_version = model.graph.opset_imports[""]
if onnx_opset_version == self.target_version:
# No need to convert the version
return ir.passes.PassResult(model, False)

Check warning on line 88 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L88

Added line #L88 was not covered by tests

# When fallback is disabled, always use the onnxscript version converter;
# When fallback is enabled, use the onnxscript version converter
# if the target version is supported. Otherwise, use the onnx C API
# to convert the model.
if not self.fallback or _version_converter.version_supported(
model, self.target_version
):
_version_converter.convert_version(
model,
target_version=self.target_version,
)
return ir.passes.PassResult(model, True)

if not self.fallback:
logger.warning(

Check warning on line 104 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L104

Added line #L104 was not covered by tests
"The model version conversion is not supported by the onnxscript version converter "
"and fallback is disabled. The model was not modified"
" (target version: %d). "
"Set fallback=True to enable fallback to the onnx c-api version converter.",
self.target_version,
)
return ir.passes.PassResult(model, False)

Check warning on line 111 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L111

Added line #L111 was not covered by tests

# If the onnxscript version converter does not support the conversion,
# we can use the onnx C API to convert the model
def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto:

Check warning on line 115 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L115

Added line #L115 was not covered by tests
"""Partial function to check the model."""
return onnx.version_converter.convert_version(

Check warning on line 117 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L117

Added line #L117 was not covered by tests
proto, target_version=self.target_version
)

try:
converted_proto = _c_api_utils.call_onnx_api(

Check warning on line 122 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L121-L122

Added lines #L121 - L122 were not covered by tests
func=_partial_convert_version, model=model
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning(

Check warning on line 126 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L125-L126

Added lines #L125 - L126 were not covered by tests
"Failed to convert the model to the target version %d using the ONNX C API. "
"The model was not modified",
self.target_version,
exc_info=e,
)
return ir.passes.PassResult(model, False)

Check warning on line 132 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L132

Added line #L132 was not covered by tests

converted_model = ir.from_proto(converted_proto)

Check warning on line 134 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L134

Added line #L134 was not covered by tests

# Recover the initializers in the converted model
for input in converted_model.graph.inputs:
if input.name in model.graph.initializers:
input.const_value = model.graph.initializers[input.name].const_value
converted_model.graph.register_initializer(input)
user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)]
converted_model.graph.inputs.clear()
converted_model.graph.inputs.extend(user_inputs)

Check warning on line 143 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L139-L143

Added lines #L139 - L143 were not covered by tests

# Return the converted graph to the original model to keep the pass in-place
model.graph = converted_model.graph
return ir.passes.PassResult(model, True)

Check warning on line 147 in onnxscript/version_converter/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/__init__.py#L146-L147

Added lines #L146 - L147 were not covered by tests


def convert_version(model: ir.Model, target_version: int) -> None:
"""Convert the model to the specified ONNX opset version."""
def convert_version(model: ir.Model, target_version: int, fallback=False) -> None:
Comment thread
justinchuby marked this conversation as resolved.
"""Convert the model to the specified ONNX opset version.

# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
# Hence, we inline all the functions.
onnxscript.optimizer.inline(model)
_version_converter.convert_version(model, target_version)
Args:
model: The model to convert.
target_version: The target ONNX opset version.
fallback: Whether to fallback to the onnx version converter if the
target version is not supported. Default is False.
"""
ConvertVersionPass(target_version=target_version, fallback=fallback)(model)
19 changes: 17 additions & 2 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
logger = logging.getLogger(__name__)


CURRENT_MAX_ONNX_OPSET = 23
SUPPORTED_MAX_ONNX_OPSET = 23
SUPPORTED_MIN_ONNX_OPSET = 18


class VersionConverterError(RuntimeError):
Expand All @@ -38,6 +39,20 @@
AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue]


def version_supported(model: ir.Model, target_version: int) -> bool:
"""Check if the target version is supported by the current version."""
if "" in model.graph.opset_imports:
current_version = model.graph.opset_imports[""]

Check warning on line 45 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L45

Added line #L45 was not covered by tests
else:
return True
return (

Check warning on line 48 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L47-L48

Added lines #L47 - L48 were not covered by tests
SUPPORTED_MIN_ONNX_OPSET
<= current_version
<= target_version
<= SUPPORTED_MAX_ONNX_OPSET
Comment thread
justinchuby marked this conversation as resolved.
)


class AdapterRegistry:
"""A class that maintains a registry of adapters for ops."""

Expand Down Expand Up @@ -262,7 +277,7 @@
return None

def visit_graph(self, graph: ir.Graph) -> None:
if self.target_version > CURRENT_MAX_ONNX_OPSET:
if self.target_version > SUPPORTED_MAX_ONNX_OPSET:
logger.warning(
"Conversion to target opset: %s not currently supported.",
self.target_version,
Expand Down
4 changes: 1 addition & 3 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@

import unittest

import onnx.checker
import onnx.defs
import onnx.parser
import onnx.shape_inference

from onnxscript import ir, version_converter


class ApapterCoverageTest(unittest.TestCase):
class AdapterCoverageTest(unittest.TestCase):
def get_all_unique_schema_versions(self) -> dict[str, list]:
"""Collect all unique versions of ONNX standard domain ops"""
op_version_dict = {}
Expand Down
24 changes: 24 additions & 0 deletions tests/version_converter/version_conversion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import pathlib
import unittest

from onnxscript import ir, version_converter

model_folder_path = pathlib.Path(__file__).resolve().parent.parent.parent / "testdata"


class ModelTest(unittest.TestCase):
def test_model_runs_and_matches_accuracy_after_conversion_fallback_true(self):
model_path = model_folder_path / "e2e_models/torchscript_model/torchscript_model.onnx"
model = ir.load(model_path)

# Down convert the model with the onnx version converter
version_converter.convert_version(model, target_version=16, fallback=True)
self.assertEqual(model.opset_imports[""], 16)


if __name__ == "__main__":
unittest.main()
Loading