Skip to content

Commit 3f5a3c3

Browse files
gramalingamCopilot
andauthored
Trivial cleanup of onnxscript converter (#2839)
* Change names to match suggested style (snake_case etc.) * Add _ to internal methods where missing * Remove a couple of lines of useless code --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 6bfbb4d commit 3f5a3c3

File tree

3 files changed

+109
-116
lines changed

3 files changed

+109
-116
lines changed

onnxscript/_internal/analysis.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
153153
return live
154154

155155
def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
156-
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
156+
def visit_block(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
157157
for s in reversed(block):
158158
live_out = visit(s, live_out)
159159
return live_out
@@ -167,28 +167,28 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
167167
if isinstance(stmt, ast.If):
168168
constant_cond = self.constant_if_condition(stmt)
169169
if constant_cond is None:
170-
live1 = visitBlock(stmt.body, live_out)
171-
live2 = visitBlock(stmt.orelse, live_out)
170+
live1 = visit_block(stmt.body, live_out)
171+
live2 = visit_block(stmt.orelse, live_out)
172172
return live1 | live2 | _used_vars(stmt.test)
173173
elif constant_cond:
174-
return visitBlock(stmt.body, live_out)
174+
return visit_block(stmt.body, live_out)
175175
else:
176-
return visitBlock(stmt.orelse, live_out)
176+
return visit_block(stmt.orelse, live_out)
177177
if isinstance(stmt, ast.For):
178178
p_loop_var = _get_loop_var(stmt, self._formatter)
179179
prev = None
180180
curr = live_out
181181
while curr != prev:
182182
prev = curr
183-
curr = visitBlock(stmt.body, prev).difference({p_loop_var})
183+
curr = visit_block(stmt.body, prev).difference({p_loop_var})
184184
return curr
185185
if isinstance(stmt, ast.While):
186186
cond_vars = _used_vars(stmt.test)
187187
prev = None
188188
curr = live_out | cond_vars
189189
while curr != prev:
190190
prev = curr
191-
curr = visitBlock(stmt.body, prev) | cond_vars
191+
curr = visit_block(stmt.body, prev) | cond_vars
192192
return curr
193193
if isinstance(stmt, ast.Break):
194194
# The following is sufficient for the current restricted usage, where
@@ -228,7 +228,7 @@ def exposed_uses(self, stmts: Sequence[ast.stmt]) -> set[str]:
228228
(in the first statement). Hence x is included in the exposed_uses.
229229
"""
230230

231-
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
231+
def visit_block(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
232232
for stmt in reversed(block):
233233
live_out = visit(stmt, live_out)
234234
return live_out
@@ -243,13 +243,13 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
243243
if isinstance(stmt, ast.If):
244244
constant_cond = self.constant_if_condition(stmt)
245245
if constant_cond is None:
246-
live1 = visitBlock(stmt.body, live_out)
247-
live2 = visitBlock(stmt.orelse, live_out)
246+
live1 = visit_block(stmt.body, live_out)
247+
live2 = visit_block(stmt.orelse, live_out)
248248
return (live1 | live2) | _used_vars(stmt.test)
249249
elif constant_cond:
250-
return visitBlock(stmt.body, live_out)
250+
return visit_block(stmt.body, live_out)
251251
else:
252-
return visitBlock(stmt.orelse, live_out)
252+
return visit_block(stmt.orelse, live_out)
253253
if ast_utils.is_print_call(stmt):
254254
return live_out
255255
if ast_utils.is_doc_string(stmt):
@@ -259,13 +259,13 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
259259
# for loops that execute at least once.
260260
loop_var_set = {_get_loop_var(stmt, self._formatter)}
261261
used_after_loop = live_out.difference(loop_var_set)
262-
used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set)
262+
used_inside_loop = visit_block(stmt.body, set()).difference(loop_var_set)
263263
used_in_loop_header = _used_vars(stmt.iter)
264264
return used_inside_loop | used_in_loop_header | used_after_loop
265265
if isinstance(stmt, ast.While):
266266
# Analysis assumes loop may execute zero times. Results can be improved
267267
# for loops that execute at least once.
268-
used_inside_loop = visitBlock(stmt.body, set())
268+
used_inside_loop = visit_block(stmt.body, set())
269269
used_in_loop_header = _used_vars(stmt.test)
270270
return used_inside_loop | used_in_loop_header | live_out
271271
if isinstance(stmt, ast.Break):
@@ -281,7 +281,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
281281
self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
282282
)
283283

284-
return visitBlock(stmts, set())
284+
return visit_block(stmts, set())
285285

286286
def outer_scope_variables(self, fun: ast.FunctionDef) -> set[str]:
287287
"""Return the set of outer-scope variables used in a nested function.

onnxscript/_internal/autocast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,15 @@ def get_type_info(x: Optional[ir.Value]) -> Optional[ir.Value]:
187187
argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is
188188
castable, while X can serve as the target-type.
189189
"""
190-
return None if x is None or converter_.is_castable(x.name) else x
190+
return None if x is None or converter_._is_castable(x.name) else x # pylint: disable=protected-access
191191

192192
def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
193193
if x is None:
194194
return None
195-
if converter_.is_castable(x.name) and y is not None:
195+
if converter_._is_castable(x.name) and y is not None: # pylint: disable=protected-access
196196
# Polymorphic constant x is cast to the type of y:
197-
x_cast = converter_.generate_unique_name(f"{x.name}_cast")
198-
return converter_.emit1([x_cast], "CastLike", [x, y])
197+
x_cast = converter_._generate_unique_name(f"{x.name}_cast") # pylint: disable=protected-access
198+
return converter_._emit1([x_cast], "CastLike", [x, y]) # pylint: disable=protected-access
199199
return x
200200

201201
return cast_inputs(get_type_info, cast_like, op_signature, args)

0 commit comments

Comments
 (0)