Skip to content

Commit 8c6084c

Browse files
Minor cleanup of onnxscript converter (#2748)
* Add some missing documentation and type annotation. * Cleanup some variable names * Cleanup some ir.Value creation --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 70abdd9 commit 8c6084c

File tree

4 files changed

+84
-71
lines changed

4 files changed

+84
-71
lines changed

onnxscript/_internal/analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
144144
error_message = self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
145145
raise ValueError(error_message)
146146

147-
def do_liveness_analysis(self, fun: ast.FunctionDef):
147+
def do_liveness_analysis(self, fun: ast.FunctionDef) -> None:
148148
"""Perform liveness analysis of the given function-ast."""
149149

150150
def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
@@ -212,7 +212,7 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
212212
for s in reversed(fun.body):
213213
live = visit(s, live)
214214

215-
def exposed_uses(self, stmts: Sequence[ast.stmt]):
215+
def exposed_uses(self, stmts: Sequence[ast.stmt]) -> set[str]:
216216
"""Return the set of variables that are used before being defined by given block.
217217
In essence, this identifies the "inputs" to a given code-block.
218218
For example, consider the following code-block:
@@ -284,7 +284,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
284284

285285
return visitBlock(stmts, set())
286286

287-
def outer_scope_variables(self, fun: ast.FunctionDef):
287+
def outer_scope_variables(self, fun: ast.FunctionDef) -> set[str]:
288288
"""Return the set of outer-scope variables used in a nested function.
289289
290290
Args:

onnxscript/_internal/converter.py

Lines changed: 60 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ def _bind(self, name: str, val: LocalSymValue) -> None:
297297
def _lookup(
298298
self, name: str, info: sourceinfo.SourceInfo, raise_exception: bool = True
299299
) -> SymValue:
300+
"""Maps a python variable name to the corresponding value used during translation.
301+
302+
Typically, a python variable X will correspond to an ONNX value Y. But other special
303+
cases include: constant values or functions (mapped to Graph attributes), etc.
304+
"""
305+
300306
for scope in self._locals:
301307
if name in scope:
302308
return scope[name]
@@ -1158,13 +1164,9 @@ def _translate_if_stmt(self, stmt: ast.If) -> None:
11581164
live_defs = list(live_def_set)
11591165
test = self._translate_expr(stmt.test, "cond")
11601166
lineno = self._source_of(stmt).lineno
1161-
thenGraph = self._translate_block(
1162-
stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt
1163-
)
1167+
thenGraph = self._translate_block(stmt.body, f"thenGraph_{lineno}", live_defs)
11641168
thenAttr = self._make_onnx_attr("then_branch", thenGraph)
1165-
elseGraph = self._translate_block(
1166-
stmt.orelse, f"elseGraph_{lineno}", live_defs, parent_stmt=stmt
1167-
)
1169+
elseGraph = self._translate_block(stmt.orelse, f"elseGraph_{lineno}", live_defs)
11681170
elseAttr = self._make_onnx_attr("else_branch", elseGraph)
11691171

11701172
def rename(x):
@@ -1196,7 +1198,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11961198
if isinstance(loop_stmt, ast.For):
11971199
if not isinstance(loop_stmt.target, ast.Name):
11981200
self.fail(loop_stmt, "For loop target must be a single variable.")
1199-
p_loop_var = loop_stmt.target.id
1201+
python_loop_var_name = loop_stmt.target.id
12001202
# iter
12011203
iter = loop_stmt.iter
12021204
assert isinstance(iter, ast.Call), "Loop bound not a call."
@@ -1210,8 +1212,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12101212
self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.")
12111213
assert not iter.keywords, "Unsupported loop bound."
12121214
o_loop_bound = self._translate_expr(iter.args[0], "loop_bound")
1213-
o_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama)
1214-
i_cond_var = o_cond_var
1215+
onnx_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama)
1216+
i_cond_var = onnx_cond_var
12151217
cond_while = None
12161218
o_loop_condition = None # No condition for a for loop.
12171219
elif isinstance(loop_stmt, ast.While):
@@ -1222,11 +1224,11 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12221224
"Unexpected condition type {type(loop_stmt)!r} for a while loop, "
12231225
"it should be 'while <condition_name>:'.",
12241226
)
1225-
p_loop_var = "infinite_loop"
1227+
python_loop_var_name = "infinite_loop"
12261228
o_loop_bound = None
12271229
i_cond_var = ir.Value(name=test.id) # TODO(Rama)
12281230
cond_while = ir.Value(name=test.id) # TODO(Rama)
1229-
o_cond_var = None
1231+
onnx_cond_var = None
12301232
o_loop_condition = self._translate_name_expr(test)
12311233
# we need to go through all the instructions to see
12321234
# which instruction defines the condition test.id
@@ -1246,19 +1248,16 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12461248

12471249
# build loop_body
12481250
self._enter_scope("loop_body", loop_stmt)
1249-
o_loop_var = self.generate_unique_name(p_loop_var)
1250-
self._current_fn.append_parameter(
1251-
make_value(
1252-
o_loop_var,
1253-
onnx_types.INT64,
1254-
self._source_of(loop_stmt),
1255-
)
1251+
onnx_loop_var_name = self.generate_unique_name(python_loop_var_name)
1252+
onnx_loop_var = make_value(
1253+
onnx_loop_var_name,
1254+
onnx_types.INT64,
1255+
self._source_of(loop_stmt),
12561256
)
1257+
self._current_fn.append_parameter(onnx_loop_var)
12571258
self._bind(
1258-
p_loop_var,
1259-
values.Dynamic(
1260-
ir.Value(name=o_loop_var), values.DynamicKind.Loop, self._source_of(loop_stmt)
1261-
),
1259+
python_loop_var_name,
1260+
values.Dynamic(onnx_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)),
12621261
)
12631262

12641263
self._current_fn.append_parameter(
@@ -1270,17 +1269,19 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12701269
)
12711270

12721271
for pv in loop_state_vars:
1273-
ov = self.generate_unique_name(pv)
1272+
onnx_var_name = self.generate_unique_name(pv)
12741273
# TODO: retrieve the annotation for variable pv is any is specified.
12751274
# typeinfo = self._eval_constant_expr(pv.annotation)
12761275
typeinfo = None
12771276
self._current_fn.append_parameter(
1278-
make_value(ov, typeinfo, self._source_of(loop_stmt))
1277+
make_value(onnx_var_name, typeinfo, self._source_of(loop_stmt))
12791278
)
12801279
self._bind(
12811280
pv,
12821281
values.Dynamic(
1283-
ir.Value(name=ov), values.DynamicKind.Loop, self._source_of(loop_stmt)
1282+
ir.Value(name=onnx_var_name),
1283+
values.DynamicKind.Loop,
1284+
self._source_of(loop_stmt),
12841285
),
12851286
)
12861287

@@ -1313,7 +1314,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13131314
continue
13141315
self._translate_stmt(s)
13151316

1316-
o_cond_out = self.generate_unique_name("cond_out")
1317+
onnx_cond_out_name = self.generate_unique_name("cond_out")
13171318

13181319
if cond_while is not None:
13191320
# Loop while
@@ -1324,35 +1325,35 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13241325
f"Unable to find condition variable {cond_while.name} in known "
13251326
f"variables {list(current_scope)!r}.",
13261327
)
1327-
o_cond_var = current_scope[cond_while.name].value
1328+
onnx_cond_var = current_scope[cond_while.name].value
13281329

13291330
self.emit(
1330-
[o_cond_out],
1331+
[onnx_cond_out_name],
13311332
values.Op(self.default_opset, operator_name),
1332-
[condition_name or o_cond_var],
1333+
[condition_name or onnx_cond_var],
13331334
[],
13341335
)
13351336

13361337
self._current_fn.outputs.append(
13371338
make_value(
1338-
o_cond_out,
1339+
onnx_cond_out_name,
13391340
onnx_types.BOOL,
13401341
self._source_of(loop_stmt),
13411342
)
13421343
)
13431344
for pv in loop_state_vars:
1344-
ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt))
1345-
if ov.name not in self._current_fn.assigned_names:
1345+
onnx_var = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt))
1346+
if onnx_var.name not in self._current_fn.assigned_names:
13461347
# When converting the loop-body into a graph, we need to handle
13471348
# identity assignments of the form "x = y" inside the loop body
13481349
# specially if y represents a value computed outside the loop body.
13491350
# In this case, we create a copy of y, treating the statement as
13501351
# shorthand for "x = op.Identity(y)".
1351-
ov = self._emit_copy(ov, pv)
1352+
onnx_var = self._emit_copy(onnx_var, pv)
13521353
# TODO: retrieve variable type for the annotation if any.
13531354
typeinfo = None
13541355
self._current_fn.outputs.append(
1355-
make_value(ov.name, typeinfo, self._source_of(loop_stmt))
1356+
make_value(onnx_var.name, typeinfo, self._source_of(loop_stmt))
13561357
)
13571358
body = self._exit_scope()
13581359
inputs = [o_loop_bound, o_loop_condition] + [
@@ -1382,49 +1383,44 @@ def _translate_block(
13821383
stmts: Sequence[ast.stmt],
13831384
name: str,
13841385
live_defs: Sequence[str],
1385-
parent_stmt: ast.stmt,
1386-
):
1387-
"""Translation of a statement-block to GraphProto attribute."""
1388-
info_stmt = stmts[0] if len(stmts) > 0 else parent_stmt
1389-
source = self._source_of(info_stmt)
1386+
) -> ir.Graph:
1387+
"""Translation of a then/else statement-block to an ir.Graph."""
13901388
self._enter_scope(name, None)
13911389
for s in stmts:
13921390
self._translate_stmt(s)
1393-
for pvar in live_defs:
1394-
if pvar in self._current_scope():
1395-
pv_val = self._current_scope()[pvar]
1396-
output = self._to_onnx_var(pv_val, pvar)
1391+
for python_var in live_defs:
1392+
if python_var in self._current_scope():
1393+
python_var_value = self._current_scope()[python_var]
1394+
output = self._to_onnx_var(python_var_value, python_var)
13971395
if output.name not in self._current_fn.assigned_names:
1396+
# TODO (Rama): Unclear how this can happen. If python_var is in current_scope,
1397+
# then it should have been assigned a value in the current graph.
1398+
#
13981399
# To return an outer-scope variable, an ONNX Graph has to
13991400
# use an explicit copy via Identity.
1400-
output = self._emit_copy(output, pvar)
1401-
self._current_fn.outputs.append(
1402-
make_value(
1403-
output.name,
1404-
pv_val.typeinfo,
1405-
source,
1406-
)
1407-
)
1401+
output = self._emit_copy(output, python_var)
1402+
self._current_fn.outputs.append(output)
14081403
else:
1409-
pv_val = None
1404+
python_var_value = None
14101405
for scope in self._locals: # TODO: skip _current_scope
1411-
if pvar in scope:
1412-
pv_val = scope[pvar]
1406+
if python_var in scope:
1407+
python_var_value = scope[python_var]
14131408
break
1414-
if pv_val is None:
1409+
if python_var_value is None:
14151410
self.fail(
14161411
stmts[0],
1417-
f"ir.Value {pvar} is not assigned a value along a conditional "
1412+
f"ir.Value {python_var} is not assigned a value along a conditional "
14181413
f"branch, known variables: {list(self._locals)}.",
14191414
)
14201415
# introduce a copy
1421-
ovar = self._emit_copy(self._to_onnx_var(pv_val, pvar), pvar)
1416+
output = self._emit_copy(
1417+
self._to_onnx_var(python_var_value, python_var), python_var
1418+
)
14221419

14231420
# TODO: retrieve the annotation if any.
1424-
typeinfo = None
1425-
self._current_fn.outputs.append(make_value(ovar.name, typeinfo, source))
1426-
graph = self._exit_scope()
1427-
return graph.graph
1421+
self._current_fn.outputs.append(output)
1422+
function_ir = self._exit_scope()
1423+
return function_ir.graph
14281424

14291425
def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
14301426
"""Translate a nested function definition."""
@@ -1465,14 +1461,13 @@ def _translate_function_signature_common(
14651461
self._current_fn.append_parameter(attr)
14661462
self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x)))
14671463
else:
1468-
self._current_fn.append_parameter(
1469-
make_value(x.arg, typeinfo, self._source_of(x))
1470-
)
1464+
onnx_parameter = make_value(x.arg, typeinfo, self._source_of(x))
1465+
self._current_fn.append_parameter(onnx_parameter)
14711466
self._used_vars.add(x.arg)
14721467
self._bind(
14731468
x.arg,
14741469
values.Dynamic(
1475-
ir.Value(name=x.arg), values.DynamicKind.Input, self._source_of(x)
1470+
onnx_parameter, values.DynamicKind.Input, self._source_of(x)
14761471
),
14771472
)
14781473
if fn.returns:

onnxscript/_internal/irbuilder.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,18 @@ def __init__(self, name: str, domain: str = "") -> None:
2424
super().__init__(domain, name, graph=graph, attributes=[])
2525
self.ordered_inputs_and_attrs: list[Union[ir.Value, ir.Attr]] = []
2626

27-
# a dictionary of nested function-definitions
27+
# A dictionary of nested function-definitions: when an onnxscript function outer_f
28+
# is translated, and it contains a nested function inner_f, then the inner function
29+
# is translated and stored here. It will be used in any subsequent concrete execution
30+
# of outer_f. Such nested functions are used in two different ways: it can be converted
31+
# into a GraphProto to be stored as a graph-valued attribute of a node; alternatively,
32+
# in a python-based execution mode, it can be called as a python function. It serves
33+
# to enable a python-based debugging experience for higher-order functions such as Scan
34+
# and SequenceMap.
2835
self.nested_functions: dict[str, IRFunction] = {}
36+
37+
# For nested functions, this dictionary maps outer-scope (python) variable names
38+
# to their corresponding translated values.
2939
self.outer_scope_variables: dict[Any, Any] = {}
3040

3141
@property

onnxscript/_internal/values.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import typing
1414
from enum import IntFlag
1515
from typing import ( # type: ignore[attr-defined]
16+
TYPE_CHECKING,
1617
Any,
1718
Callable,
1819
ClassVar,
@@ -35,6 +36,9 @@
3536
from onnxscript.ir import _schemas
3637
from onnxscript.onnx_types import ONNXType
3738

39+
if TYPE_CHECKING:
40+
from onnxscript._internal.type_annotation import TypeAnnotationValue
41+
3842
_R = TypeVar("_R")
3943
_P = ParamSpec("_P")
4044

@@ -886,9 +890,13 @@ class DynamicKind(IntFlag):
886890

887891
class Dynamic(SymbolValue):
888892
def __init__(
889-
self, onnx_var: ir.Value, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None
893+
self,
894+
onnx_var: ir.Value,
895+
kind: DynamicKind,
896+
info: sourceinfo.SourceInfo,
897+
typeinfo: TypeAnnotationValue | None = None,
890898
) -> None:
891-
"""Initializes Dynamic.
899+
"""Represents an ir.Value with some extra information.
892900
893901
Arguments:
894902
onnx_var: the name of the ONNX variable used to represent this value

0 commit comments

Comments
 (0)