Skip to content

Commit 48bfc65

Browse files
committed
Fix case
1 parent 9d6b1c8 commit 48bfc65

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

onnxscript/ir/passes/common/version_converter.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,17 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
4242
self.inliner = _inliner.InlinePass()
4343

4444
def call(self, model: ir.Model) -> ir.passes.PassResult:
45-
# Normalize the opset import
46-
if "ai.onnx" in model.graph.opset_imports:
47-
model.graph.opset_imports[""] = model.graph.opset_imports["ai.onnx"]
48-
del model.graph.opset_imports["ai.onnx"]
49-
50-
model_opset_version = model.graph.opset_imports[""]
51-
if model_opset_version == self.target_version:
52-
# No need to convert the version
53-
return ir.passes.PassResult(model, False)
45+
if "" in model.graph.opset_imports:
46+
onnx_opset_version = model.graph.opset_imports[""]
47+
if onnx_opset_version == self.target_version:
48+
# No need to convert the version
49+
return ir.passes.PassResult(model, False)
5450

5551
# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
5652
# Hence, we inline all the functions.
5753
self.inliner(model)
5854

59-
if _version_converter.version_supported(model_opset_version, self.target_version):
55+
if _version_converter.version_supported(model, self.target_version):
6056
_version_converter.convert_version(
6157
model,
6258
target_version=self.target_version,
@@ -67,9 +63,8 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
6763
logger.info(
6864
"The model version conversion is not supported by the onnxscript version converter "
6965
"and fallback is disabled. The model was not modified"
70-
" (current version: %d, target version: %d). "
66+
" (target version: %d). "
7167
"Set fallback=True to enable fallback to the onnx c-api version converter.",
72-
model_opset_version,
7368
self.target_version,
7469
)
7570
return ir.passes.PassResult(model, False)

onnxscript/version_converter/_version_converter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,20 @@ class Replacement:
3939
AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue]
4040

4141

42-
def version_supported(current_version: int, target_version: int) -> bool:
42+
def version_supported(model: ir.Model, target_version: int) -> bool:
4343
"""Check if the target version is supported by the current version."""
44+
if "" in model.graph.opset_imports:
45+
current_version = model.graph.opset_imports[""]
46+
else:
47+
return True
4448
return (
4549
SUPPORTED_MIN_ONNX_OPSET
4650
<= current_version
4751
<= target_version
4852
<= SUPPORTED_MAX_ONNX_OPSET
4953
)
5054

55+
5156
class AdapterRegistry:
5257
"""A class that maintains a registry of adapters for ops."""
5358

onnxscript/version_converter/_version_converter_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44

55
import unittest
66

7-
import onnx.checker
87
import onnx.defs
98
import onnx.parser
10-
import onnx.shape_inference
119

1210
from onnxscript import ir, version_converter
1311

1412

15-
class ApapterCoverageTest(unittest.TestCase):
13+
class AdapterCoverageTest(unittest.TestCase):
1614
def get_all_unique_schema_versions(self) -> dict[str, list]:
1715
"""Collect all unique versions of ONNX standard domain ops"""
1816
op_version_dict = {}

0 commit comments

Comments
 (0)