Skip to content

Commit 5ae26d9

Browse files
authored
Fix version converter regression (#2799)
This pull request improves the ONNX version converter by enabling direct opset version conversion for models containing functions, and ensures that version conversion now correctly applies to function nodes and their subgraphs. The changes also update the internal APIs to handle both models and functions, and expand test coverage for these scenarios. Enhancements to version conversion logic: * Removed the requirement to inline functions before version conversion, allowing `_ConvertVersionPass` to process models with functions directly (`onnxscript/version_converter/__init__.py`). [[1]](diffhunk://#diff-e2693e79e7e645175e987e44ba3432063a21003bf4aa50119c1a786925f6e042L41-R41) [[2]](diffhunk://#diff-e2693e79e7e645175e987e44ba3432063a21003bf4aa50119c1a786925f6e042L55-R54) [[3]](diffhunk://#diff-e2693e79e7e645175e987e44ba3432063a21003bf4aa50119c1a786925f6e042L76-L81) * Updated the internal API (`_set_onnx_opset_version`) to support both models and functions, so opset imports are set correctly for functions as well as models (`onnxscript/version_converter/_version_converter.py`). Support for function and subgraph conversion: * Modified traversal logic to process nodes inside functions and subgraphs within control flow nodes, ensuring that all relevant nodes are version-converted (`onnxscript/version_converter/_version_converter.py`). [[1]](diffhunk://#diff-b6c70f90bafaee79b30e43c90bc0fd5192fb3de7ccc4cf9d48a209798dd775faL277-R280) [[2]](diffhunk://#diff-b6c70f90bafaee79b30e43c90bc0fd5192fb3de7ccc4cf9d48a209798dd775faL306-R307) [[3]](diffhunk://#diff-b6c70f90bafaee79b30e43c90bc0fd5192fb3de7ccc4cf9d48a209798dd775faL324-R324) [[4]](diffhunk://#diff-b6c70f90bafaee79b30e43c90bc0fd5192fb3de7ccc4cf9d48a209798dd775faL334-R337) Expanded test coverage: * Added tests to verify that nodes inside functions and subgraphs within control flow nodes (e.g., `If` branches) are correctly version-converted, and that opset imports are updated for functions (`onnxscript/version_converter/_version_converter_test.py`).
1 parent 0d460ff commit 5ae26d9

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

onnxscript/version_converter/_version_converter.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def _get_onnx_opset_version(model: ir.Model) -> int | None:
3434
return model_version1 or model_version2
3535

3636

37-
def _set_onnx_opset_version(model: ir.Model, version: int) -> None:
38-
"""Set the ONNX opset version imported by the model."""
39-
if "ai.onnx" in model.opset_imports:
40-
del model.opset_imports["ai.onnx"]
41-
model.opset_imports[""] = version
37+
def _set_onnx_opset_version(model_or_function: ir.Model | ir.Function, version: int) -> None:
38+
"""Set the ONNX opset version imported by the model or function."""
39+
if "ai.onnx" in model_or_function.opset_imports:
40+
del model_or_function.opset_imports["ai.onnx"]
41+
model_or_function.opset_imports[""] = version
4242

4343

4444
class VersionConverterError(RuntimeError):
@@ -334,6 +334,7 @@ def visit_model(self, model: ir.Model) -> None:
334334
self.visit_graph_or_function(model.graph)
335335
for function in model.functions.values():
336336
self.visit_graph_or_function(function)
337+
_set_onnx_opset_version(function, self._target_version)
337338
_set_onnx_opset_version(model, self._target_version)
338339

339340

onnxscript/version_converter/_version_converter_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,11 @@ def test_version_convert_function_nodes(self):
237237
version_converter.convert_version(model, target_version=target_version)
238238
self.assertEqual(model.opset_imports[""], target_version)
239239

240-
# Verify that nodes inside the function were version-converted
240+
# Verify that the function's opset_imports are updated
241241
func = model.functions[("pkg.custom", "dft_func", "")]
242+
self.assertEqual(func.opset_imports[""], target_version)
243+
244+
# Verify that nodes inside the function were version-converted
242245
self.assertEqual(func[0].op_type, "Constant")
243246
self.assertEqual(func[0].version, 20)
244247
self.assertEqual(func[1].op_type, "Reshape")
@@ -293,8 +296,12 @@ def test_version_convert_function_with_control_flow_subgraph(self):
293296
version_converter.convert_version(model, target_version=target_version)
294297
self.assertEqual(model.opset_imports[""], target_version)
295298

296-
# Verify nodes inside the function's If node subgraphs were version-converted
299+
# Verify that the function's opset_imports are updated
297300
func = model.functions[("pkg.custom", "conditional_dft", "")]
301+
self.assertEqual(func.opset_imports[""], target_version)
302+
303+
# Verify nodes inside the function's If node subgraphs were version-converted
304+
# Verify nodes inside the function's If node subgraphs were version-converted
298305
if_node = func[0]
299306
self.assertEqual(if_node.op_type, "If")
300307
self.assertEqual(if_node.version, 20)

0 commit comments

Comments
 (0)