Skip to content

Commit 636926d

Browse files
Add Variable type attribution for Python identifiers (#6772)
* ChangeImport: rename bare references when from-import name changes When `from X import old_name` is rewritten to `from X import new_name`, bare references to `old_name` in the code (e.g., function calls like `clock()`) are now renamed to `new_name` (e.g., `perf_counter()`). This is done via a new `visit_identifier` method in ChangeImportVisitor that renames matching identifiers outside of import statements. An xfail test documents that scope analysis for shadowed locals is not yet implemented. * Add Variable type attribution for Python identifiers Populate field_type with JavaType.Variable on Identifiers for local variables, module-level constants, and class fields (self.x). Uses a single ty hover query per name/attribute to avoid double lookups. - Implement JavaType.Variable as a full dataclass with name, type, owner - Add name_type_info() and attribute_type_info() to PythonTypeMapping - Wire visit_Name and visit_Attribute to use the new combined methods - Add RPC sender/receiver support for Variable serialization - Add tests for variable hover detection, Variable creation, and integration tests with ty for name/attribute type info * ChangeImport: skip renaming local variables that shadow imports Use field_type from Variable type attribution to distinguish local variable references from imported symbol references. Also exclude Unknown hover results from being treated as variables. * Address code review feedback - Fix frozen mismatch: remove frozen=True from Variable .pyi stub - Use constructor kwargs in _make_variable instead of post-mutation - Remove duplicate Variable receiver code in python_receiver.py - Make remove_import.py visit_import/visit_multi_import exception-safe with try/finally - Document limitations: field_type shadow detection needs ty, visit_method_invocation only handles simple Identifier selects, attribute_type_info column offset for multiline expressions - Add test for both from-import and direct import simultaneously * ChangeImport: scope shadow check to function bodies only Module-level bare references to an imported name were incorrectly skipped when type attribution populated field_type. Limit the field_type shadow check to identifiers inside MethodDeclaration scopes so module-level references are always renamed. Also add default initializers to visitor fields, expand inline documentation, and remove redundant truthiness guards in the RPC variable receiver. * Skip shadow detection test when ty CLI is not available
1 parent 3d47629 commit 636926d

10 files changed

Lines changed: 623 additions & 29 deletions

File tree

rewrite-python/rewrite/src/rewrite/java/support_types.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,33 @@ def default_value(self) -> Optional[List[str]]:
298298
def declared_formal_type_names(self) -> Optional[List[str]]:
299299
return self._declared_formal_type_names
300300

301+
@dataclass
301302
class Variable:
302-
pass
303+
_flags_bit_map: int = field(default=0)
304+
_name: str = field(default="")
305+
_owner: Optional[JavaType] = field(default=None)
306+
_type: Optional[JavaType] = field(default=None)
307+
_annotations: Optional[List[JavaType.FullyQualified]] = field(default=None)
308+
309+
@property
310+
def flags_bit_map(self) -> int:
311+
return self._flags_bit_map
312+
313+
@property
314+
def name(self) -> str:
315+
return self._name
316+
317+
@property
318+
def owner(self) -> Optional[JavaType]:
319+
return self._owner
320+
321+
@property
322+
def type(self) -> Optional[JavaType]:
323+
return self._type
324+
325+
@property
326+
def annotations(self) -> Optional[List[JavaType.FullyQualified]]:
327+
return self._annotations
303328

304329
class Array:
305330
pass

rewrite-python/rewrite/src/rewrite/java/support_types.pyi

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ J2 = TypeVar('J2', bound=J)
1313
J3 = TypeVar('J3', bound=J)
1414

1515
from abc import abstractmethod, ABC
16-
from enum import Enum, auto
16+
from enum import Enum
1717
from rewrite import Markers
1818
from rewrite import Tree, SourceFile, TreeVisitor
1919
from rewrite.utils import replace_if_changed
@@ -144,8 +144,24 @@ class JavaType(ABC):
144144
@property
145145
def declared_formal_type_names(self) -> Optional[List[str]]: ...
146146

147+
@dataclass
147148
class Variable:
148-
pass
149+
_flags_bit_map: int = ...
150+
_name: str = ...
151+
_owner: Optional[JavaType] = ...
152+
_type: Optional[JavaType] = ...
153+
_annotations: Optional[List[JavaType.FullyQualified]] = ...
154+
155+
@property
156+
def flags_bit_map(self) -> int: ...
157+
@property
158+
def name(self) -> str: ...
159+
@property
160+
def owner(self) -> Optional[JavaType]: ...
161+
@property
162+
def type(self) -> Optional[JavaType]: ...
163+
@property
164+
def annotations(self) -> Optional[List[JavaType.FullyQualified]]: ...
149165

150166
class Array:
151167
pass

rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import dataclasses
23
import sys
34
import token
45
from argparse import ArgumentError
@@ -1050,13 +1051,24 @@ def visit_TypeIgnore(self, node):
10501051
raise NotImplementedError("Implement visit_TypeIgnore!")
10511052

10521053
def visit_Attribute(self, node):
1054+
prefix = self.__whitespace()
1055+
target = self.__convert(node.value)
1056+
receiver_type = getattr(target, 'type', None) if hasattr(target, 'type') else None
1057+
dot_space = self.__source_before('.')
1058+
name_ident = self.__convert_name(node.attr)
1059+
1060+
expr_type, field_type = self._type_mapping.attribute_type_info(node, receiver_type)
1061+
1062+
if isinstance(name_ident, j.Identifier):
1063+
name_ident = dataclasses.replace(name_ident, _type=expr_type, _field_type=field_type)
1064+
10531065
return j.FieldAccess(
10541066
random_id(),
1055-
self.__whitespace(),
1067+
prefix,
10561068
Markers.EMPTY,
1057-
self.__convert(node.value),
1058-
self.__pad_left(self.__source_before('.'), self.__convert_name(node.attr)),
1059-
self._type_mapping.type(node),
1069+
target,
1070+
self.__pad_left(dot_space, name_ident),
1071+
expr_type,
10601072
)
10611073

10621074
def visit_Del(self, node):
@@ -2548,14 +2560,15 @@ def visit_Module(self, node: ast.Module) -> py.CompilationUnit:
25482560

25492561
def visit_Name(self, node):
25502562
space, actual_name = self.__consume_identifier(node.id)
2563+
expr_type, field_type = self._type_mapping.name_type_info(node)
25512564
return j.Identifier(
25522565
random_id(),
25532566
space,
25542567
Markers.EMPTY,
25552568
[],
25562569
actual_name,
2557-
self._type_mapping.type(node),
2558-
None
2570+
expr_type,
2571+
field_type
25592572
)
25602573

25612574
def visit_NamedExpr(self, node):

rewrite-python/rewrite/src/rewrite/python/recipes/change_import.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from rewrite.recipe import option
2525
from rewrite.java import J
2626
from rewrite.java.support_types import JavaType
27-
from rewrite.java.tree import Empty, FieldAccess, Identifier, Import, MethodInvocation
27+
from rewrite.java.tree import Empty, FieldAccess, Identifier, Import, MethodDeclaration, MethodInvocation
2828
from rewrite.python.tree import CompilationUnit, MultiImport
2929
from rewrite.python.visitor import PythonVisitor
3030
from rewrite.python.add_import import AddImportOptions, maybe_add_import
@@ -35,7 +35,11 @@
3535

3636

3737
def _create_module_type(fqn: str) -> JavaType.Class:
38-
"""Create a JavaType.Class for a module from its fully qualified name."""
38+
"""Create a JavaType.Class for a module from its fully qualified name.
39+
40+
JavaType.Class is not a dataclass, so fields are set directly after
41+
construction. This matches the pattern used elsewhere in the codebase.
42+
"""
3943
class_type = JavaType.Class()
4044
class_type._flags_bit_map = 0
4145
class_type._fully_qualified_name = fqn
@@ -131,12 +135,12 @@ def editor(self) -> TreeVisitor[Any, ExecutionContext]:
131135
new_alias = self.new_alias
132136

133137
class ChangeImportVisitor(PythonVisitor[ExecutionContext]):
134-
has_old_import: bool
135-
old_alias: Optional[str]
136-
has_direct_module_import: bool
137-
module_alias: Optional[str]
138-
rewrote_qualified_refs: bool
139-
new_module_type: Optional[JavaType.Class]
138+
has_old_import: bool = False
139+
old_alias: Optional[str] = None
140+
has_direct_module_import: bool = False
141+
module_alias: Optional[str] = None
142+
rewrote_qualified_refs: bool = False
143+
new_module_type: Optional[JavaType.Class] = None
140144

141145
def visit_compilation_unit(self, cu: CompilationUnit, p: ExecutionContext) -> J:
142146
self.has_old_import = False
@@ -237,18 +241,51 @@ def visit_multi_import(self, multi: MultiImport, p: ExecutionContext) -> Optiona
237241
# import X - remove entire import
238242
return self._remove_module_from_import(multi, old_module)
239243

244+
def visit_identifier(self, ident: Identifier, p: ExecutionContext) -> J:
245+
ident = super().visit_identifier(ident, p)
246+
if not isinstance(ident, Identifier):
247+
return ident
248+
if not old_name or not new_name or not self.has_old_import:
249+
return ident
250+
old_ref_name = self.old_alias or old_name
251+
new_ref_name = new_alias or self.old_alias or new_name
252+
if old_ref_name == new_ref_name:
253+
return ident
254+
if ident.simple_name != old_ref_name:
255+
return ident
256+
# Skip identifiers inside import statements
257+
if self.cursor.first_enclosing(Import):
258+
return ident
259+
# Skip local variables that shadow the imported name.
260+
# Only check field_type inside function scopes — at module level,
261+
# bare references to the imported name always need renaming.
262+
# When ty is unavailable, field_type is None for all identifiers
263+
# and shadowed locals may be incorrectly renamed.
264+
if self.cursor.first_enclosing(MethodDeclaration) is not None:
265+
if ident.field_type is not None:
266+
return ident
267+
return ident.replace(_simple_name=new_ref_name)
268+
240269
def visit_method_invocation(self, method: MethodInvocation, p: ExecutionContext) -> J:
241270
method = super().visit_method_invocation(method, p)
242-
if not old_name or not self.has_direct_module_import:
243-
return method
244271
if not isinstance(method, MethodInvocation):
245272
return method
273+
if not old_name or not self.has_direct_module_import:
274+
return method
275+
# Only matches simple module.func() calls where the select is an
276+
# Identifier. Nested attribute chains like pkg.module.func()
277+
# (where select is a FieldAccess) are not currently handled.
246278
if not isinstance(method.select, Identifier):
247279
return method
248280
if not isinstance(method.name, Identifier):
249281
return method
250282

251283
select_name = method.select.simple_name
284+
# For dotted modules without aliases (e.g. `import os.path`),
285+
# `old_module` is a dotted string like "os.path" which will
286+
# never match a simple Identifier name — but those cases are
287+
# already excluded by the `isinstance(method.select, Identifier)`
288+
# guard above (the select would be a FieldAccess instead).
252289
expected_name = self.module_alias or old_module
253290
if select_name != expected_name:
254291
return method

rewrite-python/rewrite/src/rewrite/python/remove_import.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,18 @@ def __init__(self):
114114
def visit_import(self, import_: Import, p) -> J:
115115
# Don't collect identifiers from standalone import statements
116116
self.in_import = True
117-
result = super().visit_import(import_, p)
118-
self.in_import = False
119-
return result
117+
try:
118+
return super().visit_import(import_, p)
119+
finally:
120+
self.in_import = False
120121

121122
def visit_multi_import(self, multi: MultiImport, p) -> J:
122123
# Don't collect identifiers from import statements
123124
self.in_import = True
124-
result = super().visit_multi_import(multi, p)
125-
self.in_import = False
126-
return result
125+
try:
126+
return super().visit_multi_import(multi, p)
127+
finally:
128+
self.in_import = False
127129

128130
def visit_identifier(self, ident: Identifier, p) -> J:
129131
if not self.in_import:

rewrite-python/rewrite/src/rewrite/python/type_mapping.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,54 @@ def _constant_type(self, node: ast.Constant) -> Optional[JavaType]:
266266
return JavaType.Primitive.None_
267267
return None
268268

269+
def _is_variable_hover(self, hover: str) -> bool:
270+
"""Check if hover text describes a variable (not a function, class, or module).
271+
272+
ty hover patterns:
273+
- Variable: just the type, e.g. ``int``, ``Literal[5]``, ``str``
274+
- Function: ``def func(args) -> return_type``
275+
- Class definition: ``class MyClass``
276+
- Module: ``<module 'name'>``
277+
"""
278+
if not hover:
279+
return False
280+
clean = self._strip_markdown(hover)
281+
if not clean:
282+
return False
283+
if clean.startswith('def ') or clean.startswith('class ') or clean.startswith('<module'):
284+
return False
285+
if clean == 'Unknown':
286+
return False
287+
return True
288+
289+
def _make_variable(self, name: str, var_type: Optional[JavaType],
290+
owner: Optional[JavaType] = None) -> JavaType.Variable:
291+
"""Create a JavaType.Variable instance."""
292+
return JavaType.Variable(_name=name, _type=var_type, _owner=owner)
293+
294+
def name_type_info(self, node: ast.Name) -> Tuple[Optional[JavaType], Optional[JavaType.Variable]]:
295+
"""Get expression type and variable type for a name reference.
296+
297+
Returns (expression_type, variable_field_type) from a single hover query.
298+
"""
299+
if self._ty_client is None:
300+
return None, None
301+
302+
hover = self._ty_client.get_hover(
303+
self._uri,
304+
node.lineno - 1, # LSP uses 0-based lines
305+
node.col_offset
306+
)
307+
if not hover:
308+
return None, None
309+
310+
expr_type = self._parse_hover_type(hover)
311+
312+
if self._is_variable_hover(hover):
313+
return expr_type, self._make_variable(node.id, expr_type)
314+
315+
return expr_type, None
316+
269317
def _name_type(self, node: ast.Name) -> Optional[JavaType]:
270318
"""Get the type for a name reference."""
271319
if self._ty_client is None:
@@ -281,6 +329,40 @@ def _name_type(self, node: ast.Name) -> Optional[JavaType]:
281329
return self._parse_hover_type(hover)
282330
return None
283331

332+
def attribute_type_info(self, node: ast.Attribute,
333+
receiver_type: Optional[JavaType] = None
334+
) -> Tuple[Optional[JavaType], Optional[JavaType.Variable]]:
335+
"""Get expression type and variable type for an attribute access.
336+
337+
Args:
338+
node: The ast.Attribute node.
339+
receiver_type: The type of the receiver (e.g., type of 'self').
340+
341+
Returns (expression_type, variable_field_type).
342+
343+
Note:
344+
The hover column is computed as ``node.col_offset + len(receiver_text) + 1``.
345+
This may be inaccurate for multiline or parenthesized receiver expressions
346+
where ``_get_node_text`` returns a substring that differs from the full
347+
source span.
348+
"""
349+
if self._ty_client is None:
350+
return None, None
351+
352+
hover = self._ty_client.get_hover(
353+
self._uri, node.lineno - 1,
354+
node.col_offset + len(self._get_node_text(node.value)) + 1
355+
)
356+
if not hover:
357+
return None, None
358+
359+
expr_type = self._parse_hover_type(hover)
360+
361+
if self._is_variable_hover(hover):
362+
return expr_type, self._make_variable(node.attr, expr_type, owner=receiver_type)
363+
364+
return expr_type, None
365+
284366
def _attribute_type(self, node: ast.Attribute) -> Optional[JavaType]:
285367
"""Get the type for an attribute access."""
286368
if self._ty_client is None:

0 commit comments

Comments
 (0)