Skip to content

Commit 2c28041

Browse files
committed
Enhance GraphBuilder to support outer-scope values in inlined function bodies
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 91affd4 commit 2c28041

File tree

2 files changed

+3
-34
lines changed

2 files changed

+3
-34
lines changed

onnxscript/_internal/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,8 @@ def call_inline(
894894
if isinstance(function, ir.Function):
895895
graph = function.graph
896896
elif isinstance(function, onnxscript.OnnxFunction):
897-
graph = function.graph()
897+
# TODO(justinchuby): Reason about support for outer-scope values in inlined function bodies.
898+
graph = function.graph().clone(allow_outer_scope_values=True)
898899
else:
899900
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
900901
if _outputs is not None:

onnxscript/_internal/builder_test.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def add_mul(X, Y):
876876
self.assertEqual(nodes[0].op_type, "Add")
877877
self.assertEqual(nodes[1].op_type, "Mul")
878878

879-
def test_call_with_outer_scope_value(self):
879+
def test_call_inline_with_outer_scope_value(self):
880880
"""Test that script supports references to pre-existing values."""
881881
# Create a GraphBuilder first
882882
op, x, y = _create_builder_with_inputs()
@@ -1180,38 +1180,6 @@ def simple_add(X, Y):
11801180
# Function should be registered
11811181
self.assertEqual(len(op.builder.functions), 1)
11821182

1183-
def test_call_inline_produces_more_nodes_than_call(self):
1184-
"""Verify that call() produces exactly 1 function-call node while call_inline()
1185-
expands the function body into individual op nodes. This is the core behavioral
1186-
difference between the two APIs.
1187-
"""
1188-
# Inline version
1189-
op1, x1, y1 = _create_builder_with_inputs()
1190-
1191-
@script(default_opset=op1)
1192-
def mul_add(X, Y):
1193-
tmp = X * Y
1194-
return op1.Add(tmp, X)
1195-
1196-
op1.call_inline(mul_add, x1, y1)
1197-
inline_nodes = list(op1.builder.graph)
1198-
1199-
# Non-inline version
1200-
op2, x2, y2 = _create_builder_with_inputs()
1201-
1202-
@script(default_opset=op2)
1203-
def mul_add2(X, Y):
1204-
tmp = X * Y
1205-
return op2.Add(tmp, X)
1206-
1207-
op2.call(mul_add2, x2, y2)
1208-
non_inline_nodes = list(op2.builder.graph)
1209-
1210-
# Inlining should produce 2 nodes (Mul, Add), non-inlining should produce 1
1211-
self.assertEqual(len(inline_nodes), 2)
1212-
self.assertEqual(len(non_inline_nodes), 1)
1213-
self.assertEqual(non_inline_nodes[0].op_type, "mul_add2")
1214-
12151183

12161184
class BuildSubgraphTest(unittest.TestCase):
12171185
"""Tests for GraphBuilder.subgraph()."""

0 commit comments

Comments
 (0)