Skip to content

Commit 6800f0b

Browse files
authored
Merge branch 'main' into copilot/create-fusion-rule-remove-expand-node
2 parents a47d985 + 5391619 commit 6800f0b

File tree

13 files changed

+389
-90
lines changed

13 files changed

+389
-90
lines changed

.lintrunner.toml

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -98,42 +98,6 @@ init_command = [
9898
]
9999
is_formatter = true
100100

101-
[[linter]]
102-
code = 'PYLINT'
103-
include_patterns = [
104-
'**/*.py',
105-
]
106-
exclude_patterns = [
107-
'docs/**',
108-
'examples/**',
109-
'onnxscript/_internal/converter_test.py',
110-
'onnxscript/optimizer/**', # FIXME
111-
'onnxscript/rewriter/**', # FIXME
112-
'tests/functions/**',
113-
'tests/models/**',
114-
'tests/onnx_backend_test_code/**',
115-
]
116-
command = [
117-
'python',
118-
'-m',
119-
'lintrunner_adapters',
120-
'run',
121-
'pylint_linter',
122-
'--rcfile=pyproject_pylint.toml',
123-
'--show-disable',
124-
'--',
125-
'@{{PATHSFILE}}'
126-
]
127-
init_command = [
128-
'python',
129-
'-m',
130-
'lintrunner_adapters',
131-
'run',
132-
'pip_init',
133-
'--dry-run={{DRYRUN}}',
134-
'--requirement=requirements/lintrunner/requirements.txt',
135-
]
136-
137101
[[linter]]
138102
code = 'EDITORCONFIG-CHECKER'
139103
include_patterns = ['**']

docs/tutorial/rewriter/node_value_checkers.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,16 @@ This means you should be careful when designing patterns with multiple alternati
179179

180180
## Error Handling
181181

182-
Checkers can return either:
183-
- `True`: Check passed, continue matching
184-
- `False`: Check failed, pattern does not match
185-
- `MatchResult`: More detailed result with potential failure reasons
182+
Both check functions (including condition functions and node/value-level checkers) and
183+
rewrite functions support the same conventions for indicating failure:
186184

187-
If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions.
185+
- **`MatchResult` with `.fail()`** *(recommended)*: Return `MatchResult().fail("reason", source)` to indicate failure with a descriptive reason and optional source node/value. This provides the most useful diagnostic information for debugging.
186+
- **Raise `MatchFailureError`** *(recommended)*: Import it as `from onnxscript.rewriter.rewriter import MatchFailureError` and raise `MatchFailureError("reason", source1, source2, ...)` to indicate failure associated with one or more `ir.Node` or `ir.Value` objects. Each source should be passed as a separate positional argument (do not pass a list as a single argument). This is especially convenient in utility functions called from a check or rewrite, since it avoids having to explicitly propagate failure status through the call chain.
187+
- **Return `None` or `False`**: These indicate failure without providing a reason. They are supported but not recommended, since a failure reason is valuable for debugging why a rule did not apply.
188+
189+
Including a descriptive failure reason is strongly encouraged. The rewriter's tracing infrastructure
190+
uses these reasons to report why rules failed to match, which is essential for diagnosing
191+
issues when developing or debugging rewrite rules.
192+
193+
For **check functions**, success is indicated by returning `True` or a truthy `MatchResult`.
194+
For **rewrite functions**, success is indicated by returning one or more `ir.Value` results.

onnxscript/_internal/builder.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Sequence[float],
3232
Sequence[bool],
3333
Sequence[str],
34+
None,
3435
]
3536

3637
# Mapping from Python scalar types to their default ONNX DataType,
@@ -132,6 +133,7 @@ def build_graph(
132133
*,
133134
opset_imports: dict[str, int] | None = None,
134135
name: str = "subgraph",
136+
parent: GraphBuilder | None = None,
135137
) -> ir.Graph:
136138
"""Build an :class:`ir.Graph` suitable for use as a graph-valued attribute.
137139
@@ -164,6 +166,10 @@ def build_graph(
164166
opset_imports: Opset version map for the subgraph (e.g.
165167
``{"": 23}``). Defaults to ``{"": 23}`` when *None*.
166168
name: Name of the resulting :class:`ir.Graph`.
169+
parent: Optional parent :class:`GraphBuilder`. When provided, the
170+
sub-builder's ``_root`` points to the root builder of the parent,
171+
so that :meth:`Parameter._realize` registers initializers in the
172+
root (main) graph rather than the subgraph.
167173
168174
Returns:
169175
An :class:`ir.Graph` whose inputs and outputs are populated and whose
@@ -187,7 +193,9 @@ def build_graph(
187193
for input_name, ts in resolved_inputs:
188194
subgraph.inputs.append(ir.Value(name=input_name, type=ts.type, shape=ts.shape))
189195

190-
sub_builder = GraphBuilder(subgraph)
196+
sub_builder = GraphBuilder(subgraph, parent=parent)
197+
if parent is not None:
198+
sub_builder._scope_stack = list(parent._scope_stack)
191199
trace_outputs = trace_function(sub_builder.op, *subgraph.inputs)
192200
if not isinstance(trace_outputs, Sequence):
193201
trace_outputs = [trace_outputs]
@@ -208,8 +216,10 @@ def build_graph(
208216
class GraphBuilder:
209217
"""Imperative builder for constructing ONNX IR graphs with automatic constant promotion, type casting, and shape inference."""
210218

211-
def __init__(self, graph: ir.Graph) -> None:
219+
def __init__(self, graph: ir.Graph, parent: GraphBuilder | None = None) -> None:
212220
self._graph = graph
221+
self._parent = parent
222+
self._root: GraphBuilder = parent._root if parent is not None else self
213223

214224
# Get the opset version for "" (default domain) from the graph
215225
if "" not in graph.opset_imports:
@@ -237,6 +247,16 @@ def opset(self, domain: str, version: int = 1) -> OpBuilder:
237247
def op(self) -> OpBuilder:
238248
return self._op_builder
239249

250+
@property
251+
def parent(self) -> GraphBuilder | None:
252+
"""The parent builder, or None for a top-level builder."""
253+
return self._parent
254+
255+
@property
256+
def root(self) -> GraphBuilder:
257+
"""The root (top-level) builder in the parent chain."""
258+
return self._root
259+
240260
@property
241261
def graph(self) -> ir.Graph:
242262
return self._graph
@@ -258,7 +278,7 @@ def initializer(
258278

259279
def _input_to_ir_value(
260280
self, value: VALUE_LIKE, like_type: ir.Value | None = None
261-
) -> ir.Value:
281+
) -> ir.Value | None:
262282
"""Convert a permissible input (for a call to an op) into an ir.Value.
263283
264284
Permissible values include ir.Value as well as python constants that can be converted
@@ -267,6 +287,8 @@ def _input_to_ir_value(
267287
"""
268288
if isinstance(value, ir.Value):
269289
return value
290+
if value is None:
291+
return value
270292
dtype = (
271293
like_type.type.dtype
272294
if like_type is not None and like_type.type is not None
@@ -356,7 +378,7 @@ def _get_schema(
356378
def _partition_inputs_attributes(
357379
self,
358380
schema: onnx.defs.OpSchema | None,
359-
inputs: Sequence[ir.Value | ir.TensorProtocol],
381+
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
360382
kwargs: dict[str, Any],
361383
) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]:
362384
if schema is None:
@@ -499,12 +521,13 @@ def subgraph(
499521
outputs,
500522
opset_imports=dict(self._graph.opset_imports),
501523
name=name,
524+
parent=self,
502525
)
503526

504527
def call_op(
505528
self,
506529
op_type: str,
507-
inputs: Sequence[ir.Value | ir.TensorProtocol],
530+
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
508531
kwargs: dict[str, Any],
509532
):
510533
"""Create an ONNX node and add it to the graph, returning its output value(s)."""

onnxscript/_internal/builder_test.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,39 @@ def add_mul(X, Y):
848848

849849
self.assertIn("does not match", str(cm.exception))
850850

851+
def test_none_input_is_passed_through(self):
852+
"""Test that None inputs are preserved as None in the node's inputs."""
853+
op, x, y = _create_builder_with_inputs()
854+
855+
# Gemm's third input (C) is optional; passing None should work
856+
result = op.Gemm(x, y, None, alpha=1.0)
857+
858+
nodes = list(op.builder.graph)
859+
self.assertEqual(len(nodes), 1)
860+
node = nodes[0]
861+
self.assertEqual(node.op_type, "Gemm")
862+
# The third input should be None (optional, omitted)
863+
self.assertEqual(len(list(node.inputs)), 3)
864+
self.assertIs(node.inputs[0], x)
865+
self.assertIs(node.inputs[1], y)
866+
self.assertIsNone(node.inputs[2])
867+
self.assertIsNotNone(result)
868+
869+
def test_none_input_with_custom_domain(self):
870+
"""Test that None inputs work with custom domain ops."""
871+
op, x, y = _create_builder_with_inputs()
872+
873+
result = op.CustomOp(x, None, y, _domain="com.custom")
874+
875+
nodes = list(op.builder.graph)
876+
self.assertEqual(len(nodes), 1)
877+
node = nodes[0]
878+
self.assertEqual(node.op_type, "CustomOp")
879+
self.assertIs(node.inputs[0], x)
880+
self.assertIsNone(node.inputs[1])
881+
self.assertIs(node.inputs[2], y)
882+
self.assertIsNotNone(result)
883+
851884

852885
class BuildSubgraphTest(unittest.TestCase):
853886
"""Tests for GraphBuilder.subgraph()."""
@@ -1028,6 +1061,88 @@ def test_build_graph_custom_name(self):
10281061
)
10291062
self.assertEqual(graph.name, "loop_body")
10301063

1064+
def test_build_graph_with_parent(self):
1065+
"""build_graph with parent sets root on the sub-builder."""
1066+
parent_graph = ir.Graph(
1067+
name="main",
1068+
inputs=[],
1069+
outputs=[],
1070+
nodes=[],
1071+
opset_imports={"": 23},
1072+
)
1073+
parent_builder = builder.GraphBuilder(parent_graph)
1074+
1075+
def body(op, x):
1076+
self.assertIs(op.builder.parent, parent_builder)
1077+
self.assertIs(op.builder.root, parent_builder)
1078+
return op.Identity(x)
1079+
1080+
builder.build_graph(
1081+
body,
1082+
inputs=[FLOAT[3]],
1083+
outputs=[FLOAT[3]],
1084+
parent=parent_builder,
1085+
)
1086+
1087+
def test_subgraph_sets_parent_and_root(self):
1088+
"""GraphBuilder.subgraph() sets parent=self on the sub-builder."""
1089+
parent_graph = ir.Graph(
1090+
name="main",
1091+
inputs=[],
1092+
outputs=[],
1093+
nodes=[],
1094+
opset_imports={"": 23},
1095+
)
1096+
parent_builder = builder.GraphBuilder(parent_graph)
1097+
1098+
def body(op, x):
1099+
self.assertIs(op.builder.parent, parent_builder)
1100+
self.assertIs(op.builder.root, parent_builder)
1101+
return op.Identity(x)
1102+
1103+
parent_builder.subgraph(body, inputs=[FLOAT[3]], outputs=[FLOAT[3]])
1104+
1105+
def test_build_graph_inherits_parent_scope_stack(self):
1106+
"""build_graph copies the parent's scope stack so nodes in the subgraph carry scoped names."""
1107+
parent_graph = ir.Graph(
1108+
name="main",
1109+
inputs=[],
1110+
outputs=[],
1111+
nodes=[],
1112+
opset_imports={"": 23},
1113+
)
1114+
parent_builder = builder.GraphBuilder(parent_graph)
1115+
parent_builder.push_module("encoder", "Encoder")
1116+
parent_builder.push_module("layers.0", "TransformerBlock")
1117+
1118+
subgraph = builder.build_graph(
1119+
lambda op, x: op.Relu(x),
1120+
inputs={"x": FLOAT[3, 4]},
1121+
outputs={"y": FLOAT[3, 4]},
1122+
parent=parent_builder,
1123+
)
1124+
1125+
# The single node created inside the subgraph should carry the
1126+
# parent's scope prefix in its name and metadata.
1127+
node = subgraph.node(0)
1128+
self.assertIn("encoder", node.name)
1129+
self.assertIn("layers.0", node.name)
1130+
self.assertIn("encoder", node.metadata_props["namespace"])
1131+
self.assertIn("TransformerBlock", node.metadata_props["namespace"])
1132+
1133+
def test_root_graph_builder_is_its_own_root(self):
1134+
"""A top-level GraphBuilder has root == self."""
1135+
graph = ir.Graph(
1136+
name="main",
1137+
inputs=[],
1138+
outputs=[],
1139+
nodes=[],
1140+
opset_imports={"": 23},
1141+
)
1142+
gb = builder.GraphBuilder(graph)
1143+
self.assertIs(gb.root, gb)
1144+
self.assertIsNone(gb.parent)
1145+
10311146

10321147
class PartitionInputsAttributesTest(unittest.TestCase):
10331148
"""Tests for GraphBuilder._partition_inputs_attributes."""

onnxscript/nn/_module_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,77 @@ def test_realize_qualifies_name(self):
7272
self.assertEqual(value.name, "layer1.bias")
7373
self.assertIn("layer1.bias", graph.initializers)
7474

75+
def test_realize_in_subgraph_registers_in_root(self):
76+
"""Parameter realized inside a subgraph builder is stored in the root graph."""
77+
from onnxscript._internal.builder import GraphBuilder
78+
from onnxscript.onnx_types import FLOAT
79+
80+
root_graph = ir.Graph(
81+
name="main",
82+
inputs=[],
83+
outputs=[],
84+
nodes=[],
85+
opset_imports={"": 23},
86+
)
87+
root_builder = GraphBuilder(root_graph)
88+
89+
p = Parameter([3, 4], name="weight")
90+
91+
def body_fn(op, x):
92+
# Realize param inside a sub-builder context
93+
p._realize(op.builder) # pylint: disable=protected-access
94+
return op.Add(x, x)
95+
96+
_sub_graph = root_builder.subgraph(
97+
body_fn,
98+
inputs=[FLOAT[3, 4]],
99+
outputs=[FLOAT[3, 4]],
100+
)
101+
# Parameter should be in the ROOT graph's initializers, not the subgraph's
102+
self.assertIn("weight", root_graph.initializers)
103+
self.assertIs(root_graph.initializers["weight"], p)
104+
# The subgraph should NOT have the initializer
105+
self.assertNotIn("weight", _sub_graph.initializers)
106+
107+
def test_realize_in_nested_subgraph_registers_in_root(self):
108+
"""Parameter realized in a doubly-nested subgraph goes to the root graph."""
109+
from onnxscript._internal.builder import GraphBuilder, build_graph
110+
from onnxscript.onnx_types import FLOAT
111+
112+
root_graph = ir.Graph(
113+
name="main",
114+
inputs=[],
115+
outputs=[],
116+
nodes=[],
117+
opset_imports={"": 23},
118+
)
119+
root_builder = GraphBuilder(root_graph)
120+
121+
p = Parameter([3], name="bias")
122+
123+
def inner_fn(op, x):
124+
p._realize(op.builder) # pylint: disable=protected-access
125+
return op.Identity(x)
126+
127+
def outer_fn(op, x):
128+
# Build a nested subgraph
129+
build_graph(
130+
inner_fn,
131+
inputs=[FLOAT[3]],
132+
outputs=[FLOAT[3]],
133+
parent=op.builder,
134+
)
135+
return op.Identity(x)
136+
137+
root_builder.subgraph(
138+
outer_fn,
139+
inputs=[FLOAT[3]],
140+
outputs=[FLOAT[3]],
141+
)
142+
# Even through two levels of nesting, param ends up in root
143+
self.assertIn("bias", root_graph.initializers)
144+
self.assertIs(root_graph.initializers["bias"], p)
145+
75146

76147
class ModuleBasicTest(unittest.TestCase):
77148
def test_parameter_auto_registration(self):

0 commit comments

Comments
 (0)