77from typing import (
88 TYPE_CHECKING ,
99 Any ,
10- Dict ,
11- List ,
1210 NoReturn ,
13- Optional ,
1411 Sequence ,
15- Tuple ,
1612 Union ,
1713)
1814
1915import onnx
2016import onnx_ir as ir
2117
2218import onnxscript
23- from onnxscript import onnx_types , sourceinfo
19+ from onnxscript import onnx_types
2420from onnxscript ._internal import (
2521 analysis ,
2622 ast_utils ,
2723 autocast ,
2824 irbuilder ,
2925 param_manipulation ,
26+ sourceinfo ,
3027 values ,
3128)
3229from onnxscript ._internal import (
@@ -171,10 +168,10 @@ class :class:`onnxscript.irbuilder.IRBuilder` is used
171168
172169 def __init__ (
173170 self ,
174- opset : Optional [ values .Opset ] = None ,
175- global_names : Optional [ dict [str , Any ]] = None ,
176- source : Optional [ str ] = None ,
177- default_opset : Optional [ values .Opset ] = None ,
171+ opset : values .Opset | None = None ,
172+ global_names : dict [str , Any ] | None = None ,
173+ source : str | None = None ,
174+ default_opset : values .Opset | None = None ,
178175 ):
179176 self .source = source
180177 if global_names is not None :
@@ -184,11 +181,11 @@ def __init__(
184181 self .default_opset_ = default_opset
185182
186183 # States initialized by `_init_function_translation`
187- self ._outer : List [irbuilder .IRFunction ] = []
184+ self ._outer : list [irbuilder .IRFunction ] = []
188185 self ._current_fn : irbuilder .IRFunction = None
189186 self ._nextvar : int = 0
190187 self ._used_vars : set [str ] = set ()
191- self ._locals : List [ Dict [str , LocalSymValue ]] = [{}]
188+ self ._locals : list [ dict [str , LocalSymValue ]] = [{}]
192189 self ._analyzer : analysis .AstAnalyzer | None = None
193190 self ._castable : set [str ] = set ()
194191
@@ -225,7 +222,7 @@ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None:
225222 else :
226223 self .default_opset_ = opset
227224
228- def _find_onnx_opset (self , node : ast .AST ) -> Optional [ values .Opset ] :
225+ def _find_onnx_opset (self , node : ast .AST ) -> values .Opset | None :
229226 """Find the (first) ONNX opset used in the function, if any."""
230227 # Search for a Call expression of form "op.OpName(...)"
231228 if isinstance (node , ast .Call ):
@@ -245,13 +242,15 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
245242 def _init_function_translation (self ) -> None :
246243 """Initialize self for translating a new (top-level) function."""
247244 self ._outer = []
248- self ._current_fn : Optional [ irbuilder .IRFunction ] = None
245+ self ._current_fn : irbuilder .IRFunction | None = None
249246 self ._nextvar = 0
250247 self ._used_vars = set ()
251- self ._locals : List [ Dict [str , LocalSymValue ]] = [{}]
248+ self ._locals : list [ dict [str , LocalSymValue ]] = [{}]
252249
253250 def _source_of (self , node : ast .AST ) -> sourceinfo .SourceInfo :
254- return sourceinfo .SourceInfo (node , self .source , self ._current_fn .name )
251+ return sourceinfo .SourceInfo (
252+ node , code = self .source , function_name = self ._current_fn .name
253+ )
255254
256255 def _message (self , node : ast .AST , error_msg : str ) -> str :
257256 """Constructs an error _message containing source information about an ast node."""
@@ -287,7 +286,7 @@ def _exit_scope(self) -> irbuilder.IRFunction:
287286 self ._locals .pop (0 )
288287 return graph
289288
290- def _current_scope (self ) -> Dict [str , LocalSymValue ]:
289+ def _current_scope (self ) -> dict [str , LocalSymValue ]:
291290 return self ._locals [0 ]
292291
293292 def _bind (self , name : str , val : LocalSymValue ) -> None :
@@ -337,7 +336,7 @@ def tensor_name_generator() -> str:
337336 return ir .from_proto (proto )
338337
339338 def _to_onnx_attr_ref (
340- self , val : values .AttrRef , info : Optional [ sourceinfo .SourceInfo ]
339+ self , val : values .AttrRef , info : sourceinfo .SourceInfo | None
341340 ) -> ir .Attr :
342341 attrtype = val .value .type
343342 attrname = None
@@ -357,8 +356,8 @@ def _to_onnx_attr_ref(
357356 def _to_onnx_var (
358357 self ,
359358 val : values .SymbolValue | PyValue ,
360- target : Optional [ PreferredName ] = None ,
361- info : Optional [ sourceinfo .SourceInfo ] = None ,
359+ target : PreferredName | None = None ,
360+ info : sourceinfo .SourceInfo | None = None ,
362361 ) -> ir .Value :
363362 if isinstance (val , values .AttrRef ):
364363 # promote attribute to value
@@ -397,8 +396,8 @@ def emit(
397396 self ,
398397 outputs : Sequence [str ],
399398 callee : values .Op | str ,
400- inputs : Sequence [Optional [ ir .Value ] ],
401- attrs : Optional [ Sequence [ir .Attr ]] = None ,
399+ inputs : Sequence [ir .Value | None ],
400+ attrs : Sequence [ir .Attr ] | None = None ,
402401 ) -> Sequence [ir .Value ] | ir .Value :
403402 if not isinstance (callee , values .Op ):
404403 callee = values .Op (self .default_opset , callee )
@@ -416,6 +415,7 @@ def emit(
416415 if not isinstance (callee , values .Op ):
417416 raise TypeError (f"Unexpected type { type (callee )} for callee." )
418417 node .meta .setdefault ("callee" , callee )
418+ assert self ._current_fn is not None
419419 self ._current_fn .append_node (node )
420420
421421 return output_values if len (output_values ) > 1 else output_values [0 ]
@@ -429,7 +429,7 @@ def emit1(self, *args, **kwargs) -> ir.Value:
429429 def _emit_const (
430430 self ,
431431 pyvalue : PyValue ,
432- suggested_name : Optional [ PreferredName ] ,
432+ suggested_name : PreferredName | None ,
433433 info : sourceinfo .SourceInfo ,
434434 ) -> ir .Value :
435435 if suggested_name is None :
@@ -491,10 +491,8 @@ def _eval_constant_expr(self, expr: ast.expr) -> PyValue:
491491 function.)
492492 """
493493 # TODO: assert (self._is_constant_expr(expr))
494- # TODO: Refine types
495494 locals : dict [Any , Any ] = {}
496- expression = ast .Expression (expr )
497- cpl = compile (expression , filename = "<ast>" , mode = "eval" )
495+ cpl = compile (ast .Expression (expr ), filename = "<ast>" , mode = "eval" )
498496 try :
499497 return eval (cpl , self .globals , locals ) # pylint: disable=eval-used
500498 except NameError as e :
@@ -506,7 +504,7 @@ def _eval_constant_expr(self, expr: ast.expr) -> PyValue:
506504 )
507505 ) from e
508506
509- def _get_type_annotation (self , annotation : ast .expr ) -> Optional [ ta .TypeAnnotationValue ] :
507+ def _get_type_annotation (self , annotation : ast .expr ) -> ta .TypeAnnotationValue | None :
510508 typeinfo = self ._eval_constant_expr (annotation )
511509 if not ta .is_valid_type (typeinfo ):
512510 self .warn (
@@ -520,8 +518,8 @@ def _translate_attr(
520518 self ,
521519 attr_name : str ,
522520 expr : ast .AST ,
523- attr_meta : Optional [ onnx .defs .OpSchema .Attribute ] = None ,
524- ) -> Optional [ ir .Attr ] :
521+ attr_meta : onnx .defs .OpSchema .Attribute | None = None ,
522+ ) -> ir .Attr | None :
525523 """Translate an attribute-value specification of the form `attr_name=<expr>`
526524 in a call to an op. expr is an AST. The following cases are supported:
527525 * Expr evaluates to a script-time constant (a python-value) that can be mapped
@@ -597,9 +595,7 @@ def _translate_docstring(self, node: ast.Expr) -> None:
597595 self .fail (node , "Docstring expression must be a constant." )
598596 self ._current_fn .doc_string = node .value .value
599597
600- def _translate_expr (
601- self , node : ast .AST , target : Optional [PreferredName ] = None
602- ) -> ir .Value :
598+ def _translate_expr (self , node : ast .AST , target : PreferredName | None = None ) -> ir .Value :
603599 """Expression-translation generates "IR statements/nodes" that compute the value of
604600 the expression into a target-variable, and returns the variable that is
605601 assigned this value.
@@ -630,7 +626,7 @@ def _translate_expr(
630626 result = self .generate_unique_name (target )
631627 return self .emit1 ([result ], callee , args , attrs )
632628
633- def _translate_opt_expr (self , node : ast .expr ) -> Optional [ ir .Value ] :
629+ def _translate_opt_expr (self , node : ast .expr ) -> ir .Value | None :
634630 """Translation of an expression where "None" is permitted (eg., for an optional argument).
635631 None is represented as a Constant in Python 3.9+.
636632 """
@@ -639,7 +635,7 @@ def _translate_opt_expr(self, node: ast.expr) -> Optional[ir.Value]:
639635 return self ._translate_expr (node )
640636
641637 def _translate_subscript_expr (
642- self , node : ast .Subscript , target : Optional [ PreferredName ]
638+ self , node : ast .Subscript , target : PreferredName | None
643639 ) -> ir .Value :
644640 """List of supported syntaxes is below.
645641 `A` is a tensor or an expression equivalent to a tensor.
@@ -689,7 +685,7 @@ def _translate_subscript_expr(
689685 # TODO: Do this at a graph-scope level.
690686 cached_int_consts : dict [int , ir .Value ] = {}
691687
692- def const_1d (value , name : Optional [ str ] = None ) -> ir .Value :
688+ def const_1d (value , name : str | None = None ) -> ir .Value :
693689 nonlocal cached_int_consts
694690 if value not in cached_int_consts :
695691 cached_int_consts [value ] = self ._emit_const ([value ], name , info )
@@ -703,8 +699,8 @@ def one_1d() -> ir.Value:
703699 minint = - (1 << 63 )
704700
705701 def translate_slice_component (
706- node_arg , default_value : Optional [ int ] = None
707- ) -> tuple [ir .Value , Optional [ int ] ]:
702+ node_arg , default_value : int | None = None
703+ ) -> tuple [ir .Value , int | None ]:
708704 """Translate optional start/stop/step component of a Slice expression."""
709705 if node_arg is None :
710706 if default_value is None :
@@ -758,9 +754,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value
758754
759755 # As the first step, we partition the index elements into four kinds: Slice (eg., 1:5:2),
760756 # known-to-be-scalar (eg., 2), other-tensor (eg., I), skip/no-op (that is, just ":")
761- sliced_indices : List [ Tuple [int , ast .expr ]] = []
762- scalar_indices : List [ Tuple [int , ast .expr ]] = []
763- non_scalar_indices : List [ Tuple [int , ast .expr ]] = []
757+ sliced_indices : list [ tuple [int , ast .expr ]] = []
758+ scalar_indices : list [ tuple [int , ast .expr ]] = []
759+ non_scalar_indices : list [ tuple [int , ast .expr ]] = []
764760 for axis , elt in enumerate (indices ):
765761 if isinstance (elt , ast .Slice ):
766762 # Add to sliced_indices, unless it is "::", which is a no-op.
@@ -872,7 +868,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value
872868
873869 def _translate_call_expr (
874870 self , node : ast .Call
875- ) -> tuple [values .Op , list [Optional [ ir .Value ] ], list [ir .Attr ]]:
871+ ) -> tuple [values .Op , list [ir .Value | None ], list [ir .Attr ]]:
876872 """Translates a call-expression."""
877873 callee = self ._translate_callee_expr (node .func )
878874 param_schemas = callee .param_schemas ()
@@ -932,14 +928,7 @@ def _translate_unary_op_expr(self, node):
932928 # should intercept this call and replace node
933929 # by node.operand.
934930 # This mechanism does not handle somthing like `(-(-5))`.
935- if hasattr (node .operand , "value" ):
936- # python 3.8+
937- val = node .operand .value
938- else :
939- raise TypeError (
940- f"Unable to guess constant value from type { type (node .operand )!r} "
941- f"and attributes { dir (node .operand )!r} ."
942- )
931+ val = node .operand .value
943932 if op == ast .USub :
944933 cst = ast .Constant (- val , lineno = node .lineno , col_offset = node .col_offset )
945934 return self ._translate_expr (cst )
@@ -1042,7 +1031,7 @@ def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None:
10421031 return None
10431032 raise ValueError (self ._message (node , f"Unsupported statement type '{ type (node )!r} '." ))
10441033
1045- def _translate_assign_stmt (self , stmt : Union [ ast .Assign , ast .AnnAssign ] ) -> None :
1034+ def _translate_assign_stmt (self , stmt : ast .Assign | ast .AnnAssign ) -> None :
10461035 def assign (lhs : ast .AST , rhs : ast .AST ) -> None :
10471036 if isinstance (lhs , ast .Name ):
10481037 # Assignments of the form "x = SomeExpression"
@@ -1202,7 +1191,7 @@ def rename(x):
12021191 values .SymbolValue (y , self ._source_of (stmt )),
12031192 )
12041193
1205- def _translate_loop_stmt (self , loop_stmt : Union [ ast .For , ast .While ] ) -> None :
1194+ def _translate_loop_stmt (self , loop_stmt : ast .For | ast .While ) -> None :
12061195 # loop-variable
12071196 if isinstance (loop_stmt , ast .For ):
12081197 if not isinstance (loop_stmt .target , ast .Name ):
0 commit comments