@@ -297,6 +297,12 @@ def _bind(self, name: str, val: LocalSymValue) -> None:
297297 def _lookup (
298298 self , name : str , info : sourceinfo .SourceInfo , raise_exception : bool = True
299299 ) -> SymValue :
300+ """Maps a python variable name to the corresponding value used during translation.
301+
302+ Typically, a python variable X will correspond to an ONNX value Y. But other special
303+ cases include: constant values or functions (mapped to Graph attributes), etc.
304+ """
305+
300306 for scope in self ._locals :
301307 if name in scope :
302308 return scope [name ]
@@ -1158,13 +1164,9 @@ def _translate_if_stmt(self, stmt: ast.If) -> None:
11581164 live_defs = list (live_def_set )
11591165 test = self ._translate_expr (stmt .test , "cond" )
11601166 lineno = self ._source_of (stmt ).lineno
1161- thenGraph = self ._translate_block (
1162- stmt .body , f"thenGraph_{ lineno } " , live_defs , parent_stmt = stmt
1163- )
1167+ thenGraph = self ._translate_block (stmt .body , f"thenGraph_{ lineno } " , live_defs )
11641168 thenAttr = self ._make_onnx_attr ("then_branch" , thenGraph )
1165- elseGraph = self ._translate_block (
1166- stmt .orelse , f"elseGraph_{ lineno } " , live_defs , parent_stmt = stmt
1167- )
1169+ elseGraph = self ._translate_block (stmt .orelse , f"elseGraph_{ lineno } " , live_defs )
11681170 elseAttr = self ._make_onnx_attr ("else_branch" , elseGraph )
11691171
11701172 def rename (x ):
@@ -1196,7 +1198,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11961198 if isinstance (loop_stmt , ast .For ):
11971199 if not isinstance (loop_stmt .target , ast .Name ):
11981200 self .fail (loop_stmt , "For loop target must be a single variable." )
1199- p_loop_var = loop_stmt .target .id
1201+ python_loop_var_name = loop_stmt .target .id
12001202 # iter
12011203 iter = loop_stmt .iter
12021204 assert isinstance (iter , ast .Call ), "Loop bound not a call."
@@ -1210,8 +1212,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12101212 self .fail (loop_stmt , "Unsupported loop bound, it should be 'range(?)'." )
12111213 assert not iter .keywords , "Unsupported loop bound."
12121214 o_loop_bound = self ._translate_expr (iter .args [0 ], "loop_bound" )
1213- o_cond_var = ir .Value (name = self .generate_unique_name ("cond_in" )) # TODO(Rama)
1214- i_cond_var = o_cond_var
1215+ onnx_cond_var = ir .Value (name = self .generate_unique_name ("cond_in" )) # TODO(Rama)
1216+ i_cond_var = onnx_cond_var
12151217 cond_while = None
12161218 o_loop_condition = None # No condition for a for loop.
12171219 elif isinstance (loop_stmt , ast .While ):
@@ -1222,11 +1224,11 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12221224 "Unexpected condition type {type(loop_stmt)!r} for a while loop, "
12231225 "it should be 'while <condition_name>:'." ,
12241226 )
1225- p_loop_var = "infinite_loop"
1227+ python_loop_var_name = "infinite_loop"
12261228 o_loop_bound = None
12271229 i_cond_var = ir .Value (name = test .id ) # TODO(Rama)
12281230 cond_while = ir .Value (name = test .id ) # TODO(Rama)
1229- o_cond_var = None
1231+ onnx_cond_var = None
12301232 o_loop_condition = self ._translate_name_expr (test )
12311233 # we need to go through all the instructions to see
12321234 # which instruction defines the condition test.id
@@ -1246,19 +1248,16 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12461248
12471249 # build loop_body
12481250 self ._enter_scope ("loop_body" , loop_stmt )
1249- o_loop_var = self .generate_unique_name (p_loop_var )
1250- self ._current_fn .append_parameter (
1251- make_value (
1252- o_loop_var ,
1253- onnx_types .INT64 ,
1254- self ._source_of (loop_stmt ),
1255- )
1251+ onnx_loop_var_name = self .generate_unique_name (python_loop_var_name )
1252+ onnx_loop_var = make_value (
1253+ onnx_loop_var_name ,
1254+ onnx_types .INT64 ,
1255+ self ._source_of (loop_stmt ),
12561256 )
1257+ self ._current_fn .append_parameter (onnx_loop_var )
12571258 self ._bind (
1258- p_loop_var ,
1259- values .Dynamic (
1260- ir .Value (name = o_loop_var ), values .DynamicKind .Loop , self ._source_of (loop_stmt )
1261- ),
1259+ python_loop_var_name ,
1260+ values .Dynamic (onnx_loop_var , values .DynamicKind .Loop , self ._source_of (loop_stmt )),
12621261 )
12631262
12641263 self ._current_fn .append_parameter (
@@ -1270,17 +1269,19 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12701269 )
12711270
12721271 for pv in loop_state_vars :
1273- ov = self .generate_unique_name (pv )
1272+ onnx_var_name = self .generate_unique_name (pv )
12741273 # TODO: retrieve the annotation for variable pv is any is specified.
12751274 # typeinfo = self._eval_constant_expr(pv.annotation)
12761275 typeinfo = None
12771276 self ._current_fn .append_parameter (
1278- make_value (ov , typeinfo , self ._source_of (loop_stmt ))
1277+ make_value (onnx_var_name , typeinfo , self ._source_of (loop_stmt ))
12791278 )
12801279 self ._bind (
12811280 pv ,
12821281 values .Dynamic (
1283- ir .Value (name = ov ), values .DynamicKind .Loop , self ._source_of (loop_stmt )
1282+ ir .Value (name = onnx_var_name ),
1283+ values .DynamicKind .Loop ,
1284+ self ._source_of (loop_stmt ),
12841285 ),
12851286 )
12861287
@@ -1313,7 +1314,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13131314 continue
13141315 self ._translate_stmt (s )
13151316
1316- o_cond_out = self .generate_unique_name ("cond_out" )
1317+ onnx_cond_out_name = self .generate_unique_name ("cond_out" )
13171318
13181319 if cond_while is not None :
13191320 # Loop while
@@ -1324,35 +1325,35 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13241325 f"Unable to find condition variable { cond_while .name } in known "
13251326 f"variables { list (current_scope )!r} ." ,
13261327 )
1327- o_cond_var = current_scope [cond_while .name ].value
1328+ onnx_cond_var = current_scope [cond_while .name ].value
13281329
13291330 self .emit (
1330- [o_cond_out ],
1331+ [onnx_cond_out_name ],
13311332 values .Op (self .default_opset , operator_name ),
1332- [condition_name or o_cond_var ],
1333+ [condition_name or onnx_cond_var ],
13331334 [],
13341335 )
13351336
13361337 self ._current_fn .outputs .append (
13371338 make_value (
1338- o_cond_out ,
1339+ onnx_cond_out_name ,
13391340 onnx_types .BOOL ,
13401341 self ._source_of (loop_stmt ),
13411342 )
13421343 )
13431344 for pv in loop_state_vars :
1344- ov = self ._py_var_to_onnx_var (pv , self ._source_of (loop_stmt ))
1345- if ov .name not in self ._current_fn .assigned_names :
1345+ onnx_var = self ._py_var_to_onnx_var (pv , self ._source_of (loop_stmt ))
1346+ if onnx_var .name not in self ._current_fn .assigned_names :
13461347 # When converting the loop-body into a graph, we need to handle
13471348 # identity assignments of the form "x = y" inside the loop body
13481349 # specially if y represents a value computed outside the loop body.
13491350 # In this case, we create a copy of y, treating the statement as
13501351 # shorthand for "x = op.Identity(y)".
1351- ov = self ._emit_copy (ov , pv )
1352+ onnx_var = self ._emit_copy (onnx_var , pv )
13521353 # TODO: retrieve variable type for the annotation if any.
13531354 typeinfo = None
13541355 self ._current_fn .outputs .append (
1355- make_value (ov .name , typeinfo , self ._source_of (loop_stmt ))
1356+ make_value (onnx_var .name , typeinfo , self ._source_of (loop_stmt ))
13561357 )
13571358 body = self ._exit_scope ()
13581359 inputs = [o_loop_bound , o_loop_condition ] + [
@@ -1382,49 +1383,44 @@ def _translate_block(
13821383 stmts : Sequence [ast .stmt ],
13831384 name : str ,
13841385 live_defs : Sequence [str ],
1385- parent_stmt : ast .stmt ,
1386- ):
1387- """Translation of a statement-block to GraphProto attribute."""
1388- info_stmt = stmts [0 ] if len (stmts ) > 0 else parent_stmt
1389- source = self ._source_of (info_stmt )
1386+ ) -> ir .Graph :
1387+ """Translation of a then/else statement-block to an ir.Graph."""
13901388 self ._enter_scope (name , None )
13911389 for s in stmts :
13921390 self ._translate_stmt (s )
1393- for pvar in live_defs :
1394- if pvar in self ._current_scope ():
1395- pv_val = self ._current_scope ()[pvar ]
1396- output = self ._to_onnx_var (pv_val , pvar )
1391+ for python_var in live_defs :
1392+ if python_var in self ._current_scope ():
1393+ python_var_value = self ._current_scope ()[python_var ]
1394+ output = self ._to_onnx_var (python_var_value , python_var )
13971395 if output .name not in self ._current_fn .assigned_names :
1396+ # TODO (Rama): Unclear how this can happen. If python_var is in current_scope,
1397+ # then it should have been assigned a value in the current graph.
1398+ #
13981399 # To return an outer-scope variable, an ONNX Graph has to
13991400 # use an explicit copy via Identity.
1400- output = self ._emit_copy (output , pvar )
1401- self ._current_fn .outputs .append (
1402- make_value (
1403- output .name ,
1404- pv_val .typeinfo ,
1405- source ,
1406- )
1407- )
1401+ output = self ._emit_copy (output , python_var )
1402+ self ._current_fn .outputs .append (output )
14081403 else :
1409- pv_val = None
1404+ python_var_value = None
14101405 for scope in self ._locals : # TODO: skip _current_scope
1411- if pvar in scope :
1412- pv_val = scope [pvar ]
1406+ if python_var in scope :
1407+ python_var_value = scope [python_var ]
14131408 break
1414- if pv_val is None :
1409+ if python_var_value is None :
14151410 self .fail (
14161411 stmts [0 ],
1417- f"ir.Value { pvar } is not assigned a value along a conditional "
1412+ f"ir.Value { python_var } is not assigned a value along a conditional "
14181413 f"branch, known variables: { list (self ._locals )} ." ,
14191414 )
14201415 # introduce a copy
1421- ovar = self ._emit_copy (self ._to_onnx_var (pv_val , pvar ), pvar )
1416+ output = self ._emit_copy (
1417+ self ._to_onnx_var (python_var_value , python_var ), python_var
1418+ )
14221419
14231420 # TODO: retrieve the annotation if any.
1424- typeinfo = None
1425- self ._current_fn .outputs .append (make_value (ovar .name , typeinfo , source ))
1426- graph = self ._exit_scope ()
1427- return graph .graph
1421+ self ._current_fn .outputs .append (output )
1422+ function_ir = self ._exit_scope ()
1423+ return function_ir .graph
14281424
14291425 def _translate_nested_function_def (self , fn : ast .FunctionDef ) -> None :
14301426 """Translate a nested function definition."""
@@ -1465,14 +1461,13 @@ def _translate_function_signature_common(
14651461 self ._current_fn .append_parameter (attr )
14661462 self ._bind (x .arg , values .AttrRef (x .arg , typeinfo , self ._source_of (x )))
14671463 else :
1468- self ._current_fn .append_parameter (
1469- make_value (x .arg , typeinfo , self ._source_of (x ))
1470- )
1464+ onnx_parameter = make_value (x .arg , typeinfo , self ._source_of (x ))
1465+ self ._current_fn .append_parameter (onnx_parameter )
14711466 self ._used_vars .add (x .arg )
14721467 self ._bind (
14731468 x .arg ,
14741469 values .Dynamic (
1475- ir . Value ( name = x . arg ) , values .DynamicKind .Input , self ._source_of (x )
1470+ onnx_parameter , values .DynamicKind .Input , self ._source_of (x )
14761471 ),
14771472 )
14781473 if fn .returns :
0 commit comments