Skip to content

Commit be63ea7

Browse files
committed
Revert "Remove function value error in version converter (microsoft#2791)"
This reverts commit 74a5f34.
1 parent 0d23d32 commit be63ea7

3 files changed

Lines changed: 37 additions & 106 deletions

File tree

onnxscript/version_converter/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
3838
self.target_version = target_version
3939
self.fallback = fallback
4040
self.convert_pass = ir.passes.Sequential(
41-
_ConvertVersionPass(
41+
common_passes.InlinePass(),
42+
_ConvertVersionPassRequiresInline(
4243
target_version=target_version,
4344
fallback=fallback,
4445
),
@@ -51,7 +52,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
5152
return self.convert_pass(model)
5253

5354

54-
class _ConvertVersionPass(ir.passes.InPlacePass):
55+
class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
5556
"""Convert the model to the specified ONNX opset version.
5657
5758
This pass leverages the onnxscript version converter to convert the model. If
@@ -72,6 +73,12 @@ def __init__(self, target_version: int, fallback: bool) -> None:
7273
self.fallback = fallback
7374

7475
def call(self, model: ir.Model) -> ir.passes.PassResult:
76+
if model.functions:
77+
raise ValueError(
78+
"The model contains functions. The version conversion pass does not support "
79+
"functions. Please use `common_passes.InlinePass` to inline the "
80+
f"functions before applying this pass ({self.__class__.__name__})."
81+
)
7582
if "" in model.graph.opset_imports:
7683
onnx_opset_version = model.graph.opset_imports[""]
7784
if onnx_opset_version == self.target_version:

onnxscript/version_converter/_version_converter.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,10 @@ def visit_attribute(self, attr: ir.Attr) -> None:
274274
if attr.is_ref():
275275
return
276276
if attr.type == ir.AttributeType.GRAPH:
277-
self.visit_graph_or_function(attr.as_graph())
277+
self.visit_graph(attr.as_graph())
278278
elif attr.type == ir.AttributeType.GRAPHS:
279279
for graph in attr.as_graphs():
280-
self.visit_graph_or_function(graph)
280+
self.visit_graph(graph)
281281

282282
def visit_node(
283283
self,
@@ -303,8 +303,8 @@ def visit_node(
303303
self._default_metadata_merger.copy_merged_metadata([node], replacement.new_nodes)
304304
self.replace_node(node, replacement, root)
305305

306-
def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) -> None:
307-
for node in graph_or_function:
306+
def visit_graph(self, graph: ir.Graph) -> None:
307+
for node in graph:
308308
if node.domain != "":
309309
continue
310310
node_version = node.version or self._default_onnx_opset
@@ -321,7 +321,7 @@ def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) ->
321321
)
322322
for from_version in range(node_version, self._target_version):
323323
try:
324-
self.visit_node(node, graph_or_function, from_version, up_conversion=True)
324+
self.visit_node(node, graph, from_version, up_conversion=True)
325325
except VersionConverterError as e:
326326
logger.warning(
327327
"Skipping version conversion for node %s due to exception: %s",
@@ -331,9 +331,7 @@ def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) ->
331331

332332
def visit_model(self, model: ir.Model) -> None:
333333
self._default_onnx_opset = _get_onnx_opset_version(model)
334-
self.visit_graph_or_function(model.graph)
335-
for function in model.functions.values():
336-
self.visit_graph_or_function(function)
334+
self.visit_graph(model.graph)
337335
_set_onnx_opset_version(model, self._target_version)
338336

339337

onnxscript/version_converter/_version_converter_test.py

Lines changed: 22 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -208,114 +208,40 @@ def test_version_convert_gridsample_cubic(self):
208208
self.assertEqual(model.graph.node(4).version, 20)
209209
self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic")
210210

211-
def test_version_convert_function_nodes(self):
212-
"""Test that version converter processes nodes inside model functions."""
211+
def test_version_convert_inline(self):
213212
model = ir.from_onnx_text(
214213
"""
215-
<ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
216-
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
217-
{
218-
output = pkg.custom.dft_func (input_x)
219-
}
220-
221-
<domain: "pkg.custom", opset_import: [ "" : 18]>
222-
dft_func (x) => (result) {
223-
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
224-
reshape_x = Reshape (x, shape_a)
225-
dft = DFT <axis = 2, onesided = 1> (reshape_x)
226-
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
227-
result = Reshape (dft, shape_c)
228-
}
229-
"""
230-
)
231-
# Verify the function exists with correct initial state
232-
self.assertEqual(len(model.functions), 1)
233-
func = model.functions[("pkg.custom", "dft_func", "")]
234-
self.assertEqual(len(func), 5) # 5 nodes in the function
235-
236-
target_version = 20
237-
version_converter.convert_version(model, target_version=target_version)
238-
self.assertEqual(model.opset_imports[""], target_version)
239-
240-
# Verify that nodes inside the function were version-converted
241-
func = model.functions[("pkg.custom", "dft_func", "")]
242-
self.assertEqual(func[0].op_type, "Constant")
243-
self.assertEqual(func[0].version, 20)
244-
self.assertEqual(func[1].op_type, "Reshape")
245-
self.assertEqual(func[1].version, 20)
246-
# After DFT adapter, a new Constant node is inserted for dft_length
247-
self.assertEqual(func[2].op_type, "Constant")
248-
self.assertEqual(func[2].version, 20)
249-
self.assertEqual(func[3].op_type, "DFT")
250-
self.assertEqual(func[3].version, 20)
251-
self.assertEqual(len(func[3].inputs), 3) # DFT 19->20 adds dft_length input
252-
253-
def test_version_convert_function_with_control_flow_subgraph(self):
254-
"""Test that version converter processes subgraphs inside control flow nodes in functions."""
255-
model = ir.from_onnx_text(
256-
"""
257-
<ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
258-
agraph (float[4, 512, 512] input_x, bool cond) => (float[4, 257, 64, 2] output)
214+
<ir_version: 8, opset_import: [ "" : 18]>
215+
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output)
259216
{
260-
output = pkg.custom.conditional_dft (input_x, cond)
217+
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512}>()
218+
reshape_x = Reshape (input_x, shape_a)
219+
shape_b = Constant<value: tensor = int64[5] {1, 4, 1024, 1024}>()
220+
reshape_y = Reshape (input_x, shape_b)
221+
gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
222+
output = foo(gridsample)
261223
}
262224
263-
<domain: "pkg.custom", opset_import: [ "" : 18]>
264-
conditional_dft (x, cond) => (result) {
265-
result = If (cond) <then_branch: graph = then_graph () => (out) {
266-
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
267-
reshape_x = Reshape (x, shape_a)
268-
dft = DFT <axis = 2, onesided = 1> (reshape_x)
269-
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
270-
out = Reshape (dft, shape_c)
271-
}, else_branch: graph = else_graph () => (out) {
272-
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
273-
out = Reshape (x, shape_c)
274-
}>
225+
<opset_import: [ "" : 18]>
226+
foo (x) => (dft) {
227+
dft = DFT <axis = 2, onesided = 1> (x)
275228
}
276229
"""
277230
)
278-
# Verify the function exists with correct initial state
279-
self.assertEqual(len(model.functions), 1)
280-
func = model.functions[("pkg.custom", "conditional_dft", "")]
281-
self.assertEqual(len(func), 1) # 1 node (If) in the function
282-
283-
# Verify the If node has subgraphs
284-
if_node = func[0]
285-
self.assertEqual(if_node.op_type, "If")
286-
then_branch = if_node.attributes["then_branch"].as_graph()
287-
else_branch = if_node.attributes["else_branch"].as_graph()
288-
self.assertEqual(len(then_branch), 5) # 5 nodes in then_branch
289-
self.assertEqual(len(else_branch), 2) # 2 nodes in else_branch
290-
291231
target_version = 20
292-
# Use internal API to test function version conversion without inlining
293232
version_converter.convert_version(model, target_version=target_version)
294233
self.assertEqual(model.opset_imports[""], target_version)
295234

296-
# Verify nodes inside the function's If node subgraphs were version-converted
297-
func = model.functions[("pkg.custom", "conditional_dft", "")]
298-
if_node = func[0]
299-
self.assertEqual(if_node.op_type, "If")
300-
self.assertEqual(if_node.version, 20)
301-
302-
# Check then_branch subgraph nodes
303-
then_branch = if_node.attributes["then_branch"].as_graph()
304-
# After DFT adapter, a new Constant node is inserted for dft_length
305-
self.assertEqual(len(then_branch), 6) # 5 + 1 new Constant for DFT
306-
dft_node = None
307-
for node in then_branch:
308-
self.assertEqual(node.version, 20)
309-
if node.op_type == "DFT":
310-
dft_node = node
311-
self.assertIsNotNone(dft_node)
312-
self.assertEqual(len(dft_node.inputs), 3) # DFT 19->20 adds dft_length input
313-
314-
# Check else_branch subgraph nodes
315-
else_branch = if_node.attributes["else_branch"].as_graph()
316-
self.assertEqual(len(else_branch), 2)
317-
for node in else_branch:
318-
self.assertEqual(node.version, 20)
235+
self.assertEqual(model.graph.node(0).op_type, "Constant")
236+
self.assertEqual(model.graph.node(0).version, 20)
237+
self.assertEqual(model.graph.node(1).op_type, "Reshape")
238+
self.assertEqual(model.graph.node(1).version, 20)
239+
self.assertEqual(model.graph.node(4).op_type, "GridSample")
240+
self.assertEqual(model.graph.node(4).version, 20)
241+
self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear")
242+
self.assertEqual(model.graph.node(6).op_type, "DFT")
243+
self.assertEqual(model.graph.node(6).version, 20)
244+
self.assertEqual(len(model.graph.node(6).inputs), 3)
319245

320246

321247
class VersionConverter20to21Test(unittest.TestCase):

0 commit comments

Comments
 (0)