Skip to content

Commit ef2bc22

Browse files
gramalingamCopilot
andauthored
Add GraphBuilder.subgraph() and TensorType.to_ir() for control-flow ops (#2824)
Summary: Adds utilities to simplify building ONNX control-flow ops (Scan, Loop, If) that use graph-valued attributes. New: TensorType.to_ir() Converts tensor-type annotations to ir.TypeAndShape, enabling the convenient FLOAT[1024] / FLOAT['M', 'N'] / FLOAT[...] notation wherever a type spec is needed. New: GraphBuilder.subgraph() Builds an ir.Graph suitable for use as a graph-valued attribute. The body is defined by a trace function fn(op, *inputs) — the same imperative style as the outer graph. The opset version is inherited from the parent GraphBuilder . ```py def cumsum_body(op, state, x_i): new_state = op.Add(state, x_i) return new_state, new_state body = builder.subgraph( cumsum_body, input_types=[FLOAT[4], FLOAT[4]], output_types=[FLOAT[4], FLOAT[4]], name="cumsum_body", ) final_state, partial_sums = op.Scan( init_state, sequence, body=body, num_scan_inputs=1, _outputs=2, ) ``` Files changed: - onnxscript/onnx_types.py — Add TensorType.to_ir() classmethod - onnxscript/_internal/builder.py — Add GraphBuilder.subgraph() method + TypeSpec / _resolve_type_spec helpers - onnxscript/onnx_types_test.py — New test file for TensorType.to_ir() - onnxscript/_internal/builder_test.py — Tests for GraphBuilder.subgraph() - docs/tutorial/builder/graph_builder.md — New Building Subgraphs for Control-Flow Ops section with Scan example --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent f6a7de1 commit ef2bc22

File tree

5 files changed

+433
-0
lines changed

5 files changed

+433
-0
lines changed

docs/tutorial/builder/graph_builder.md

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,139 @@ print(w.name) # "encoder.W"
326326
builder.pop_module()
327327
```
328328

329+
## Building Subgraphs for Control-Flow Ops
330+
331+
ONNX control-flow operators such as `Scan`, `Loop`, and `If` accept one or more
332+
**graph-valued attributes** — graphs that define the body executed at each
333+
iteration (or branch). `GraphBuilder.subgraph()` builds these inner graphs in
334+
exactly the same imperative style as the outer graph, and the resulting
335+
`ir.Graph` can be passed directly as an attribute.
336+
337+
The subgraph automatically inherits the opset version from the parent
338+
`GraphBuilder`, so there is no need to specify it separately.
339+
340+
### Type annotations for subgraph inputs and outputs
341+
342+
`subgraph()` accepts `input_types` and `output_types` lists that describe
343+
the types and shapes of each input and output. Each element can be either an
344+
`ir.TypeAndShape` object or — more conveniently — an
345+
`onnxscript` tensor-type expression:
346+
347+
| Expression | Meaning |
348+
|----------------------|-----------------------------------------|
349+
| `FLOAT` | Rank-0 scalar float tensor |
350+
| `FLOAT[...]` | Float tensor of unknown rank |
351+
| `FLOAT[1024]` | 1-D float tensor with 1024 elements |
352+
| `FLOAT[3, 4]` | 2-D float tensor of shape (3, 4) |
353+
| `FLOAT['M', 'N']` | 2-D float tensor with symbolic dims |
354+
355+
These types come from `onnxscript.onnx_types` (also importable from
356+
`onnxscript` directly):
357+
358+
```python
359+
from onnxscript.onnx_types import FLOAT, INT64
360+
```
361+
362+
### Example: cumulative sum with Scan
363+
364+
The `Scan` op iterates over a sequence axis, threading a state vector through
365+
each step. Here is how to build a cumulative-sum model with `subgraph()`:
366+
367+
```python
368+
import onnx_ir as ir
369+
import onnxscript
370+
from onnxscript.onnx_types import FLOAT
371+
372+
D = 4 # feature dimension
373+
N = 10 # sequence length
374+
375+
# --- Parent graph -----------------------------------------------------------
376+
graph = ir.Graph(
377+
name="cumsum_model",
378+
inputs=[],
379+
outputs=[],
380+
nodes=[],
381+
opset_imports={"": 23},
382+
)
383+
384+
# Initial accumulator (shape [D]) and input sequence (shape [N, D])
385+
init_state = ir.Value(
386+
name="init_state",
387+
type=ir.TensorType(ir.DataType.FLOAT),
388+
shape=ir.Shape([D]),
389+
)
390+
sequence = ir.Value(
391+
name="sequence",
392+
type=ir.TensorType(ir.DataType.FLOAT),
393+
shape=ir.Shape([N, D]),
394+
)
395+
graph.inputs.extend([init_state, sequence])
396+
397+
builder = onnxscript.GraphBuilder(graph)
398+
op = builder.op
399+
400+
# --- Scan body --------------------------------------------------------------
401+
# The body receives one state slice (the running sum) and one scan slice
402+
# (the current element of the sequence). It adds them and returns the new
403+
# state both as the updated state and as a scan output.
404+
405+
def cumsum_body(op, state, x_i):
406+
new_state = op.Add(state, x_i)
407+
return new_state, new_state # (updated_state, scan_output_for_this_step)
408+
409+
body = builder.subgraph(
410+
cumsum_body,
411+
input_types=[FLOAT[D], FLOAT[D]], # state, x_i
412+
output_types=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i
413+
name="cumsum_body",
414+
)
415+
416+
# --- Scan node --------------------------------------------------------------
417+
# Inputs: init_state (1 state variable), sequence (1 scan input)
418+
# Outputs: final_state, all_partial_sums (shape [N, D])
419+
final_state, partial_sums = op.Scan(
420+
init_state,
421+
sequence,
422+
body=body,
423+
num_scan_inputs=1,
424+
_outputs=2,
425+
)
426+
graph.outputs.extend([final_state, partial_sums])
427+
428+
model = ir.Model(graph=graph, ir_version=10)
429+
```
430+
431+
Key points:
432+
433+
- `builder.subgraph(fn, input_types, output_types)` creates a fresh
434+
`ir.Graph`, calls `fn(op, *inputs)` to trace the body, and wires up the
435+
declared input/output types.
436+
- The `fn` receives an `OpBuilder` as its first argument — exactly the same
437+
API as the outer graph — so you can use the full builder feature set inside
438+
a body (constants, module scopes, nested subgraphs, etc.).
439+
- The returned `ir.Graph` is passed as the `body` keyword attribute of `Scan`.
440+
- `_outputs=2` tells the builder that `Scan` returns two output values.
441+
442+
### Nested subgraphs
443+
444+
Because the `fn` receives an `OpBuilder`, and `OpBuilder` exposes
445+
`op.builder`, you can reach the inner `GraphBuilder` and call `subgraph()`
446+
recursively for doubly-nested control flow (e.g. a `Scan` inside a `Loop`):
447+
448+
```python
449+
def outer_body(op, state, x_i):
450+
# Build a nested subgraph inside the scan body
451+
inner = op.builder.subgraph(
452+
lambda iop, v: iop.Relu(v),
453+
input_types=[FLOAT[D]],
454+
output_types=[FLOAT[D]],
455+
name="relu_body",
456+
)
457+
# ... use inner as a graph attribute of a nested op ...
458+
new_state = op.Add(state, x_i)
459+
return new_state, new_state
460+
```
461+
329462
## Putting It All Together
330463

331464
Here is a complete example that builds a small model with two layers:

onnxscript/_internal/builder.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,33 @@ def _constant_name(
7474
return f"const_1d_{num}"
7575

7676

77+
# Type accepted as an element of *input_types* / *output_types* by
78+
# :meth:`GraphBuilder.subgraph`. Can be an already-resolved
79+
# :class:`ir.TypeAndShape`, or a
80+
# :class:`~onnxscript.onnx_types.TensorType` subclass such as ``FLOAT[1024]``.
81+
TypeSpec = Union[ir.TypeAndShape, Any]
82+
83+
84+
def _resolve_type_spec(spec: TypeSpec) -> ir.TypeAndShape:
85+
"""Convert a *TypeSpec* to an :class:`ir.TypeAndShape`.
86+
87+
Accepts either an :class:`ir.TypeAndShape` directly, or a
88+
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]``
89+
or ``FLOAT['M', 'N']``).
90+
"""
91+
# Lazy import to avoid a circular dependency: onnxscript.__init__ imports
92+
# onnx_types (line ~106) before builder (line ~132), so by the time any
93+
# call reaches here the module is fully initialised — but a top-level
94+
# import in builder.py could break if builder is ever imported first.
95+
from onnxscript.onnx_types import TensorType # pylint: disable=import-outside-toplevel
96+
97+
if isinstance(spec, ir.TypeAndShape):
98+
return spec
99+
if isinstance(spec, type) and issubclass(spec, TensorType):
100+
return spec.to_ir()
101+
raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.")
102+
103+
77104
class GraphBuilder:
78105
"""Imperative builder for constructing ONNX IR graphs with automatic constant promotion, type casting, and shape inference."""
79106

@@ -302,6 +329,78 @@ def add_node(self, node: ir.Node) -> None:
302329
onnxscript.optimizer.basic_constant_propagation([node])
303330
inference.infer_outputs(node)
304331

332+
def subgraph(
333+
self,
334+
trace_function: Callable,
335+
input_types: Sequence[TypeSpec],
336+
output_types: Sequence[TypeSpec],
337+
*,
338+
name: str = "subgraph",
339+
) -> ir.Graph:
340+
"""Build an :class:`ir.Graph` suitable for use as a graph-valued attribute.
341+
342+
The subgraph inherits the opset version from this :class:`GraphBuilder`.
343+
It is particularly useful for constructing the body graphs of control-flow ops
344+
such as ``Scan``, ``Loop``, and ``If``.
345+
346+
Example - building a Scan body that adds two sequences element-wise::
347+
348+
body = graph_builder.subgraph(
349+
lambda op, x, y: op.Add(x, y),
350+
input_types=[FLOAT[...], FLOAT[...]],
351+
output_types=[FLOAT[...]],
352+
)
353+
354+
Args:
355+
trace_function: A callable with signature
356+
``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``.
357+
It is called once with freshly created placeholder inputs to record the
358+
graph topology.
359+
input_types: Types for each graph input. Each element may be an
360+
:class:`ir.TypeAndShape` **or** a
361+
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g.
362+
``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
363+
output_types: Types for each graph output, in the same format as
364+
*input_types*.
365+
name: Name of the resulting :class:`ir.Graph`.
366+
367+
Returns:
368+
An :class:`ir.Graph` whose inputs and outputs are populated and whose
369+
nodes record the operations traced by *trace_function*. This graph can be
370+
passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
371+
a ``Scan`` or ``Loop`` node).
372+
"""
373+
opset_version = self._graph.opset_imports[""]
374+
resolved_inputs = [_resolve_type_spec(t) for t in input_types]
375+
resolved_outputs = [_resolve_type_spec(t) for t in output_types]
376+
377+
subgraph = ir.Graph(
378+
name=name,
379+
inputs=[],
380+
outputs=[],
381+
nodes=[],
382+
opset_imports={"": opset_version},
383+
)
384+
385+
for i, ts in enumerate(resolved_inputs):
386+
subgraph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape))
387+
388+
sub_builder = GraphBuilder(subgraph)
389+
outputs = trace_function(sub_builder.op, *subgraph.inputs)
390+
if not isinstance(outputs, Sequence):
391+
outputs = [outputs]
392+
if len(outputs) != len(resolved_outputs):
393+
raise ValueError(
394+
f"trace_function returned {len(outputs)} output(s), "
395+
f"but {len(resolved_outputs)} were declared in output_types."
396+
)
397+
for output, ts in zip(outputs, resolved_outputs):
398+
output.type = ts.type
399+
output.merge_shapes(ts.shape)
400+
401+
subgraph.outputs.extend(outputs)
402+
return subgraph
403+
305404
def call_op(
306405
self,
307406
op_type: str,

onnxscript/_internal/builder_test.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import onnxscript._internal.builder as builder
1313
from onnxscript import script
14+
from onnxscript.onnx_types import DOUBLE, FLOAT
1415

1516
_default_opset_version = 23
1617

@@ -819,5 +820,123 @@ def add_mul(X, Y):
819820
self.assertIn("does not match", str(cm.exception))
820821

821822

823+
class BuildSubgraphTest(unittest.TestCase):
824+
"""Tests for GraphBuilder.subgraph()."""
825+
826+
def _make_builder(self, opset_version: int = 23) -> builder.GraphBuilder:
827+
"""Return a minimal GraphBuilder for the given opset version."""
828+
graph = ir.Graph(
829+
name="parent",
830+
inputs=[],
831+
outputs=[],
832+
nodes=[],
833+
opset_imports={"": opset_version},
834+
)
835+
return builder.GraphBuilder(graph)
836+
837+
def test_basic_subgraph(self):
838+
"""Subgraph returns a valid ir.Graph with correct inputs/outputs."""
839+
840+
def _add(op, x, y):
841+
return op.Add(x, y)
842+
843+
gb = self._make_builder()
844+
graph = gb.subgraph(
845+
_add,
846+
input_types=[FLOAT[3, 4], FLOAT[3, 4]],
847+
output_types=[FLOAT[3, 4]],
848+
)
849+
self.assertIsInstance(graph, ir.Graph)
850+
self.assertEqual(len(graph.inputs), 2)
851+
self.assertEqual(len(graph.outputs), 1)
852+
op_types = [node.op_type for node in graph]
853+
self.assertEqual(op_types, ["Add"])
854+
855+
def test_subgraph_inherits_opset_version(self):
856+
"""The subgraph opset version matches the parent GraphBuilder."""
857+
gb = self._make_builder(opset_version=17)
858+
graph = gb.subgraph(
859+
lambda op, x: op.Identity(x),
860+
input_types=[FLOAT[...]],
861+
output_types=[FLOAT[...]],
862+
)
863+
self.assertEqual(graph.opset_imports[""], 17)
864+
865+
def test_subgraph_with_ir_type_and_shape(self):
866+
"""Subgraph also accepts ir.TypeAndShape directly."""
867+
868+
def _mul(op, x, y):
869+
return op.Mul(x, y)
870+
871+
float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([2, 3]))
872+
gb = self._make_builder()
873+
graph = gb.subgraph(
874+
_mul,
875+
input_types=[float_2d, float_2d],
876+
output_types=[float_2d],
877+
)
878+
self.assertIsInstance(graph, ir.Graph)
879+
self.assertEqual(len(list(graph)), 1)
880+
self.assertEqual(next(iter(graph)).op_type, "Mul")
881+
882+
def test_subgraph_multiple_outputs(self):
883+
"""Subgraph handles multiple outputs."""
884+
885+
def _add_and_mul(op, x, y):
886+
return op.Add(x, y), op.Mul(x, y)
887+
888+
ts = FLOAT[...]
889+
gb = self._make_builder()
890+
graph = gb.subgraph(
891+
_add_and_mul,
892+
input_types=[ts, ts],
893+
output_types=[ts, ts],
894+
)
895+
self.assertEqual(len(graph.outputs), 2)
896+
897+
def test_subgraph_output_count_mismatch_raises(self):
898+
"""Subgraph raises ValueError when output count does not match."""
899+
900+
def _returns_one(op, x, y):
901+
return op.Add(x, y)
902+
903+
gb = self._make_builder()
904+
with self.assertRaises(ValueError):
905+
gb.subgraph(
906+
_returns_one,
907+
input_types=[FLOAT[...], FLOAT[...]],
908+
output_types=[FLOAT[...], FLOAT[...]], # expects 2, gets 1
909+
)
910+
911+
def test_subgraph_custom_name(self):
912+
"""Subgraph passes the name through to the ir.Graph."""
913+
914+
def _id(op, x):
915+
return op.Identity(x)
916+
917+
gb = self._make_builder()
918+
graph = gb.subgraph(
919+
_id,
920+
input_types=[DOUBLE[...]],
921+
output_types=[DOUBLE[...]],
922+
name="scan_body",
923+
)
924+
self.assertEqual(graph.name, "scan_body")
925+
926+
def test_invalid_type_spec_raises(self):
927+
"""Subgraph raises TypeError for an unrecognised type specification."""
928+
929+
def _id(op, x):
930+
return op.Identity(x)
931+
932+
gb = self._make_builder()
933+
with self.assertRaises(TypeError):
934+
gb.subgraph(
935+
_id,
936+
input_types=["not_a_type_spec"],
937+
output_types=["not_a_type_spec"],
938+
)
939+
940+
822941
if __name__ == "__main__":
823942
unittest.main()

0 commit comments

Comments
 (0)