Skip to content

Commit 95a6883

Browse files
authored
Merge branch 'main' into copilot/fix-squeeze-layer-axes
2 parents 58755be + 4291ff2 commit 95a6883

33 files changed

Lines changed: 1140 additions & 463 deletions

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ jobs:
8383
token: ${{ secrets.CODECOV_TOKEN }}
8484
- name: Upload torchlib error reports
8585
if: always()
86-
uses: actions/upload-artifact@v6
86+
uses: actions/upload-artifact@v7
8787
with:
8888
name: Error reports (${{ matrix.name }}-${{ matrix.os }})
8989
path: error_reports

.lintrunner.toml

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -98,42 +98,6 @@ init_command = [
9898
]
9999
is_formatter = true
100100

101-
[[linter]]
102-
code = 'PYLINT'
103-
include_patterns = [
104-
'**/*.py',
105-
]
106-
exclude_patterns = [
107-
'docs/**',
108-
'examples/**',
109-
'onnxscript/_internal/converter_test.py',
110-
'onnxscript/optimizer/**', # FIXME
111-
'onnxscript/rewriter/**', # FIXME
112-
'tests/functions/**',
113-
'tests/models/**',
114-
'tests/onnx_backend_test_code/**',
115-
]
116-
command = [
117-
'python',
118-
'-m',
119-
'lintrunner_adapters',
120-
'run',
121-
'pylint_linter',
122-
'--rcfile=pyproject_pylint.toml',
123-
'--show-disable',
124-
'--',
125-
'@{{PATHSFILE}}'
126-
]
127-
init_command = [
128-
'python',
129-
'-m',
130-
'lintrunner_adapters',
131-
'run',
132-
'pip_init',
133-
'--dry-run={{DRYRUN}}',
134-
'--requirement=requirements/lintrunner/requirements.txt',
135-
]
136-
137101
[[linter]]
138102
code = 'EDITORCONFIG-CHECKER'
139103
include_patterns = ['**']

docs/tutorial/builder/graph_builder.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,11 @@ The subgraph automatically inherits the opset version from the parent
339339

340340
### Type annotations for subgraph inputs and outputs
341341

342-
`subgraph()` accepts `input_types` and `output_types` lists that describe
343-
the types and shapes of each input and output. Each element can be either an
344-
`ir.TypeAndShape` object or — more conveniently — an
342+
`subgraph()` accepts `inputs` and `outputs` that describe
343+
the types and shapes of each input and output. They can be provided as a
344+
:class:`list` of type specs (names are auto-generated) **or** as a
345+
:class:`dict` mapping explicit names to type specs. Each type spec can be
346+
either an `ir.TypeAndShape` object or — more conveniently — an
345347
`onnxscript` tensor-type expression:
346348

347349
| Expression | Meaning |
@@ -408,8 +410,8 @@ def cumsum_body(op, state, x_i):
408410

409411
body = builder.subgraph(
410412
cumsum_body,
411-
input_types=[FLOAT[D], FLOAT[D]], # state, x_i
412-
output_types=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i
413+
inputs=[FLOAT[D], FLOAT[D]], # state, x_i
414+
outputs=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i
413415
name="cumsum_body",
414416
)
415417

@@ -430,7 +432,7 @@ model = ir.Model(graph=graph, ir_version=10)
430432

431433
Key points:
432434

433-
- `builder.subgraph(fn, input_types, output_types)` creates a fresh
435+
- `builder.subgraph(fn, inputs, outputs)` creates a fresh
434436
`ir.Graph`, calls `fn(op, *inputs)` to trace the body, and wires up the
435437
declared input/output types.
436438
- The `fn` receives an `OpBuilder` as its first argument — exactly the same
@@ -450,8 +452,8 @@ def outer_body(op, state, x_i):
450452
# Build a nested subgraph inside the scan body
451453
inner = op.builder.subgraph(
452454
lambda iop, v: iop.Relu(v),
453-
input_types=[FLOAT[D]],
454-
output_types=[FLOAT[D]],
455+
inputs=[FLOAT[D]],
456+
outputs=[FLOAT[D]],
455457
name="relu_body",
456458
)
457459
# ... use inner as a graph attribute of a nested op ...

docs/tutorial/rewriter/node_value_checkers.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,16 @@ This means you should be careful when designing patterns with multiple alternati
179179

180180
## Error Handling
181181

182-
Checkers can return either:
183-
- `True`: Check passed, continue matching
184-
- `False`: Check failed, pattern does not match
185-
- `MatchResult`: More detailed result with potential failure reasons
182+
Both check functions (including condition functions and node/value-level checkers) and
183+
rewrite functions support the same conventions for indicating failure:
186184

187-
If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions.
185+
- **`MatchResult` with `.fail()`** *(recommended)*: Return `MatchResult().fail("reason", source)` to indicate failure with a descriptive reason and optional source node/value. This provides the most useful diagnostic information for debugging.
186+
- **Raise `MatchFailureError`** *(recommended)*: Import it as `from onnxscript.rewriter.rewriter import MatchFailureError` and raise `MatchFailureError("reason", source1, source2, ...)` to indicate failure associated with one or more `ir.Node` or `ir.Value` objects. Each source should be passed as a separate positional argument (do not pass a list as a single argument). This is especially convenient in utility functions called from a check or rewrite, since it avoids having to explicitly propagate failure status through the call chain.
187+
- **Return `None` or `False`**: These indicate failure without providing a reason. They are supported but not recommended, since a failure reason is valuable for debugging why a rule did not apply.
188+
189+
Including a descriptive failure reason is strongly encouraged. The rewriter's tracing infrastructure
190+
uses these reasons to report why rules failed to match, which is essential for diagnosing
191+
issues when developing or debugging rewrite rules.
192+
193+
For **check functions**, success is indicated by returning `True` or a truthy `MatchResult`.
194+
For **rewrite functions**, success is indicated by returning one or more `ir.Value` results.

onnxscript/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"OnnxFunction",
1414
"TracedOnnxFunction",
1515
"GraphBuilder",
16+
"OpBuilder",
1617
"proto2python",
1718
"external_tensor",
1819
"BFLOAT16",
@@ -129,7 +130,7 @@
129130
# isort: on
130131

131132
from . import ir, nn, optimizer, rewriter, version_converter
132-
from ._internal.builder import GraphBuilder
133+
from ._internal.builder import GraphBuilder, OpBuilder
133134
from ._internal.utils import external_tensor
134135
from ._internal.values import OnnxFunction, TracedOnnxFunction
135136

onnxscript/_internal/_inliner.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,29 @@
1010

1111

1212
def instantiate(
13-
function: ir.Function,
13+
graph: ir.Graph,
1414
inputs: Sequence[ir.Value | None],
1515
attributes: Mapping[str, ir.Attr],
1616
*,
1717
prefix: str = "",
1818
) -> tuple[list[ir.Node], list[ir.Value | None]]:
19-
"""Instantiate (inline) a function, substituting inputs and attributes.
19+
"""Instantiate (inline) a graph, substituting inputs and attributes.
2020
2121
Args:
22-
function: The function to instantiate.
23-
inputs: Actual input values to bind to the function's formal parameters.
22+
graph: The graph to instantiate.
23+
inputs: Actual input values to bind to the graph's formal parameters.
2424
attributes: Attribute values to substitute for reference attributes.
2525
prefix: Optional prefix to prepend to node and output names.
2626
2727
Returns:
28-
A tuple of (nodes, outputs) where nodes are the cloned function body
29-
and outputs are the values corresponding to the function's outputs.
28+
A tuple of (nodes, outputs) where nodes are the cloned graph body
29+
and outputs are the values corresponding to the graph's outputs.
3030
"""
31-
formal_inputs = function.inputs
31+
formal_inputs = graph.inputs
3232
if len(inputs) > len(formal_inputs):
3333
raise ValueError(
3434
f"Too many inputs: got {len(inputs)}, "
35-
f"but function has {len(formal_inputs)} parameters."
35+
f"but graph has {len(formal_inputs)} parameters."
3636
)
3737
value_map: dict[ir.Value, ir.Value | None] = dict(zip(formal_inputs, inputs))
3838

@@ -50,7 +50,8 @@ def rename(node: ir.Node) -> None:
5050
metadata_props={},
5151
post_process=rename,
5252
resolve_ref_attrs=True,
53+
allow_outer_scope_values=True,
5354
)
54-
nodes = [cloner.clone_node(n) for n in function]
55-
outputs = [value_map.get(v) for v in function.outputs]
55+
nodes = [cloner.clone_node(n) for n in graph]
56+
outputs = [value_map.get(v) for v in graph.outputs]
5657
return nodes, outputs

onnxscript/_internal/analysis.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
153153
return live
154154

155155
def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
156-
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
156+
def visit_block(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
157157
for s in reversed(block):
158158
live_out = visit(s, live_out)
159159
return live_out
@@ -167,28 +167,28 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
167167
if isinstance(stmt, ast.If):
168168
constant_cond = self.constant_if_condition(stmt)
169169
if constant_cond is None:
170-
live1 = visitBlock(stmt.body, live_out)
171-
live2 = visitBlock(stmt.orelse, live_out)
170+
live1 = visit_block(stmt.body, live_out)
171+
live2 = visit_block(stmt.orelse, live_out)
172172
return live1 | live2 | _used_vars(stmt.test)
173173
elif constant_cond:
174-
return visitBlock(stmt.body, live_out)
174+
return visit_block(stmt.body, live_out)
175175
else:
176-
return visitBlock(stmt.orelse, live_out)
176+
return visit_block(stmt.orelse, live_out)
177177
if isinstance(stmt, ast.For):
178178
p_loop_var = _get_loop_var(stmt, self._formatter)
179179
prev = None
180180
curr = live_out
181181
while curr != prev:
182182
prev = curr
183-
curr = visitBlock(stmt.body, prev).difference({p_loop_var})
183+
curr = visit_block(stmt.body, prev).difference({p_loop_var})
184184
return curr
185185
if isinstance(stmt, ast.While):
186186
cond_vars = _used_vars(stmt.test)
187187
prev = None
188188
curr = live_out | cond_vars
189189
while curr != prev:
190190
prev = curr
191-
curr = visitBlock(stmt.body, prev) | cond_vars
191+
curr = visit_block(stmt.body, prev) | cond_vars
192192
return curr
193193
if isinstance(stmt, ast.Break):
194194
# The following is sufficient for the current restricted usage, where
@@ -228,7 +228,7 @@ def exposed_uses(self, stmts: Sequence[ast.stmt]) -> set[str]:
228228
(in the first statement). Hence x is included in the exposed_uses.
229229
"""
230230

231-
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
231+
def visit_block(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
232232
for stmt in reversed(block):
233233
live_out = visit(stmt, live_out)
234234
return live_out
@@ -243,13 +243,13 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
243243
if isinstance(stmt, ast.If):
244244
constant_cond = self.constant_if_condition(stmt)
245245
if constant_cond is None:
246-
live1 = visitBlock(stmt.body, live_out)
247-
live2 = visitBlock(stmt.orelse, live_out)
246+
live1 = visit_block(stmt.body, live_out)
247+
live2 = visit_block(stmt.orelse, live_out)
248248
return (live1 | live2) | _used_vars(stmt.test)
249249
elif constant_cond:
250-
return visitBlock(stmt.body, live_out)
250+
return visit_block(stmt.body, live_out)
251251
else:
252-
return visitBlock(stmt.orelse, live_out)
252+
return visit_block(stmt.orelse, live_out)
253253
if ast_utils.is_print_call(stmt):
254254
return live_out
255255
if ast_utils.is_doc_string(stmt):
@@ -259,13 +259,13 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
259259
# for loops that execute at least once.
260260
loop_var_set = {_get_loop_var(stmt, self._formatter)}
261261
used_after_loop = live_out.difference(loop_var_set)
262-
used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set)
262+
used_inside_loop = visit_block(stmt.body, set()).difference(loop_var_set)
263263
used_in_loop_header = _used_vars(stmt.iter)
264264
return used_inside_loop | used_in_loop_header | used_after_loop
265265
if isinstance(stmt, ast.While):
266266
# Analysis assumes loop may execute zero times. Results can be improved
267267
# for loops that execute at least once.
268-
used_inside_loop = visitBlock(stmt.body, set())
268+
used_inside_loop = visit_block(stmt.body, set())
269269
used_in_loop_header = _used_vars(stmt.test)
270270
return used_inside_loop | used_in_loop_header | live_out
271271
if isinstance(stmt, ast.Break):
@@ -281,7 +281,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
281281
self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
282282
)
283283

284-
return visitBlock(stmts, set())
284+
return visit_block(stmts, set())
285285

286286
def outer_scope_variables(self, fun: ast.FunctionDef) -> set[str]:
287287
"""Return the set of outer-scope variables used in a nested function.

onnxscript/_internal/autocast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,15 @@ def get_type_info(x: Optional[ir.Value]) -> Optional[ir.Value]:
187187
argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is
188188
castable, while X can serve as the target-type.
189189
"""
190-
return None if x is None or converter_.is_castable(x.name) else x
190+
return None if x is None or converter_._is_castable(x.name) else x # pylint: disable=protected-access
191191

192192
def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
193193
if x is None:
194194
return None
195-
if converter_.is_castable(x.name) and y is not None:
195+
if converter_._is_castable(x.name) and y is not None: # pylint: disable=protected-access
196196
# Polymorphic constant x is cast to the type of y:
197-
x_cast = converter_.generate_unique_name(f"{x.name}_cast")
198-
return converter_.emit1([x_cast], "CastLike", [x, y])
197+
x_cast = converter_._generate_unique_name(f"{x.name}_cast") # pylint: disable=protected-access
198+
return converter_._emit1([x_cast], "CastLike", [x, y]) # pylint: disable=protected-access
199199
return x
200200

201201
return cast_inputs(get_type_info, cast_like, op_signature, args)

0 commit comments

Comments
 (0)