1010
1111from __future__ import annotations
1212
13- from typing import Any , Callable , Sequence , Union
13+ from typing import Any , Callable , Mapping , Sequence , Union
1414
1515import onnx
1616import onnx_ir as ir
@@ -74,31 +74,135 @@ def _constant_name(
7474 return f"const_1d_{ num } "
7575
7676
77- # Type accepted as an element of *input_types * / *output_types * by
77+ # Type accepted as an element of *inputs * / *outputs * by
7878# :meth:`GraphBuilder.subgraph`. Can be an already-resolved
7979# :class:`ir.TypeAndShape`, or a
8080# :class:`~onnxscript.onnx_types.TensorType` subclass such as ``FLOAT[1024]``.
8181TypeSpec = Union [ir .TypeAndShape , Any ]
8282
83+ # Acceptable collection forms for *inputs* / *outputs* in
84+ # :meth:`GraphBuilder.subgraph`. A :class:`Sequence` of :data:`TypeSpec`
85+ # auto-names entries (``input_0``, ``input_1``, …), while a :class:`Mapping`
86+ # from :class:`str` to :data:`TypeSpec` uses the keys as explicit names.
87+ InputOutputSpec = Union [Sequence [TypeSpec ], Mapping [str , TypeSpec ]]
88+
8389
8490def _resolve_type_spec (spec : TypeSpec ) -> ir .TypeAndShape :
8591 """Convert a *TypeSpec* to an :class:`ir.TypeAndShape`.
8692
87- Accepts either an :class:`ir.TypeAndShape` directly, or a
88- :class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]``
89- or ``FLOAT['M', 'N']``).
93+ Accepts an :class:`ir.TypeAndShape` directly, or any object with a
94+ ``to_ir_type_and_shape()`` method (e.g. a
95+ :class:`~onnxscript.onnx_types.TensorType` subclass such as
96+ ``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
9097 """
91- # Lazy import to avoid a circular dependency: onnxscript.__init__ imports
92- # onnx_types (line ~106) before builder (line ~132), so by the time any
93- # call reaches here the module is fully initialised — but a top-level
94- # import in builder.py could break if builder is ever imported first.
95- from onnxscript .onnx_types import TensorType # pylint: disable=import-outside-toplevel
96-
9798 if isinstance (spec , ir .TypeAndShape ):
9899 return spec
99- if isinstance (spec , type ) and issubclass (spec , TensorType ):
100- return spec .to_ir ()
101- raise TypeError (f"Expected ir.TypeAndShape or a TensorType subclass, got { type (spec )!r} ." )
100+ if hasattr (spec , "to_ir_type_and_shape" ):
101+ result = spec .to_ir_type_and_shape ()
102+ if not isinstance (result , ir .TypeAndShape ):
103+ raise TypeError (
104+ f"{ type (spec )!r} .to_ir_type_and_shape() returned { type (result )!r} , "
105+ f"expected ir.TypeAndShape."
106+ )
107+ return result
108+ raise TypeError (
109+ f"Expected ir.TypeAndShape or an object with a to_ir_type_and_shape() method, "
110+ f"got { type (spec )!r} ."
111+ )
112+
113+
114+ def _normalize_io_spec (
115+ spec : InputOutputSpec , default_prefix : str
116+ ) -> list [tuple [str , ir .TypeAndShape ]]:
117+ """Normalize an *InputOutputSpec* into a list of ``(name, TypeAndShape)`` pairs.
118+
119+ When *spec* is a :class:`Mapping`, the keys are used as names. When it is
120+ a plain :class:`Sequence`, names are generated as
121+ ``{default_prefix}_0``, ``{default_prefix}_1``, etc.
122+ """
123+ if isinstance (spec , Mapping ):
124+ return [(name , _resolve_type_spec (ts )) for name , ts in spec .items ()]
125+ return [(f"{ default_prefix } _{ i } " , _resolve_type_spec (ts )) for i , ts in enumerate (spec )]
126+
127+
128+ def build_graph (
129+ trace_function : Callable ,
130+ inputs : InputOutputSpec ,
131+ outputs : InputOutputSpec ,
132+ * ,
133+ opset_imports : dict [str , int ] | None = None ,
134+ name : str = "subgraph" ,
135+ ) -> ir .Graph :
136+ """Build an :class:`ir.Graph` suitable for use as a graph-valued attribute.
137+
138+ This is a module-level utility that constructs a subgraph by tracing
139+ *trace_function*. It is useful for building body graphs of control-flow ops
140+ such as ``Scan``, ``Loop``, and ``If``.
141+
142+ Example - building a Scan body that adds two sequences element-wise::
143+
144+ body = build_graph(
145+ lambda op, x, y: op.Add(x, y),
146+ inputs={"x": FLOAT[...], "y": FLOAT[...]},
147+ outputs={"sum": FLOAT[...]},
148+ )
149+
150+ Args:
151+ trace_function: A callable with signature
152+ ``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``.
153+ It is called once with freshly created placeholder inputs to record the
154+ graph topology.
155+ inputs: Types (and optionally names) for each graph input. May be a
156+ :class:`Sequence` of :data:`TypeSpec` values (names are auto-generated
157+ as ``input_0``, ``input_1``, …) **or** a :class:`Mapping` from
158+ :class:`str` names to :data:`TypeSpec` values. Each :data:`TypeSpec`
159+ can be an :class:`ir.TypeAndShape` or a
160+ :class:`~onnxscript.onnx_types.TensorType` subclass (e.g.
161+ ``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
162+ outputs: Types (and optionally names) for each graph output, in the
163+ same format as *inputs*.
164+ opset_imports: Opset version map for the subgraph (e.g.
165+ ``{"": 23}``). Defaults to ``{"": 23}`` when *None*.
166+ name: Name of the resulting :class:`ir.Graph`.
167+
168+ Returns:
169+ An :class:`ir.Graph` whose inputs and outputs are populated and whose
170+ nodes record the operations traced by *trace_function*. This graph can be
171+ passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
172+ a ``Scan`` or ``Loop`` node).
173+ """
174+ if opset_imports is None :
175+ opset_imports = {"" : 23 }
176+ resolved_inputs = _normalize_io_spec (inputs , "input" )
177+ resolved_outputs = _normalize_io_spec (outputs , "output" )
178+
179+ subgraph = ir .Graph (
180+ name = name ,
181+ inputs = [],
182+ outputs = [],
183+ nodes = [],
184+ opset_imports = opset_imports ,
185+ )
186+
187+ for input_name , ts in resolved_inputs :
188+ subgraph .inputs .append (ir .Value (name = input_name , type = ts .type , shape = ts .shape ))
189+
190+ sub_builder = GraphBuilder (subgraph )
191+ trace_outputs = trace_function (sub_builder .op , * subgraph .inputs )
192+ if not isinstance (trace_outputs , Sequence ):
193+ trace_outputs = [trace_outputs ]
194+ if len (trace_outputs ) != len (resolved_outputs ):
195+ raise ValueError (
196+ f"trace_function returned { len (trace_outputs )} output(s), "
197+ f"but { len (resolved_outputs )} were declared in outputs."
198+ )
199+ for output , (output_name , ts ) in zip (trace_outputs , resolved_outputs ):
200+ output .name = output_name
201+ output .type = ts .type
202+ output .merge_shapes (ts .shape )
203+
204+ subgraph .outputs .extend (trace_outputs )
205+ return subgraph
102206
103207
104208class GraphBuilder :
@@ -332,8 +436,8 @@ def add_node(self, node: ir.Node) -> None:
332436 def subgraph (
333437 self ,
334438 trace_function : Callable ,
335- input_types : Sequence [ TypeSpec ] ,
336- output_types : Sequence [ TypeSpec ] ,
439+ inputs : InputOutputSpec ,
440+ outputs : InputOutputSpec ,
337441 * ,
338442 name : str = "subgraph" ,
339443 ) -> ir .Graph :
@@ -347,21 +451,33 @@ def subgraph(
347451
348452 body = graph_builder.subgraph(
349453 lambda op, x, y: op.Add(x, y),
350- input_types=[FLOAT[...], FLOAT[...]],
351- output_types=[FLOAT[...]],
454+ inputs=[FLOAT[...], FLOAT[...]],
455+ outputs=[FLOAT[...]],
456+ )
457+
458+ Inputs and outputs can also be given as a :class:`dict` to assign
459+ explicit names::
460+
461+ body = graph_builder.subgraph(
462+ lambda op, x, y: op.Add(x, y),
463+ inputs={"x": FLOAT[...], "y": FLOAT[...]},
464+ outputs={"sum": FLOAT[...]},
352465 )
353466
354467 Args:
355468 trace_function: A callable with signature
356469 ``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``.
357470 It is called once with freshly created placeholder inputs to record the
358471 graph topology.
359- input_types: Types for each graph input. Each element may be an
360- :class:`ir.TypeAndShape` **or** a
472+ inputs: Types (and optionally names) for each graph input. May be a
473+ :class:`Sequence` of :data:`TypeSpec` values (names are auto-generated
474+ as ``input_0``, ``input_1``, …) **or** a :class:`Mapping` from
475+ :class:`str` names to :data:`TypeSpec` values. Each :data:`TypeSpec`
476+ can be an :class:`ir.TypeAndShape` or a
361477 :class:`~onnxscript.onnx_types.TensorType` subclass (e.g.
362478 ``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
363- output_types : Types for each graph output, in the same format as
364- *input_types *.
479+ outputs : Types (and optionally names) for each graph output, in the
480+ same format as *inputs *.
365481 name: Name of the resulting :class:`ir.Graph`.
366482
367483 Returns:
@@ -370,37 +486,14 @@ def subgraph(
370486 passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
371487 a ``Scan`` or ``Loop`` node).
372488 """
373- opset_version = self . _graph . opset_imports [ "" ]
374- resolved_inputs = [ _resolve_type_spec ( t ) for t in input_types ]
375- resolved_outputs = [ _resolve_type_spec ( t ) for t in output_types ]
376-
377- subgraph = ir . Graph (
489+ return build_graph (
490+ trace_function ,
491+ inputs ,
492+ outputs ,
493+ opset_imports = dict ( self . _graph . opset_imports ),
378494 name = name ,
379- inputs = [],
380- outputs = [],
381- nodes = [],
382- opset_imports = {"" : opset_version },
383495 )
384496
385- for i , ts in enumerate (resolved_inputs ):
386- subgraph .inputs .append (ir .Value (name = f"input_{ i } " , type = ts .type , shape = ts .shape ))
387-
388- sub_builder = GraphBuilder (subgraph )
389- outputs = trace_function (sub_builder .op , * subgraph .inputs )
390- if not isinstance (outputs , Sequence ):
391- outputs = [outputs ]
392- if len (outputs ) != len (resolved_outputs ):
393- raise ValueError (
394- f"trace_function returned { len (outputs )} output(s), "
395- f"but { len (resolved_outputs )} were declared in output_types."
396- )
397- for output , ts in zip (outputs , resolved_outputs ):
398- output .type = ts .type
399- output .merge_shapes (ts .shape )
400-
401- subgraph .outputs .extend (outputs )
402- return subgraph
403-
404497 def call_op (
405498 self ,
406499 op_type : str ,
0 commit comments