Skip to content

Commit bab4f28

Browse files
gramalingamCopilot
andauthored
Extend graph construction utility (#2842)
* The subgraph construction utility extended to allow specification of input/output names as well (optionally). * The to_ir method renamed to to_ir_type_and_shape in TensorType (for more robust duck-typing/protocol typing) --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 3f5a3c3 commit bab4f28

File tree

5 files changed

+239
-82
lines changed

5 files changed

+239
-82
lines changed

docs/tutorial/builder/graph_builder.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,11 @@ The subgraph automatically inherits the opset version from the parent
339339

340340
### Type annotations for subgraph inputs and outputs
341341

342-
`subgraph()` accepts `input_types` and `output_types` lists that describe
343-
the types and shapes of each input and output. Each element can be either an
344-
`ir.TypeAndShape` object or — more conveniently — an
342+
`subgraph()` accepts `inputs` and `outputs` that describe
343+
the types and shapes of each input and output. They can be provided as a
344+
:class:`list` of type specs (names are auto-generated) **or** as a
345+
:class:`dict` mapping explicit names to type specs. Each type spec can be
346+
either an `ir.TypeAndShape` object or — more conveniently — an
345347
`onnxscript` tensor-type expression:
346348

347349
| Expression | Meaning |
@@ -408,8 +410,8 @@ def cumsum_body(op, state, x_i):
408410

409411
body = builder.subgraph(
410412
cumsum_body,
411-
input_types=[FLOAT[D], FLOAT[D]], # state, x_i
412-
output_types=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i
413+
inputs=[FLOAT[D], FLOAT[D]], # state, x_i
414+
outputs=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i
413415
name="cumsum_body",
414416
)
415417

@@ -430,7 +432,7 @@ model = ir.Model(graph=graph, ir_version=10)
430432

431433
Key points:
432434

433-
- `builder.subgraph(fn, input_types, output_types)` creates a fresh
435+
- `builder.subgraph(fn, inputs, outputs)` creates a fresh
434436
`ir.Graph`, calls `fn(op, *inputs)` to trace the body, and wires up the
435437
declared input/output types.
436438
- The `fn` receives an `OpBuilder` as its first argument — exactly the same
@@ -450,8 +452,8 @@ def outer_body(op, state, x_i):
450452
# Build a nested subgraph inside the scan body
451453
inner = op.builder.subgraph(
452454
lambda iop, v: iop.Relu(v),
453-
input_types=[FLOAT[D]],
454-
output_types=[FLOAT[D]],
455+
inputs=[FLOAT[D]],
456+
outputs=[FLOAT[D]],
455457
name="relu_body",
456458
)
457459
# ... use inner as a graph attribute of a nested op ...

onnxscript/_internal/builder.py

Lines changed: 143 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
from typing import Any, Callable, Sequence, Union
13+
from typing import Any, Callable, Mapping, Sequence, Union
1414

1515
import onnx
1616
import onnx_ir as ir
@@ -74,31 +74,135 @@ def _constant_name(
7474
return f"const_1d_{num}"
7575

7676

77-
# Type accepted as an element of *input_types* / *output_types* by
77+
# Type accepted as an element of *inputs* / *outputs* by
7878
# :meth:`GraphBuilder.subgraph`. Can be an already-resolved
7979
# :class:`ir.TypeAndShape`, or a
8080
# :class:`~onnxscript.onnx_types.TensorType` subclass such as ``FLOAT[1024]``.
8181
TypeSpec = Union[ir.TypeAndShape, Any]
8282

83+
# Acceptable collection forms for *inputs* / *outputs* in
84+
# :meth:`GraphBuilder.subgraph`. A :class:`Sequence` of :data:`TypeSpec`
85+
# auto-names entries (``input_0``, ``input_1``, …), while a :class:`Mapping`
86+
# from :class:`str` to :data:`TypeSpec` uses the keys as explicit names.
87+
InputOutputSpec = Union[Sequence[TypeSpec], Mapping[str, TypeSpec]]
88+
8389

8490
def _resolve_type_spec(spec: TypeSpec) -> ir.TypeAndShape:
8591
"""Convert a *TypeSpec* to an :class:`ir.TypeAndShape`.
8692
87-
Accepts either an :class:`ir.TypeAndShape` directly, or a
88-
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]``
89-
or ``FLOAT['M', 'N']``).
93+
Accepts an :class:`ir.TypeAndShape` directly, or any object with a
94+
``to_ir_type_and_shape()`` method (e.g. a
95+
:class:`~onnxscript.onnx_types.TensorType` subclass such as
96+
``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
9097
"""
91-
# Lazy import to avoid a circular dependency: onnxscript.__init__ imports
92-
# onnx_types (line ~106) before builder (line ~132), so by the time any
93-
# call reaches here the module is fully initialised — but a top-level
94-
# import in builder.py could break if builder is ever imported first.
95-
from onnxscript.onnx_types import TensorType # pylint: disable=import-outside-toplevel
96-
9798
if isinstance(spec, ir.TypeAndShape):
9899
return spec
99-
if isinstance(spec, type) and issubclass(spec, TensorType):
100-
return spec.to_ir()
101-
raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.")
100+
if hasattr(spec, "to_ir_type_and_shape"):
101+
result = spec.to_ir_type_and_shape()
102+
if not isinstance(result, ir.TypeAndShape):
103+
raise TypeError(
104+
f"{type(spec)!r}.to_ir_type_and_shape() returned {type(result)!r}, "
105+
f"expected ir.TypeAndShape."
106+
)
107+
return result
108+
raise TypeError(
109+
f"Expected ir.TypeAndShape or an object with a to_ir_type_and_shape() method, "
110+
f"got {type(spec)!r}."
111+
)
112+
113+
114+
def _normalize_io_spec(
115+
spec: InputOutputSpec, default_prefix: str
116+
) -> list[tuple[str, ir.TypeAndShape]]:
117+
"""Normalize an *InputOutputSpec* into a list of ``(name, TypeAndShape)`` pairs.
118+
119+
When *spec* is a :class:`Mapping`, the keys are used as names. When it is
120+
a plain :class:`Sequence`, names are generated as
121+
``{default_prefix}_0``, ``{default_prefix}_1``, etc.
122+
"""
123+
if isinstance(spec, Mapping):
124+
return [(name, _resolve_type_spec(ts)) for name, ts in spec.items()]
125+
return [(f"{default_prefix}_{i}", _resolve_type_spec(ts)) for i, ts in enumerate(spec)]
126+
127+
128+
def build_graph(
129+
trace_function: Callable,
130+
inputs: InputOutputSpec,
131+
outputs: InputOutputSpec,
132+
*,
133+
opset_imports: dict[str, int] | None = None,
134+
name: str = "subgraph",
135+
) -> ir.Graph:
136+
"""Build an :class:`ir.Graph` suitable for use as a graph-valued attribute.
137+
138+
This is a module-level utility that constructs a subgraph by tracing
139+
*trace_function*. It is useful for building body graphs of control-flow ops
140+
such as ``Scan``, ``Loop``, and ``If``.
141+
142+
Example - building a Scan body that adds two sequences element-wise::
143+
144+
body = build_graph(
145+
lambda op, x, y: op.Add(x, y),
146+
inputs={"x": FLOAT[...], "y": FLOAT[...]},
147+
outputs={"sum": FLOAT[...]},
148+
)
149+
150+
Args:
151+
trace_function: A callable with signature
152+
``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``.
153+
It is called once with freshly created placeholder inputs to record the
154+
graph topology.
155+
inputs: Types (and optionally names) for each graph input. May be a
156+
:class:`Sequence` of :data:`TypeSpec` values (names are auto-generated
157+
as ``input_0``, ``input_1``, …) **or** a :class:`Mapping` from
158+
:class:`str` names to :data:`TypeSpec` values. Each :data:`TypeSpec`
159+
can be an :class:`ir.TypeAndShape` or a
160+
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g.
161+
``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
162+
outputs: Types (and optionally names) for each graph output, in the
163+
same format as *inputs*.
164+
opset_imports: Opset version map for the subgraph (e.g.
165+
``{"": 23}``). Defaults to ``{"": 23}`` when *None*.
166+
name: Name of the resulting :class:`ir.Graph`.
167+
168+
Returns:
169+
An :class:`ir.Graph` whose inputs and outputs are populated and whose
170+
nodes record the operations traced by *trace_function*. This graph can be
171+
passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
172+
a ``Scan`` or ``Loop`` node).
173+
"""
174+
if opset_imports is None:
175+
opset_imports = {"": 23}
176+
resolved_inputs = _normalize_io_spec(inputs, "input")
177+
resolved_outputs = _normalize_io_spec(outputs, "output")
178+
179+
subgraph = ir.Graph(
180+
name=name,
181+
inputs=[],
182+
outputs=[],
183+
nodes=[],
184+
opset_imports=opset_imports,
185+
)
186+
187+
for input_name, ts in resolved_inputs:
188+
subgraph.inputs.append(ir.Value(name=input_name, type=ts.type, shape=ts.shape))
189+
190+
sub_builder = GraphBuilder(subgraph)
191+
trace_outputs = trace_function(sub_builder.op, *subgraph.inputs)
192+
if not isinstance(trace_outputs, Sequence):
193+
trace_outputs = [trace_outputs]
194+
if len(trace_outputs) != len(resolved_outputs):
195+
raise ValueError(
196+
f"trace_function returned {len(trace_outputs)} output(s), "
197+
f"but {len(resolved_outputs)} were declared in outputs."
198+
)
199+
for output, (output_name, ts) in zip(trace_outputs, resolved_outputs):
200+
output.name = output_name
201+
output.type = ts.type
202+
output.merge_shapes(ts.shape)
203+
204+
subgraph.outputs.extend(trace_outputs)
205+
return subgraph
102206

103207

104208
class GraphBuilder:
@@ -332,8 +436,8 @@ def add_node(self, node: ir.Node) -> None:
332436
def subgraph(
333437
self,
334438
trace_function: Callable,
335-
input_types: Sequence[TypeSpec],
336-
output_types: Sequence[TypeSpec],
439+
inputs: InputOutputSpec,
440+
outputs: InputOutputSpec,
337441
*,
338442
name: str = "subgraph",
339443
) -> ir.Graph:
@@ -347,21 +451,33 @@ def subgraph(
347451
348452
body = graph_builder.subgraph(
349453
lambda op, x, y: op.Add(x, y),
350-
input_types=[FLOAT[...], FLOAT[...]],
351-
output_types=[FLOAT[...]],
454+
inputs=[FLOAT[...], FLOAT[...]],
455+
outputs=[FLOAT[...]],
456+
)
457+
458+
Inputs and outputs can also be given as a :class:`dict` to assign
459+
explicit names::
460+
461+
body = graph_builder.subgraph(
462+
lambda op, x, y: op.Add(x, y),
463+
inputs={"x": FLOAT[...], "y": FLOAT[...]},
464+
outputs={"sum": FLOAT[...]},
352465
)
353466
354467
Args:
355468
trace_function: A callable with signature
356469
``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``.
357470
It is called once with freshly created placeholder inputs to record the
358471
graph topology.
359-
input_types: Types for each graph input. Each element may be an
360-
:class:`ir.TypeAndShape` **or** a
472+
inputs: Types (and optionally names) for each graph input. May be a
473+
:class:`Sequence` of :data:`TypeSpec` values (names are auto-generated
474+
as ``input_0``, ``input_1``, …) **or** a :class:`Mapping` from
475+
:class:`str` names to :data:`TypeSpec` values. Each :data:`TypeSpec`
476+
can be an :class:`ir.TypeAndShape` or a
361477
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g.
362478
``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
363-
output_types: Types for each graph output, in the same format as
364-
*input_types*.
479+
outputs: Types (and optionally names) for each graph output, in the
480+
same format as *inputs*.
365481
name: Name of the resulting :class:`ir.Graph`.
366482
367483
Returns:
@@ -370,37 +486,14 @@ def subgraph(
370486
passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
371487
a ``Scan`` or ``Loop`` node).
372488
"""
373-
opset_version = self._graph.opset_imports[""]
374-
resolved_inputs = [_resolve_type_spec(t) for t in input_types]
375-
resolved_outputs = [_resolve_type_spec(t) for t in output_types]
376-
377-
subgraph = ir.Graph(
489+
return build_graph(
490+
trace_function,
491+
inputs,
492+
outputs,
493+
opset_imports=dict(self._graph.opset_imports),
378494
name=name,
379-
inputs=[],
380-
outputs=[],
381-
nodes=[],
382-
opset_imports={"": opset_version},
383495
)
384496

385-
for i, ts in enumerate(resolved_inputs):
386-
subgraph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape))
387-
388-
sub_builder = GraphBuilder(subgraph)
389-
outputs = trace_function(sub_builder.op, *subgraph.inputs)
390-
if not isinstance(outputs, Sequence):
391-
outputs = [outputs]
392-
if len(outputs) != len(resolved_outputs):
393-
raise ValueError(
394-
f"trace_function returned {len(outputs)} output(s), "
395-
f"but {len(resolved_outputs)} were declared in output_types."
396-
)
397-
for output, ts in zip(outputs, resolved_outputs):
398-
output.type = ts.type
399-
output.merge_shapes(ts.shape)
400-
401-
subgraph.outputs.extend(outputs)
402-
return subgraph
403-
404497
def call_op(
405498
self,
406499
op_type: str,

0 commit comments

Comments
 (0)