Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
self.target_version = target_version
self.fallback = fallback
self.convert_pass = ir.passes.Sequential(
common_passes.InlinePass(),
_ConvertVersionPassRequiresInline(
target_version=target_version,
fallback=fallback,
Expand Down Expand Up @@ -73,12 +72,6 @@ def __init__(self, target_version: int, fallback: bool) -> None:
self.fallback = fallback

def call(self, model: ir.Model) -> ir.passes.PassResult:
if model.functions:
raise ValueError(
Comment thread
titaiwangms marked this conversation as resolved.
"The model contains functions. The version conversion pass does not support "
"functions. Please use `common_passes.InlinePass` to inline the "
f"functions before applying this pass ({self.__class__.__name__})."
)
if "" in model.graph.opset_imports:
onnx_opset_version = model.graph.opset_imports[""]
if onnx_opset_version == self.target_version:
Expand Down
25 changes: 25 additions & 0 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,34 @@ def visit_graph(self, graph: ir.Graph) -> None:
e,
)

def visit_function(self, function: ir.Function) -> None:
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
"""Visit a function and convert nodes to the target opset version."""
for node in function:
if node.domain != "":
continue
node_version = node.version or self._default_onnx_opset
if node_version is None:
raise VersionConverterError(f"Node {node} has no version.")
if self._target_version < node_version:
raise VersionConverterError(
f"Target opset: {self._target_version} less than node version: {node.version}, "
"downstream version conversion not currently handled."
)
for from_version in range(node_version, self._target_version):
try:
self.visit_node(node, function, from_version, up_conversion=True)
except VersionConverterError as e:
logger.warning(
"Skipping version conversion for node %s due to exception: %s",
node.op_type,
e,
)
Comment thread
justinchuby marked this conversation as resolved.

def visit_model(self, model: ir.Model) -> None:
self._default_onnx_opset = _get_onnx_opset_version(model)
self.visit_graph(model.graph)
for function in model.functions.values():
self.visit_function(function)
_set_onnx_opset_version(model, self._target_version)


Expand Down
42 changes: 42 additions & 0 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,48 @@ def test_version_convert_inline(self):
self.assertEqual(model.graph.node(6).version, 20)
self.assertEqual(len(model.graph.node(6).inputs), 3)

def test_version_convert_function_nodes(self):
"""Test that version converter processes nodes inside model functions."""
model = ir.from_onnx_text(
"""
<ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
{
output = pkg.custom.dft_func (input_x)
}

<domain: "pkg.custom", opset_import: [ "" : 18]>
dft_func (x) => (result) {
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
reshape_x = Reshape (x, shape_a)
dft = DFT <axis = 2, onesided = 1> (reshape_x)
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
result = Reshape (dft, shape_c)
}
"""
)
# Verify the function exists with correct initial state
self.assertEqual(len(model.functions), 1)
func = model.functions[("pkg.custom", "dft_func", "")]
self.assertEqual(len(func), 5) # 5 nodes in the function

target_version = 20
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

# Verify that nodes inside the function were version-converted
func = model.functions[("pkg.custom", "dft_func", "")]
self.assertEqual(func[0].op_type, "Constant")
self.assertEqual(func[0].version, 20)
self.assertEqual(func[1].op_type, "Reshape")
self.assertEqual(func[1].version, 20)
# After DFT adapter, a new Constant node is inserted for dft_length
self.assertEqual(func[2].op_type, "Constant")
self.assertEqual(func[2].version, 20)
self.assertEqual(func[3].op_type, "DFT")
self.assertEqual(func[3].version, 20)
self.assertEqual(len(func[3].inputs), 3) # DFT 19->20 adds dft_length input


class VersionConverter20to21Test(unittest.TestCase):
def test_version_groupnorm(self):
Expand Down
Loading