Skip to content

Commit 3690885

Browse files
justinchubyCopilot
andauthored
Fix call to ast.Expression (#2758)
Fix #2114 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent e80ebe8 commit 3690885

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

onnxscript/_internal/converter.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def emit(
398398
outputs: Sequence[str],
399399
callee: values.Op | str,
400400
inputs: Sequence[Optional[ir.Value]],
401-
attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None,
401+
attrs: Optional[Sequence[ir.Attr]] = None,
402402
) -> Sequence[ir.Value] | ir.Value:
403403
if not isinstance(callee, values.Op):
404404
callee = values.Op(self.default_opset, callee)
@@ -480,7 +480,7 @@ def _is_constant_expr(self, node: ast.AST) -> None:
480480
return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node))
481481
return False
482482

483-
def _eval_constant_expr(self, expr: ast.AST) -> PyValue:
483+
def _eval_constant_expr(self, expr: ast.expr) -> PyValue:
484484
"""Evaluates a sub-expression that is assumed to represent a constant value.
485485
The expression can refer only to global names (inherited from the scope
486486
where the script is evaluated) and cannot refer to local names defined
@@ -493,8 +493,8 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue:
493493
# TODO: assert (self._is_constant_expr(expr))
494494
# TODO: Refine types
495495
locals: dict[Any, Any] = {}
496-
expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset)
497-
cpl = compile(expr, filename="<ast>", mode="eval")
496+
expression = ast.Expression(expr)
497+
cpl = compile(expression, filename="<ast>", mode="eval")
498498
try:
499499
return eval(cpl, self.globals, locals) # pylint: disable=eval-used
500500
except NameError as e:
@@ -506,7 +506,7 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue:
506506
)
507507
) from e
508508

509-
def _get_type_annotation(self, annotation: ast.Expr) -> Optional[ta.TypeAnnotationValue]:
509+
def _get_type_annotation(self, annotation: ast.expr) -> Optional[ta.TypeAnnotationValue]:
510510
typeinfo = self._eval_constant_expr(annotation)
511511
if not ta.is_valid_type(typeinfo):
512512
self.warn(
@@ -521,7 +521,7 @@ def _translate_attr(
521521
attr_name: str,
522522
expr: ast.AST,
523523
attr_meta: Optional[onnx.defs.OpSchema.Attribute] = None,
524-
) -> Optional[irbuilder.IRAttributeValue]:
524+
) -> Optional[ir.Attr]:
525525
"""Translate an attribute-value specification of the form `attr_name=<expr>`
526526
in a call to an op. expr is an AST. The following cases are supported:
527527
* Expr evaluates to a script-time constant (a python-value) that can be mapped
@@ -593,13 +593,9 @@ def _translate_attr(
593593
return attr
594594

595595
def _translate_docstring(self, node: ast.Expr) -> None:
596-
if hasattr(node.value, "value"):
597-
# python 3.8+
598-
self._current_fn.doc_string = node.value.value
599-
else:
600-
raise TypeError(
601-
f"Unexpected type {type(node)!r} for node. Unsupoorted version of python."
602-
)
596+
if not isinstance(node.value, ast.Constant):
597+
self.fail(node, "Docstring expression must be a constant.")
598+
self._current_fn.doc_string = node.value.value
603599

604600
def _translate_expr(
605601
self, node: ast.AST, target: Optional[PreferredName] = None
@@ -876,7 +872,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value
876872

877873
def _translate_call_expr(
878874
self, node: ast.Call
879-
) -> tuple[values.Op, list[Optional[ir.Value]], list[irbuilder.IRAttributeValue]]:
875+
) -> tuple[values.Op, list[Optional[ir.Value]], list[ir.Attr]]:
880876
"""Translates a call-expression."""
881877
callee = self._translate_callee_expr(node.func)
882878
param_schemas = callee.param_schemas()

0 commit comments

Comments
 (0)