Skip to content

Commit f99fa7d

Browse files
gramalingamgithub-advanced-security[bot]justinchubyCopilot
authored
Allow GraphBuilder to call script functions (#2820)
* This extension allows calls to scripted functions within traced (GraphBuilder) calls. * As part of this extension, allow the use of OpBuilder within scripted function * Also fix the default value naming strategy to include node number to ensure uniqueness (in common cases). TODO later: Consider integration of standard execution of script functions with execution of OpBuilder calls, as well as other extensions within script-mode to ensure both modes are seamless and uniform --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 8b79816 commit f99fa7d

File tree

5 files changed

+444
-12
lines changed

5 files changed

+444
-12
lines changed

docs/tutorial/builder/graph_builder.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,110 @@ def build_linear(op, x, weight, bias_value):
406406

407407
This pattern keeps function signatures simple while preserving access to the
408408
full builder API when needed.
409+
## Calling Script Functions from OpBuilder
410+
411+
The `OpBuilder` provides a `call()` method to inline `@script`-decorated ONNX functions directly into the builder's graph. This enables composition of both imperative (builder) and declarative (`@script`) code within a single graph.
412+
413+
### Basic function inlining
414+
415+
Define an ONNX script function and then call it through `op.call()`:
416+
417+
```python
418+
from onnxscript import script, opset23 as op23
419+
420+
# Define a reusable script function
421+
@script(default_opset=op23)
422+
def mul_add_relu(X, Y):
423+
tmp = X * Y
424+
tmp = tmp + X
425+
return op23.Relu(tmp)
426+
427+
# Now build a graph using OpBuilder
428+
graph = ir.Graph(
429+
name="my_graph",
430+
inputs=[],
431+
outputs=[],
432+
nodes=[],
433+
opset_imports={"": 23},
434+
)
435+
x = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4]))
436+
y = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4]))
437+
graph.inputs.extend([x, y])
438+
439+
builder = onnxscript.GraphBuilder(graph)
440+
op = builder.op
441+
442+
# Call the script function — it gets inlined into the graph
443+
result = op.call(mul_add_relu, x, y)
444+
graph.outputs.append(result)
445+
```
446+
447+
The function body (three nodes: Mul, Add, Relu) is inlined directly into the graph.
448+
449+
### Renaming outputs with `_outputs`
450+
451+
By default, inlined function outputs keep their original names, qualified by the
452+
current naming context. You can rename them explicitly with `_outputs`:
453+
454+
```python
455+
@script(default_opset=op23)
456+
def add_mul(X, Y):
457+
a = X + Y
458+
b = X * Y
459+
return a, b
460+
461+
# Inline with custom output names
462+
result_sum, result_prod = op.call(
463+
add_mul, x, y,
464+
_outputs=["custom_sum", "custom_product"]
465+
)
466+
```
467+
468+
### Adding hierarchical context with `_prefix`
469+
470+
Use `_prefix` to add a naming context to all nodes and intermediate values created
471+
by the inlined function:
472+
473+
```python
474+
result = op.call(
475+
mul_add_relu, x, y,
476+
_prefix="layer1"
477+
)
478+
# Node names will be "layer1.Mul_n...", "layer1.Add_n...", "layer1.Relu_n..."
479+
# Intermediate value names will also start with "layer1."
480+
```
481+
482+
You can combine both options:
483+
484+
```python
485+
result_a, result_b = op.call(
486+
add_mul, x, y,
487+
_outputs=["sum_out", "prod_out"],
488+
_prefix="math_ops"
489+
)
490+
# Final outputs: "sum_out", "prod_out" (renamed before prefix context)
491+
# Intermediate values: "math_ops.Add_n...", "math_ops.Mul_n..." (with prefix)
492+
```
493+
494+
### Using OpBuilder as the default_opset
495+
496+
`OpBuilder` can be passed directly as the `default_opset` when decorating a script
497+
function. This enables scripted functions to use the same opset version as the
498+
builder they will be inlined into:
499+
500+
```python
501+
builder = onnxscript.GraphBuilder(graph)
502+
op = builder.op
503+
504+
# Define the function *after* creating the builder, using op as default_opset
505+
@script(default_opset=op)
506+
def my_func(X, Y):
507+
t = X + Y
508+
return op.Relu(t) # Uses the op directly
509+
510+
# Inline it
511+
result = op.call(my_func, x, y)
512+
```
513+
514+
This pattern ensures consistency: the script function operates in the same domain
515+
and opset version as the builder.

onnxscript/_internal/_inliner.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from __future__ import annotations
5+
6+
from typing import Mapping, Sequence
7+
8+
import onnx_ir as ir
9+
from onnx_ir._cloner import Cloner
10+
11+
12+
def instantiate(
13+
function: ir.Function,
14+
inputs: Sequence[ir.Value | None],
15+
attributes: Mapping[str, ir.Attr],
16+
*,
17+
prefix: str = "",
18+
) -> tuple[list[ir.Node], list[ir.Value | None]]:
19+
"""Instantiate (inline) a function, substituting inputs and attributes.
20+
21+
Args:
22+
function: The function to instantiate.
23+
inputs: Actual input values to bind to the function's formal parameters.
24+
attributes: Attribute values to substitute for reference attributes.
25+
prefix: Optional prefix to prepend to node and output names.
26+
27+
Returns:
28+
A tuple of (nodes, outputs) where nodes are the cloned function body
29+
and outputs are the values corresponding to the function's outputs.
30+
"""
31+
formal_inputs = function.inputs
32+
if len(inputs) > len(formal_inputs):
33+
raise ValueError(
34+
f"Too many inputs: got {len(inputs)}, "
35+
f"but function has {len(formal_inputs)} parameters."
36+
)
37+
value_map: dict[ir.Value, ir.Value | None] = dict(zip(formal_inputs, inputs))
38+
39+
def rename(node: ir.Node) -> None:
40+
if prefix:
41+
if node.name:
42+
node.name = prefix + node.name
43+
for output in node.outputs:
44+
if output is not None and output.name:
45+
output.name = prefix + output.name
46+
47+
cloner = Cloner(
48+
attr_map=attributes,
49+
value_map=value_map,
50+
metadata_props={},
51+
post_process=rename,
52+
resolve_ref_attrs=True,
53+
)
54+
nodes = [cloner.clone_node(n) for n in function]
55+
outputs = [value_map.get(v) for v in function.outputs]
56+
return nodes, outputs

onnxscript/_internal/builder.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3+
"""Graph builder for constructing ONNX IR graphs imperatively.
4+
5+
This module provides imperative builders for constructing ONNX IR graphs with automatic
6+
constant promotion, type casting, and shape inference. The GraphBuilder class enables
7+
programmatic construction of graphs with proper scoping, constant management, and node
8+
creation. The OpBuilder class provides dynamic op dispatching via attribute access.
9+
"""
310

411
from __future__ import annotations
512

@@ -10,6 +17,7 @@
1017

1118
import onnxscript._internal._inference as inference
1219
import onnxscript.optimizer
20+
from onnxscript._internal import _inliner
1321

1422
# A permissible value for an op input, which can be converted to an ir.Value.
1523
VALUE_LIKE = Union[
@@ -334,6 +342,49 @@ def call_op(
334342

335343
return node.outputs if len(node.outputs) > 1 else node.outputs[0]
336344

345+
def call(
346+
self,
347+
function,
348+
*args,
349+
_outputs: Sequence[str] | None = None,
350+
_prefix: str = "",
351+
**kwargs,
352+
):
353+
if isinstance(function, ir.Function):
354+
function_ir = function
355+
elif isinstance(function, onnxscript.OnnxFunction):
356+
function_proto = function.to_function_proto()
357+
function_ir = ir.serde.deserialize_function(function_proto)
358+
else:
359+
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
360+
output_renaming: dict[str, str] = {}
361+
if _outputs is not None:
362+
if len(_outputs) != len(function_ir.outputs):
363+
raise ValueError(
364+
f"Number of provided output names {_outputs} does not match "
365+
f"number of function outputs {len(function_ir.outputs)}."
366+
)
367+
for output, name in zip(function_ir.outputs, _outputs):
368+
output_renaming[output.name] = self._qualify_value_name(name)
369+
else:
370+
for output in function_ir.outputs:
371+
output_renaming[output.name] = self._qualify_value_name(output.name)
372+
nodes, outputs = _inliner.instantiate(function_ir, args, kwargs)
373+
if _prefix:
374+
self.push_module(_prefix)
375+
for node in nodes:
376+
node.name = self._qualify_node_name(node.name)
377+
for output in node.outputs:
378+
if output.name:
379+
if output.name in output_renaming:
380+
output.name = output_renaming[output.name]
381+
else:
382+
output.name = self._qualify_value_name(output.name)
383+
self.add_node(node)
384+
if _prefix:
385+
self.pop_module()
386+
return outputs if len(outputs) > 1 else outputs[0]
387+
337388
def push_module(self, module: str, class_name: str = "") -> None:
338389
"""Push a new module scope onto the stack.
339390
@@ -414,6 +465,14 @@ def __init__(
414465
def builder(self) -> GraphBuilder:
415466
return self._builder
416467

468+
@property
469+
def domain(self) -> str:
470+
return self._domain
471+
472+
@property
473+
def version(self) -> int | None:
474+
return self._version
475+
417476
def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]):
418477
if "_domain" not in kwargs:
419478
kwargs["_domain"] = self._domain
@@ -426,3 +485,28 @@ def __getattr__(self, op_type: str) -> Callable:
426485

427486
def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
428487
return self._builder.initializer(tensor, name)
488+
489+
def call(
490+
self,
491+
function,
492+
*args,
493+
_outputs: Sequence[str] | None = None,
494+
_prefix: str = "",
495+
**kwargs,
496+
):
497+
"""Call a function and inline it into the graph.
498+
499+
Args:
500+
function: The function to call (ir.Function or onnxscript.OnnxFunction).
501+
*args: Positional arguments to pass to the function.
502+
_outputs: Optional sequence of output names. If provided, must match the
503+
number of function outputs.
504+
_prefix: Optional prefix for module scoping (e.g., "layers.0").
505+
**kwargs: Keyword arguments to pass to the function.
506+
507+
Returns:
508+
The output value(s) from the function call.
509+
"""
510+
return self._builder.call(
511+
function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs
512+
)

0 commit comments

Comments
 (0)