@@ -1206,7 +1206,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
12061206 self .fail (loop_stmt , "Unsupported loop bound, it should be 'range(?)'." )
12071207 assert not iter .keywords , "Unsupported loop bound."
12081208 o_loop_bound = self ._translate_expr (iter .args [0 ], "loop_bound" )
1209- onnx_cond_var = ir .Value (name = self .generate_unique_name ("cond_in" )) # TODO(Rama)
1209+ onnx_cond_var = make_value (
1210+ self .generate_unique_name ("cond_in" ),
1211+ onnx_types .BOOL ,
1212+ self ._source_of (loop_stmt ),
1213+ )
12101214 i_cond_var = onnx_cond_var
12111215 cond_while = None
12121216 o_loop_condition = None # No condition for a for loop.
@@ -1220,8 +1224,12 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
12201224 )
12211225 python_loop_var_name = "infinite_loop"
12221226 o_loop_bound = None
1223- i_cond_var = ir .Value (name = test .id ) # TODO(Rama)
1224- cond_while = ir .Value (name = test .id ) # TODO(Rama)
1227+ i_cond_var = make_value (
1228+ self .generate_unique_name (test .id ),
1229+ onnx_types .BOOL ,
1230+ self ._source_of (loop_stmt ),
1231+ )
1232+ cond_while = test .id
12251233 onnx_cond_var = None
12261234 o_loop_condition = self ._translate_name_expr (test )
12271235 # we need to go through all the instructions to see
@@ -1254,20 +1262,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
12541262 values .SymbolValue (onnx_loop_var , self ._source_of (loop_stmt )),
12551263 )
12561264
1257- self ._current_fn .append_parameter (
1258- make_value (
1259- i_cond_var .name ,
1260- onnx_types .BOOL ,
1261- self ._source_of (loop_stmt ),
1262- )
1263- )
1265+ self ._current_fn .append_parameter (i_cond_var )
12641266
12651267 for pv in loop_state_vars :
12661268 onnx_var_name = self .generate_unique_name (pv )
1267- # TODO: retrieve the annotation for variable pv is any is specified.
1268- # typeinfo = self._eval_constant_expr(pv.annotation)
1269- typeinfo = None
1270- parameter = make_value (onnx_var_name , typeinfo , self ._source_of (loop_stmt ))
1269+ parameter = make_value (onnx_var_name , None , self ._source_of (loop_stmt ))
12711270 self ._current_fn .append_parameter (parameter )
12721271 self ._bind (
12731272 pv ,
@@ -1306,33 +1305,25 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
13061305 continue
13071306 self ._translate_stmt (s )
13081307
1309- onnx_cond_out_name = self .generate_unique_name ("cond_out" )
1310-
13111308 if cond_while is not None :
13121309 # Loop while
13131310 current_scope = self ._current_scope ()
1314- if cond_while . name not in current_scope :
1311+ if cond_while not in current_scope :
13151312 self .fail (
13161313 loop_stmt ,
1317- f"Unable to find condition variable { cond_while . name } in known "
1314+ f"Unable to find condition variable { cond_while } in known "
13181315 f"variables { list (current_scope )!r} ." ,
13191316 )
1320- onnx_cond_var = current_scope [cond_while . name ].value
1317+ onnx_cond_var = current_scope [cond_while ].value
13211318
1322- self .emit (
1323- [onnx_cond_out_name ],
1319+ cond_out = self .emit1 (
1320+ [self . generate_unique_name ( "cond_out" ) ],
13241321 values .Op (self .default_opset , operator_name ),
13251322 [condition_name or onnx_cond_var ],
13261323 [],
13271324 )
1325+ self ._current_fn .outputs .append (cond_out )
13281326
1329- self ._current_fn .outputs .append (
1330- make_value (
1331- onnx_cond_out_name ,
1332- onnx_types .BOOL ,
1333- self ._source_of (loop_stmt ),
1334- )
1335- )
13361327 for pv in loop_state_vars :
13371328 onnx_var = self ._py_var_to_onnx_var (pv , self ._source_of (loop_stmt ))
13381329 if onnx_var .name not in self ._current_fn .assigned_names :
@@ -1342,11 +1333,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
13421333 # In this case, we create a copy of y, treating the statement as
13431334 # shorthand for "x = op.Identity(y)".
13441335 onnx_var = self ._emit_copy (onnx_var , pv )
1345- # TODO: retrieve variable type for the annotation if any.
1346- typeinfo = None
1347- self ._current_fn .outputs .append (
1348- make_value (onnx_var .name , typeinfo , self ._source_of (loop_stmt ))
1349- )
1336+ self ._current_fn .outputs .append (onnx_var )
13501337 body = self ._exit_scope ()
13511338 inputs = [o_loop_bound , o_loop_condition ] + [
13521339 self ._py_var_to_onnx_var (pv , self ._source_of (loop_stmt )) for pv in loop_state_vars
0 commit comments