Skip to content

Commit 565b8e5

Browse files
authored
Clean up onnxscript/_internal/converter.py (#2759)
Clean up onnxscript/_internal/converter.py and move sourceinfo to internal --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent b6e9878 commit 565b8e5

File tree

6 files changed

+54
-68
lines changed

6 files changed

+54
-68
lines changed

onnxscript/_internal/analysis.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import ast
66
from typing import Any, Optional, Sequence, Set
77

8-
from onnxscript import sourceinfo
9-
from onnxscript._internal import ast_utils
8+
from onnxscript._internal import ast_utils, sourceinfo
109

1110

1211
def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str:

onnxscript/_internal/analysis_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing import Any
88

99
from onnxscript._internal import analysis, ast_utils
10+
from onnxscript._internal.sourceinfo import formatter
1011
from onnxscript.onnx_opset import opset15 as op
11-
from onnxscript.sourceinfo import formatter
1212

1313

1414
class AnalysisResultsVisitor(ast.NodeVisitor):

onnxscript/_internal/converter.py

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,23 @@
77
from typing import (
88
TYPE_CHECKING,
99
Any,
10-
Dict,
11-
List,
1210
NoReturn,
13-
Optional,
1411
Sequence,
15-
Tuple,
1612
Union,
1713
)
1814

1915
import onnx
2016
import onnx_ir as ir
2117

2218
import onnxscript
23-
from onnxscript import onnx_types, sourceinfo
19+
from onnxscript import onnx_types
2420
from onnxscript._internal import (
2521
analysis,
2622
ast_utils,
2723
autocast,
2824
irbuilder,
2925
param_manipulation,
26+
sourceinfo,
3027
values,
3128
)
3229
from onnxscript._internal import (
@@ -171,10 +168,10 @@ class :class:`onnxscript.irbuilder.IRBuilder` is used
171168

172169
def __init__(
173170
self,
174-
opset: Optional[values.Opset] = None,
175-
global_names: Optional[dict[str, Any]] = None,
176-
source: Optional[str] = None,
177-
default_opset: Optional[values.Opset] = None,
171+
opset: values.Opset | None = None,
172+
global_names: dict[str, Any] | None = None,
173+
source: str | None = None,
174+
default_opset: values.Opset | None = None,
178175
):
179176
self.source = source
180177
if global_names is not None:
@@ -184,11 +181,11 @@ def __init__(
184181
self.default_opset_ = default_opset
185182

186183
# States initialized by `_init_function_translation`
187-
self._outer: List[irbuilder.IRFunction] = []
184+
self._outer: list[irbuilder.IRFunction] = []
188185
self._current_fn: irbuilder.IRFunction = None
189186
self._nextvar: int = 0
190187
self._used_vars: set[str] = set()
191-
self._locals: List[Dict[str, LocalSymValue]] = [{}]
188+
self._locals: list[dict[str, LocalSymValue]] = [{}]
192189
self._analyzer: analysis.AstAnalyzer | None = None
193190
self._castable: set[str] = set()
194191

@@ -225,7 +222,7 @@ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None:
225222
else:
226223
self.default_opset_ = opset
227224

228-
def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
225+
def _find_onnx_opset(self, node: ast.AST) -> values.Opset | None:
229226
"""Find the (first) ONNX opset used in the function, if any."""
230227
# Search for a Call expression of form "op.OpName(...)"
231228
if isinstance(node, ast.Call):
@@ -245,13 +242,15 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
245242
def _init_function_translation(self) -> None:
246243
"""Initialize self for translating a new (top-level) function."""
247244
self._outer = []
248-
self._current_fn: Optional[irbuilder.IRFunction] = None
245+
self._current_fn: irbuilder.IRFunction | None = None
249246
self._nextvar = 0
250247
self._used_vars = set()
251-
self._locals: List[Dict[str, LocalSymValue]] = [{}]
248+
self._locals: list[dict[str, LocalSymValue]] = [{}]
252249

253250
def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo:
254-
return sourceinfo.SourceInfo(node, self.source, self._current_fn.name)
251+
return sourceinfo.SourceInfo(
252+
node, code=self.source, function_name=self._current_fn.name
253+
)
255254

256255
def _message(self, node: ast.AST, error_msg: str) -> str:
257256
"""Constructs an error _message containing source information about an ast node."""
@@ -287,7 +286,7 @@ def _exit_scope(self) -> irbuilder.IRFunction:
287286
self._locals.pop(0)
288287
return graph
289288

290-
def _current_scope(self) -> Dict[str, LocalSymValue]:
289+
def _current_scope(self) -> dict[str, LocalSymValue]:
291290
return self._locals[0]
292291

293292
def _bind(self, name: str, val: LocalSymValue) -> None:
@@ -337,7 +336,7 @@ def tensor_name_generator() -> str:
337336
return ir.from_proto(proto)
338337

339338
def _to_onnx_attr_ref(
340-
self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]
339+
self, val: values.AttrRef, info: sourceinfo.SourceInfo | None
341340
) -> ir.Attr:
342341
attrtype = val.value.type
343342
attrname = None
@@ -357,8 +356,8 @@ def _to_onnx_attr_ref(
357356
def _to_onnx_var(
358357
self,
359358
val: values.SymbolValue | PyValue,
360-
target: Optional[PreferredName] = None,
361-
info: Optional[sourceinfo.SourceInfo] = None,
359+
target: PreferredName | None = None,
360+
info: sourceinfo.SourceInfo | None = None,
362361
) -> ir.Value:
363362
if isinstance(val, values.AttrRef):
364363
# promote attribute to value
@@ -397,8 +396,8 @@ def emit(
397396
self,
398397
outputs: Sequence[str],
399398
callee: values.Op | str,
400-
inputs: Sequence[Optional[ir.Value]],
401-
attrs: Optional[Sequence[ir.Attr]] = None,
399+
inputs: Sequence[ir.Value | None],
400+
attrs: Sequence[ir.Attr] | None = None,
402401
) -> Sequence[ir.Value] | ir.Value:
403402
if not isinstance(callee, values.Op):
404403
callee = values.Op(self.default_opset, callee)
@@ -416,6 +415,7 @@ def emit(
416415
if not isinstance(callee, values.Op):
417416
raise TypeError(f"Unexpected type {type(callee)} for callee.")
418417
node.meta.setdefault("callee", callee)
418+
assert self._current_fn is not None
419419
self._current_fn.append_node(node)
420420

421421
return output_values if len(output_values) > 1 else output_values[0]
@@ -429,7 +429,7 @@ def emit1(self, *args, **kwargs) -> ir.Value:
429429
def _emit_const(
430430
self,
431431
pyvalue: PyValue,
432-
suggested_name: Optional[PreferredName],
432+
suggested_name: PreferredName | None,
433433
info: sourceinfo.SourceInfo,
434434
) -> ir.Value:
435435
if suggested_name is None:
@@ -491,10 +491,8 @@ def _eval_constant_expr(self, expr: ast.expr) -> PyValue:
491491
function.)
492492
"""
493493
# TODO: assert (self._is_constant_expr(expr))
494-
# TODO: Refine types
495494
locals: dict[Any, Any] = {}
496-
expression = ast.Expression(expr)
497-
cpl = compile(expression, filename="<ast>", mode="eval")
495+
cpl = compile(ast.Expression(expr), filename="<ast>", mode="eval")
498496
try:
499497
return eval(cpl, self.globals, locals) # pylint: disable=eval-used
500498
except NameError as e:
@@ -506,7 +504,7 @@ def _eval_constant_expr(self, expr: ast.expr) -> PyValue:
506504
)
507505
) from e
508506

509-
def _get_type_annotation(self, annotation: ast.expr) -> Optional[ta.TypeAnnotationValue]:
507+
def _get_type_annotation(self, annotation: ast.expr) -> ta.TypeAnnotationValue | None:
510508
typeinfo = self._eval_constant_expr(annotation)
511509
if not ta.is_valid_type(typeinfo):
512510
self.warn(
@@ -520,8 +518,8 @@ def _translate_attr(
520518
self,
521519
attr_name: str,
522520
expr: ast.AST,
523-
attr_meta: Optional[onnx.defs.OpSchema.Attribute] = None,
524-
) -> Optional[ir.Attr]:
521+
attr_meta: onnx.defs.OpSchema.Attribute | None = None,
522+
) -> ir.Attr | None:
525523
"""Translate an attribute-value specification of the form `attr_name=<expr>`
526524
in a call to an op. expr is an AST. The following cases are supported:
527525
* Expr evaluates to a script-time constant (a python-value) that can be mapped
@@ -597,9 +595,7 @@ def _translate_docstring(self, node: ast.Expr) -> None:
597595
self.fail(node, "Docstring expression must be a constant.")
598596
self._current_fn.doc_string = node.value.value
599597

600-
def _translate_expr(
601-
self, node: ast.AST, target: Optional[PreferredName] = None
602-
) -> ir.Value:
598+
def _translate_expr(self, node: ast.AST, target: PreferredName | None = None) -> ir.Value:
603599
"""Expression-translation generates "IR statements/nodes" that compute the value of
604600
the expression into a target-variable, and returns the variable that is
605601
assigned this value.
@@ -630,7 +626,7 @@ def _translate_expr(
630626
result = self.generate_unique_name(target)
631627
return self.emit1([result], callee, args, attrs)
632628

633-
def _translate_opt_expr(self, node: ast.expr) -> Optional[ir.Value]:
629+
def _translate_opt_expr(self, node: ast.expr) -> ir.Value | None:
634630
"""Translation of an expression where "None" is permitted (eg., for an optional argument).
635631
None is represented as a Constant in Python 3.9+.
636632
"""
@@ -639,7 +635,7 @@ def _translate_opt_expr(self, node: ast.expr) -> Optional[ir.Value]:
639635
return self._translate_expr(node)
640636

641637
def _translate_subscript_expr(
642-
self, node: ast.Subscript, target: Optional[PreferredName]
638+
self, node: ast.Subscript, target: PreferredName | None
643639
) -> ir.Value:
644640
"""List of supported syntaxes is below.
645641
`A` is a tensor or an expression equivalent to a tensor.
@@ -689,7 +685,7 @@ def _translate_subscript_expr(
689685
# TODO: Do this at a graph-scope level.
690686
cached_int_consts: dict[int, ir.Value] = {}
691687

692-
def const_1d(value, name: Optional[str] = None) -> ir.Value:
688+
def const_1d(value, name: str | None = None) -> ir.Value:
693689
nonlocal cached_int_consts
694690
if value not in cached_int_consts:
695691
cached_int_consts[value] = self._emit_const([value], name, info)
@@ -703,8 +699,8 @@ def one_1d() -> ir.Value:
703699
minint = -(1 << 63)
704700

705701
def translate_slice_component(
706-
node_arg, default_value: Optional[int] = None
707-
) -> tuple[ir.Value, Optional[int]]:
702+
node_arg, default_value: int | None = None
703+
) -> tuple[ir.Value, int | None]:
708704
"""Translate optional start/stop/step component of a Slice expression."""
709705
if node_arg is None:
710706
if default_value is None:
@@ -758,9 +754,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value
758754

759755
# As the first step, we partition the index elements into four kinds: Slice (eg., 1:5:2),
760756
# known-to-be-scalar (eg., 2), other-tensor (eg., I), skip/no-op (that is, just ":")
761-
sliced_indices: List[Tuple[int, ast.expr]] = []
762-
scalar_indices: List[Tuple[int, ast.expr]] = []
763-
non_scalar_indices: List[Tuple[int, ast.expr]] = []
757+
sliced_indices: list[tuple[int, ast.expr]] = []
758+
scalar_indices: list[tuple[int, ast.expr]] = []
759+
non_scalar_indices: list[tuple[int, ast.expr]] = []
764760
for axis, elt in enumerate(indices):
765761
if isinstance(elt, ast.Slice):
766762
# Add to sliced_indices, unless it is "::", which is a no-op.
@@ -872,7 +868,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value
872868

873869
def _translate_call_expr(
874870
self, node: ast.Call
875-
) -> tuple[values.Op, list[Optional[ir.Value]], list[ir.Attr]]:
871+
) -> tuple[values.Op, list[ir.Value | None], list[ir.Attr]]:
876872
"""Translates a call-expression."""
877873
callee = self._translate_callee_expr(node.func)
878874
param_schemas = callee.param_schemas()
@@ -932,14 +928,7 @@ def _translate_unary_op_expr(self, node):
932928
# should intercept this call and replace node
933929
# by node.operand.
934930
# This mechanism does not handle somthing like `(-(-5))`.
935-
if hasattr(node.operand, "value"):
936-
# python 3.8+
937-
val = node.operand.value
938-
else:
939-
raise TypeError(
940-
f"Unable to guess constant value from type {type(node.operand)!r} "
941-
f"and attributes {dir(node.operand)!r}."
942-
)
931+
val = node.operand.value
943932
if op == ast.USub:
944933
cst = ast.Constant(-val, lineno=node.lineno, col_offset=node.col_offset)
945934
return self._translate_expr(cst)
@@ -1042,7 +1031,7 @@ def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None:
10421031
return None
10431032
raise ValueError(self._message(node, f"Unsupported statement type '{type(node)!r}'."))
10441033

1045-
def _translate_assign_stmt(self, stmt: Union[ast.Assign, ast.AnnAssign]) -> None:
1034+
def _translate_assign_stmt(self, stmt: ast.Assign | ast.AnnAssign) -> None:
10461035
def assign(lhs: ast.AST, rhs: ast.AST) -> None:
10471036
if isinstance(lhs, ast.Name):
10481037
# Assignments of the form "x = SomeExpression"
@@ -1202,7 +1191,7 @@ def rename(x):
12021191
values.SymbolValue(y, self._source_of(stmt)),
12031192
)
12041193

1205-
def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
1194+
def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
12061195
# loop-variable
12071196
if isinstance(loop_stmt, ast.For):
12081197
if not isinstance(loop_stmt.target, ast.Name):

onnxscript/_internal/converter_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def validate_save(
7777
model = f.to_model_proto(io_types=FLOAT)
7878
if save_text:
7979
with (TEST_OUTPUT_DIR / f"{f.name}.txt").open("w", encoding="utf-8") as fi:
80-
fi.write(onnx.helper.printable_graph(model.graph))
80+
fi.write(onnx.printer.to_text(model.graph))
8181
for fct in model.functions:
8282
fi.write("\n-------------------------\n")
83-
fi.write(onnx.helper.printable_graph(fct))
83+
fi.write(onnx.printer.to_text(fct))
8484
if check_ort and (skip_check_ort is None or f.name not in skip_check_ort):
8585
try:
8686
create_cpu_inference_session(model.SerializeToString())
@@ -92,10 +92,10 @@ def validate_save(
9292
model = onnx.shape_inference.infer_shapes(model)
9393
if save_text:
9494
with open(os.path.join(TEST_OUTPUT_DIR, f"{f.name}.shape.txt"), "w") as fi:
95-
fi.write(onnx.helper.printable_graph(model.graph))
95+
fi.write(onnx.printer.to_text(model.graph))
9696
for fct in model.functions:
97-
f.write("\n-------------------------\n")
98-
f.write(onnx.helper.printable_graph(fct))
97+
fi.write("\n-------------------------\n")
98+
fi.write(onnx.printer.to_text(fct))
9999
try:
100100
onnx.checker.check_model(model)
101101
except onnx.checker.ValidationError as e:
Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
# -------------------------------------------------------------------------
2-
# Copyright (c) Microsoft Corporation. All rights reserved.
1+
# Copyright (c) Microsoft Corporation.
32
# Licensed under the MIT License.
4-
# -------------------------------------------------------------------------
53

64
"""Source code information used for diagnostic messages."""
75

86
from __future__ import annotations
97

108
import ast
11-
from typing import Callable, Optional
9+
from typing import Callable
1210

1311

1412
class SourceInfo:
@@ -17,8 +15,9 @@ class SourceInfo:
1715
def __init__(
1816
self,
1917
ast_node: ast.AST,
20-
code: Optional[str] = None,
21-
function_name: Optional[str] = None,
18+
*,
19+
code: str | None = None,
20+
function_name: str | None = None,
2221
):
2322
self.ast_node = ast_node
2423
self.code = code
@@ -52,8 +51,8 @@ def __str__(self) -> str:
5251
Formatter = Callable[[ast.AST, str], str]
5352

5453

55-
def formatter(source_code: Optional[str]) -> Formatter:
54+
def formatter(source_code: str | None) -> Formatter:
5655
def format(node: ast.AST, message: str) -> str:
57-
return SourceInfo(node, source_code).msg(message)
56+
return SourceInfo(node, code=source_code).msg(message)
5857

5958
return format

onnxscript/_internal/values.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
import onnx_ir as ir
2828
from typing_extensions import ParamSpec
2929

30-
from onnxscript import sourceinfo
31-
from onnxscript._internal import ast_utils, deprecation, irbuilder, type_annotation
30+
from onnxscript._internal import ast_utils, deprecation, irbuilder, sourceinfo, type_annotation
3231
from onnxscript._internal import converter as converter_module
3332
from onnxscript.ir import _schemas
3433
from onnxscript.onnx_types import ONNXType

0 commit comments

Comments
 (0)