@@ -398,7 +398,7 @@ def emit(
398398 outputs : Sequence [str ],
399399 callee : values .Op | str ,
400400 inputs : Sequence [Optional [ir .Value ]],
401- attrs : Optional [Sequence [irbuilder . IRAttributeValue ]] = None ,
401+ attrs : Optional [Sequence [ir . Attr ]] = None ,
402402 ) -> Sequence [ir .Value ] | ir .Value :
403403 if not isinstance (callee , values .Op ):
404404 callee = values .Op (self .default_opset , callee )
@@ -480,7 +480,7 @@ def _is_constant_expr(self, node: ast.AST) -> None:
480480 return all (self ._is_constant_expr (c ) for c in ast .iter_child_nodes (node ))
481481 return False
482482
483- def _eval_constant_expr (self , expr : ast .AST ) -> PyValue :
483+ def _eval_constant_expr (self , expr : ast .expr ) -> PyValue :
484484 """Evaluates a sub-expression that is assumed to represent a constant value.
485485 The expression can refer only to global names (inherited from the scope
486486 where the script is evaluated) and cannot refer to local names defined
@@ -493,8 +493,8 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue:
493493 # TODO: assert (self._is_constant_expr(expr))
494494 # TODO: Refine types
495495 locals : dict [Any , Any ] = {}
496- expr = ast .Expression (expr , lineno = expr . lineno , col_offset = expr . col_offset )
497- cpl = compile (expr , filename = "<ast>" , mode = "eval" )
496+ expression = ast .Expression (expr )
497+ cpl = compile (expression , filename = "<ast>" , mode = "eval" )
498498 try :
499499 return eval (cpl , self .globals , locals ) # pylint: disable=eval-used
500500 except NameError as e :
@@ -506,7 +506,7 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue:
506506 )
507507 ) from e
508508
509- def _get_type_annotation (self , annotation : ast .Expr ) -> Optional [ta .TypeAnnotationValue ]:
509+ def _get_type_annotation (self , annotation : ast .expr ) -> Optional [ta .TypeAnnotationValue ]:
510510 typeinfo = self ._eval_constant_expr (annotation )
511511 if not ta .is_valid_type (typeinfo ):
512512 self .warn (
@@ -521,7 +521,7 @@ def _translate_attr(
521521 attr_name : str ,
522522 expr : ast .AST ,
523523 attr_meta : Optional [onnx .defs .OpSchema .Attribute ] = None ,
524- ) -> Optional [irbuilder . IRAttributeValue ]:
524+ ) -> Optional [ir . Attr ]:
525525 """Translate an attribute-value specification of the form `attr_name=<expr>`
526526 in a call to an op. expr is an AST. The following cases are supported:
527527 * Expr evaluates to a script-time constant (a python-value) that can be mapped
@@ -593,13 +593,9 @@ def _translate_attr(
593593 return attr
594594
595595 def _translate_docstring (self , node : ast .Expr ) -> None :
596- if hasattr (node .value , "value" ):
597- # python 3.8+
598- self ._current_fn .doc_string = node .value .value
599- else :
600- raise TypeError (
601- f"Unexpected type { type (node )!r} for node. Unsupoorted version of python."
602- )
596+ if not isinstance (node .value , ast .Constant ):
597+ self .fail (node , "Docstring expression must be a constant." )
598+ self ._current_fn .doc_string = node .value .value
603599
604600 def _translate_expr (
605601 self , node : ast .AST , target : Optional [PreferredName ] = None
@@ -876,7 +872,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value
876872
877873 def _translate_call_expr (
878874 self , node : ast .Call
879- ) -> tuple [values .Op , list [Optional [ir .Value ]], list [irbuilder . IRAttributeValue ]]:
875+ ) -> tuple [values .Op , list [Optional [ir .Value ]], list [ir . Attr ]]:
880876 """Translates a call-expression."""
881877 callee = self ._translate_callee_expr (node .func )
882878 param_schemas = callee .param_schemas ()
0 commit comments