Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions onnxscript/ir/passes/common/version_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Version conversion passes."""

from __future__ import annotations

__all__ = [
"ConvertVersionPass",
]

import logging

import onnx

from onnxscript import ir
from onnxscript.ir.passes.common import _c_api_utils
from onnxscript.ir.passes.common import inliner as _inliner
from onnxscript.version_converter import _version_converter

logger = logging.getLogger(__name__)


class ConvertVersionPass(ir.passes.InPlacePass):
"""Convert the model to the specified ONNX opset version.

This pass leverages the onnxscript version converter to convert the model. If
the conversion is not supported, it falls back to the onnx C API to convert
the model. This pass is in-place.

The pass is an no-op if the c-api fails.

Attributes:
target_version: The target ONNX opset version to convert the model to.
fallback: Whether to fallback to the onnx version converter if the
target version is not supported. Default is True.
"""

def __init__(self, target_version: int, fallback: bool = False) -> None:
super().__init__()
self.target_version = target_version
self.fallback = fallback
self.inliner = _inliner.InlinePass()

def call(self, model: ir.Model) -> ir.passes.PassResult:
# Normalize the opset import
if "ai.onnx" in model.graph.opset_imports:
model.graph.opset_imports[""] = model.graph.opset_imports["ai.onnx"]
del model.graph.opset_imports["ai.onnx"]

model_opset_version = model.graph.opset_imports[""]
if model_opset_version == self.target_version:
# No need to convert the version
return ir.passes.PassResult(model, False)

# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
# Hence, we inline all the functions.
self.inliner(model)

if _version_converter.version_supported(model_opset_version, self.target_version):
_version_converter.convert_version(
model,
target_version=self.target_version,
)
return ir.passes.PassResult(model, True)

if not self.fallback:
logger.info(
"The model version conversion is not supported by the onnxscript version converter "
"and fallback is disabled. The model was not modified"
" (current version: %d, target version: %d). "
"Set fallback=True to enable fallback to the onnx c-api version converter.",
model_opset_version,
self.target_version,
)
return ir.passes.PassResult(model, False)

# If the onnxscript version converter does not support the conversion,
# we can use the onnx C API to convert the model
def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto:
"""Partial function to check the model."""
return onnx.version_converter.convert_version(
proto, target_version=self.target_version
)

try:
converted_model = _c_api_utils.call_onnx_api(
func=_partial_convert_version, model=model
)
except Exception as e:
Comment thread Fixed
logger.warning(
"Failed to convert the model to the target version %d using the ONNX C API. "
"The model was not modified",
self.target_version,
exc_info=e,
)
return ir.passes.PassResult(model, False)

converted_model = ir.from_proto(converted_model)
Comment thread Fixed

# Recover the initializers in the converted model
for input in converted_model.graph.inputs:
if input.name in model.graph.initializers:
input.const_value = model.graph.initializers[input.name].const_value
converted_model.graph.register_initializer(input)
user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)]
converted_model.graph.inputs.clear()
converted_model.graph.inputs.extend(user_inputs)

# Return the converted graph to the original model to keep the pass in-place
model.graph = converted_model.graph
return ir.passes.PassResult(model, True)
18 changes: 12 additions & 6 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
"convert_version",
]

import onnxscript.optimizer
from onnxscript import ir
from onnxscript.version_converter import _version_converter
from onnxscript.ir.passes.common import version_converter as _version_converter_pass


def convert_version(model: ir.Model, target_version: int) -> None:
"""Convert the model to the specified ONNX opset version."""
def convert_version(model: ir.Model, target_version: int, fallback=False) -> None:
Comment thread
justinchuby marked this conversation as resolved.
"""Convert the model to the specified ONNX opset version.

Args:
model: The model to convert.
target_version: The target ONNX opset version.
fallback: Whether to fallback to the onnx version converter if the
target version is not supported. Default is True.
"""
# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
# Hence, we inline all the functions.
onnxscript.optimizer.inline(model)
_version_converter.convert_version(model, target_version)
_version_converter_pass.ConvertVersionPass(
target_version=target_version, fallback=fallback
)(model)
15 changes: 13 additions & 2 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
logger = logging.getLogger(__name__)


CURRENT_MAX_ONNX_OPSET = 23
SUPPORTED_MAX_ONNX_OPSET = 23
SUPPORTED_MIN_ONNX_OPSET = 18


class VersionConverterError(RuntimeError):
Expand All @@ -38,6 +39,16 @@ class Replacement:
AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue]


def version_supported(current_version: int, target_version: int) -> bool:
"""Check if the target version is supported by the current version."""
return (
SUPPORTED_MIN_ONNX_OPSET
<= current_version
<= target_version
<= SUPPORTED_MIN_ONNX_OPSET
)

Comment thread
justinchuby marked this conversation as resolved.
Outdated

class AdapterRegistry:
"""A class that maintains a registry of adapters for ops."""

Expand Down Expand Up @@ -262,7 +273,7 @@ def visit_node(
return None

def visit_graph(self, graph: ir.Graph) -> None:
if self.target_version > CURRENT_MAX_ONNX_OPSET:
if self.target_version > SUPPORTED_MAX_ONNX_OPSET:
logger.warning(
"Conversion to target opset: %s not currently supported.",
self.target_version,
Expand Down
Loading