Skip to content

Commit 1080d6a

Browse files
authored
Fix loop body value creation (#2777)
Create and use ir.Value in loop subgraph inputs and bodies to avoid unintentional duplication of ir.Values (which led to invalid IR graphs). --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 7ec5d25 commit 1080d6a

File tree

1 file changed

+20
-33
lines changed

1 file changed

+20
-33
lines changed

onnxscript/_internal/converter.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
12061206
self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.")
12071207
assert not iter.keywords, "Unsupported loop bound."
12081208
o_loop_bound = self._translate_expr(iter.args[0], "loop_bound")
1209-
onnx_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama)
1209+
onnx_cond_var = make_value(
1210+
self.generate_unique_name("cond_in"),
1211+
onnx_types.BOOL,
1212+
self._source_of(loop_stmt),
1213+
)
12101214
i_cond_var = onnx_cond_var
12111215
cond_while = None
12121216
o_loop_condition = None # No condition for a for loop.
@@ -1220,8 +1224,12 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
12201224
)
12211225
python_loop_var_name = "infinite_loop"
12221226
o_loop_bound = None
1223-
i_cond_var = ir.Value(name=test.id) # TODO(Rama)
1224-
cond_while = ir.Value(name=test.id) # TODO(Rama)
1227+
i_cond_var = make_value(
1228+
self.generate_unique_name(test.id),
1229+
onnx_types.BOOL,
1230+
self._source_of(loop_stmt),
1231+
)
1232+
cond_while = test.id
12251233
onnx_cond_var = None
12261234
o_loop_condition = self._translate_name_expr(test)
12271235
# we need to go through all the instructions to see
@@ -1254,20 +1262,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
12541262
values.SymbolValue(onnx_loop_var, self._source_of(loop_stmt)),
12551263
)
12561264

1257-
self._current_fn.append_parameter(
1258-
make_value(
1259-
i_cond_var.name,
1260-
onnx_types.BOOL,
1261-
self._source_of(loop_stmt),
1262-
)
1263-
)
1265+
self._current_fn.append_parameter(i_cond_var)
12641266

12651267
for pv in loop_state_vars:
12661268
onnx_var_name = self.generate_unique_name(pv)
1267-
# TODO: retrieve the annotation for variable pv is any is specified.
1268-
# typeinfo = self._eval_constant_expr(pv.annotation)
1269-
typeinfo = None
1270-
parameter = make_value(onnx_var_name, typeinfo, self._source_of(loop_stmt))
1269+
parameter = make_value(onnx_var_name, None, self._source_of(loop_stmt))
12711270
self._current_fn.append_parameter(parameter)
12721271
self._bind(
12731272
pv,
@@ -1306,33 +1305,25 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
13061305
continue
13071306
self._translate_stmt(s)
13081307

1309-
onnx_cond_out_name = self.generate_unique_name("cond_out")
1310-
13111308
if cond_while is not None:
13121309
# Loop while
13131310
current_scope = self._current_scope()
1314-
if cond_while.name not in current_scope:
1311+
if cond_while not in current_scope:
13151312
self.fail(
13161313
loop_stmt,
1317-
f"Unable to find condition variable {cond_while.name} in known "
1314+
f"Unable to find condition variable {cond_while} in known "
13181315
f"variables {list(current_scope)!r}.",
13191316
)
1320-
onnx_cond_var = current_scope[cond_while.name].value
1317+
onnx_cond_var = current_scope[cond_while].value
13211318

1322-
self.emit(
1323-
[onnx_cond_out_name],
1319+
cond_out = self.emit1(
1320+
[self.generate_unique_name("cond_out")],
13241321
values.Op(self.default_opset, operator_name),
13251322
[condition_name or onnx_cond_var],
13261323
[],
13271324
)
1325+
self._current_fn.outputs.append(cond_out)
13281326

1329-
self._current_fn.outputs.append(
1330-
make_value(
1331-
onnx_cond_out_name,
1332-
onnx_types.BOOL,
1333-
self._source_of(loop_stmt),
1334-
)
1335-
)
13361327
for pv in loop_state_vars:
13371328
onnx_var = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt))
13381329
if onnx_var.name not in self._current_fn.assigned_names:
@@ -1342,11 +1333,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
13421333
# In this case, we create a copy of y, treating the statement as
13431334
# shorthand for "x = op.Identity(y)".
13441335
onnx_var = self._emit_copy(onnx_var, pv)
1345-
# TODO: retrieve variable type for the annotation if any.
1346-
typeinfo = None
1347-
self._current_fn.outputs.append(
1348-
make_value(onnx_var.name, typeinfo, self._source_of(loop_stmt))
1349-
)
1336+
self._current_fn.outputs.append(onnx_var)
13501337
body = self._exit_scope()
13511338
inputs = [o_loop_bound, o_loop_condition] + [
13521339
self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars

0 commit comments

Comments
 (0)