Skip to content

Commit 3f197a2

Browse files
authored
Support outer scope value reference in script (#2831)
This extends onnxscript scripts to allow references to ir.Values defined outside the function. This allows the use of script functions within the construction of graphs using a GraphBuilder. Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 1fdfb1b commit 3f197a2

File tree

6 files changed

+64
-18
lines changed

6 files changed

+64
-18
lines changed

onnxscript/_internal/_inliner.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,29 @@
1010

1111

1212
def instantiate(
13-
function: ir.Function,
13+
graph: ir.Graph,
1414
inputs: Sequence[ir.Value | None],
1515
attributes: Mapping[str, ir.Attr],
1616
*,
1717
prefix: str = "",
1818
) -> tuple[list[ir.Node], list[ir.Value | None]]:
19-
"""Instantiate (inline) a function, substituting inputs and attributes.
19+
"""Instantiate (inline) a graph, substituting inputs and attributes.
2020
2121
Args:
22-
function: The function to instantiate.
23-
inputs: Actual input values to bind to the function's formal parameters.
22+
graph: The graph to instantiate.
23+
inputs: Actual input values to bind to the graph's formal parameters.
2424
attributes: Attribute values to substitute for reference attributes.
2525
prefix: Optional prefix to prepend to node and output names.
2626
2727
Returns:
28-
A tuple of (nodes, outputs) where nodes are the cloned function body
29-
and outputs are the values corresponding to the function's outputs.
28+
A tuple of (nodes, outputs) where nodes are the cloned graph body
29+
and outputs are the values corresponding to the graph's outputs.
3030
"""
31-
formal_inputs = function.inputs
31+
formal_inputs = graph.inputs
3232
if len(inputs) > len(formal_inputs):
3333
raise ValueError(
3434
f"Too many inputs: got {len(inputs)}, "
35-
f"but function has {len(formal_inputs)} parameters."
35+
f"but graph has {len(formal_inputs)} parameters."
3636
)
3737
value_map: dict[ir.Value, ir.Value | None] = dict(zip(formal_inputs, inputs))
3838

@@ -50,7 +50,8 @@ def rename(node: ir.Node) -> None:
5050
metadata_props={},
5151
post_process=rename,
5252
resolve_ref_attrs=True,
53+
allow_outer_scope_values=True,
5354
)
54-
nodes = [cloner.clone_node(n) for n in function]
55-
outputs = [value_map.get(v) for v in function.outputs]
55+
nodes = [cloner.clone_node(n) for n in graph]
56+
outputs = [value_map.get(v) for v in graph.outputs]
5657
return nodes, outputs

onnxscript/_internal/builder.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -450,25 +450,24 @@ def call(
450450
**kwargs,
451451
):
452452
if isinstance(function, ir.Function):
453-
function_ir = function
453+
graph = function.graph
454454
elif isinstance(function, onnxscript.OnnxFunction):
455-
function_proto = function.to_function_proto()
456-
function_ir = ir.serde.deserialize_function(function_proto)
455+
graph = function.graph()
457456
else:
458457
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
459458
output_renaming: dict[str, str] = {}
460459
if _outputs is not None:
461-
if len(_outputs) != len(function_ir.outputs):
460+
if len(_outputs) != len(graph.outputs):
462461
raise ValueError(
463462
f"Number of provided output names {_outputs} does not match "
464-
f"number of function outputs {len(function_ir.outputs)}."
463+
f"number of function outputs {len(graph.outputs)}."
465464
)
466-
for output, name in zip(function_ir.outputs, _outputs):
465+
for output, name in zip(graph.outputs, _outputs):
467466
output_renaming[output.name] = self._qualify_value_name(name)
468467
else:
469-
for output in function_ir.outputs:
468+
for output in graph.outputs:
470469
output_renaming[output.name] = self._qualify_value_name(output.name)
471-
nodes, outputs = _inliner.instantiate(function_ir, args, kwargs)
470+
nodes, outputs = _inliner.instantiate(graph, args, kwargs)
472471
if _prefix:
473472
self.push_module(_prefix)
474473
for node in nodes:

onnxscript/_internal/builder_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import onnx_ir as ir
1111

1212
import onnxscript._internal.builder as builder
13+
import onnxscript.testing
1314
from onnxscript import script
1415
from onnxscript.onnx_types import DOUBLE, FLOAT
1516

@@ -713,6 +714,31 @@ def add_mul(X, Y):
713714
self.assertEqual(nodes[0].op_type, "Add")
714715
self.assertEqual(nodes[1].op_type, "Mul")
715716

717+
def test_call_with_outer_scope_value(self):
718+
"""Test that script supports references to pre-existing values."""
719+
# Create a GraphBuilder first
720+
op, x, y = _create_builder_with_inputs()
721+
product = op.Mul(x, y)
722+
723+
@script()
724+
def add_product(X):
725+
return op.Add(X, product) # Reference to 'product' from outer scope
726+
727+
x_plus = op.call(add_product, x, _outputs=["x_plus"])
728+
y_plus = op.call(add_product, y, _outputs=["y_plus"])
729+
730+
op.builder.graph.outputs.extend([x_plus, y_plus])
731+
732+
# Now, create the same graph directly:
733+
op2, x2, y2 = _create_builder_with_inputs()
734+
product2 = op2.Mul(x2, y2)
735+
x2_plus = op2.Add(x2, product2, _outputs=["x_plus"])
736+
y2_plus = op2.Add(y2, product2, _outputs=["y_plus"])
737+
op2.builder.graph.outputs.extend([x2_plus, y2_plus])
738+
739+
# Verify that the two graphs are structurally equivalent
740+
onnxscript.testing.assert_isomorphic_graph(op.builder.graph, op2.builder.graph)
741+
716742
def test_call_with_prefix_option(self):
717743
"""Test that GraphBuilder.call respects the _prefix option for hierarchical naming."""
718744
# Create a GraphBuilder first

onnxscript/_internal/converter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,9 @@ def _to_onnx_var(
377377
if isinstance(val, values.SymbolValue):
378378
if isinstance(val.value, ir.Value):
379379
return val.value
380+
if isinstance(val, ir.Value):
381+
# An outer-scope ir.Value (e.g., from a closure variable) can be used directly.
382+
return val
380383
# Assume value is a python-value convertible to a tensor
381384
# TODO: check if value is convertible to a TensorProto, so that we can
382385
# produce a better error _message otherwise

onnxscript/_internal/values.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,15 @@ def to_function_proto(self) -> onnx.FunctionProto:
333333
"""Converts the function into :class:`onnx.FunctionProto`."""
334334
return self.function_ir.to_function_proto()
335335

336+
def graph(self) -> ir.Graph:
337+
"""Returns the IR graph representation of this function.
338+
339+
Returns:
340+
The :class:`ir.Graph` representing the computation graph of this function.
341+
NOTE: This is not a copy, and should not be modified by the caller.
342+
"""
343+
return self.function_ir.graph
344+
336345
def to_model_proto(self, **kwargs):
337346
"""Converts the function into :class:`onnx.ModelProto`."""
338347
if self.function_ir.attrs and any(

onnxscript/testing/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def _to_function_proto(f):
320320
return f.to_function_proto()
321321
if isinstance(f, str):
322322
return parser.parse_function(f)
323+
if isinstance(f, ir.Function):
324+
return ir.to_proto(f)
323325
raise TypeError(f"Cannot convert {type(f)} to FunctionProto")
324326

325327

@@ -330,6 +332,8 @@ def _to_graph_proto(g):
330332
return g.to_model_proto().graph
331333
if isinstance(g, str):
332334
return parser.parse_graph(g)
335+
if isinstance(g, ir.Graph):
336+
return ir.to_proto(g)
333337
raise TypeError(f"Cannot convert {type(g)} to ModelProto")
334338

335339

@@ -342,6 +346,10 @@ def _to_function_or_graph(obj):
342346
return obj.graph
343347
if isinstance(obj, onnxscript.OnnxFunction):
344348
return obj.to_function_proto()
349+
if isinstance(obj, ir.Function):
350+
return ir.to_proto(obj)
351+
if isinstance(obj, ir.Graph):
352+
return ir.to_proto(obj)
345353
raise TypeError(f"Cannot convert {type(obj)} to FunctionProto or GraphProto")
346354

347355

0 commit comments

Comments
 (0)