Skip to content

Commit f158123

Browse files
committed
ChangeImport: rewrite qualified references when old_name is specified
When old_name is set, ChangeImport now also handles the `import X` / `X.name()` pattern by rewriting qualified references and managing direct imports. Supports aliased imports (`import X as f` / `f.name()`). Type attribution is updated on rewritten nodes (Identifier._type, MethodInvocation._method_type). Also fixes two pre-existing bugs: - AddImport: insertion point now considers standalone Import nodes - RemoveImport: UsageCollector now skips standalone Import identifiers
1 parent 8b90b6e commit f158123

4 files changed

Lines changed: 252 additions & 26 deletions

File tree

rewrite-python/rewrite/src/rewrite/python/add_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def _add_import(self, cu: CompilationUnit) -> CompilationUnit:
287287
insert_idx = 0
288288
padded_stmts = list(cu.padding.statements)
289289
for i, padded in enumerate(padded_stmts):
290-
if isinstance(padded.element, MultiImport):
290+
if isinstance(padded.element, (Import, MultiImport)):
291291
insert_idx = i + 1
292292
elif insert_idx > 0:
293293
break # Stop after we've passed the import section

rewrite-python/rewrite/src/rewrite/python/recipes/change_import.py

Lines changed: 138 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Recipe to change Python imports from one module/name to another."""
1616

17-
from dataclasses import dataclass, field
17+
from dataclasses import dataclass, field, replace as dc_replace
1818
from typing import Any, Optional
1919

2020
from rewrite import ExecutionContext, Recipe, TreeVisitor
@@ -23,15 +23,26 @@
2323
from rewrite.marketplace import Python
2424
from rewrite.recipe import option
2525
from rewrite.java import J
26-
from rewrite.java.tree import Empty, FieldAccess, Identifier, Import
26+
from rewrite.java.support_types import JavaType
27+
from rewrite.java.tree import Empty, FieldAccess, Identifier, Import, MethodInvocation
2728
from rewrite.python.tree import CompilationUnit, MultiImport
2829
from rewrite.python.visitor import PythonVisitor
2930
from rewrite.python.add_import import AddImportOptions, maybe_add_import
31+
from rewrite.python.remove_import import RemoveImportOptions, maybe_remove_import
3032

3133

3234
_Imports = [*Python, CategoryDescriptor(display_name="Imports")]
3335

3436

37+
def _create_module_type(fqn: str) -> JavaType.Class:
38+
"""Create a JavaType.Class for a module from its fully qualified name."""
39+
class_type = JavaType.Class()
40+
class_type._flags_bit_map = 0
41+
class_type._fully_qualified_name = fqn
42+
class_type._kind = JavaType.FullyQualified.Kind.Class
43+
return class_type
44+
45+
3546
@categorize(_Imports)
3647
@dataclass
3748
class ChangeImport(Recipe):
@@ -124,48 +135,80 @@ def __init__(self):
124135
super().__init__()
125136
self.has_old_import = False
126137
self.old_alias: Optional[str] = None
138+
self.has_direct_module_import = False
139+
self.module_alias: Optional[str] = None
140+
self.rewrote_qualified_refs = False
141+
self._new_module_type: Optional[JavaType.Class] = None
127142

128143
def visit_compilation_unit(self, cu: CompilationUnit, p: ExecutionContext) -> J:
129-
# First pass: check if the old import exists
144+
# Reset state for each compilation unit
130145
self.has_old_import = False
131146
self.old_alias = None
147+
self.has_direct_module_import = False
148+
self.module_alias = None
149+
self.rewrote_qualified_refs = False
132150

151+
# Single pass: detect old imports and direct module imports
133152
for stmt in cu.statements:
134153
if isinstance(stmt, Import) and not isinstance(stmt, MultiImport):
135-
alias = self._check_for_old_single_import(stmt)
136-
if alias is not None:
137-
self.has_old_import = True
138-
self.old_alias = alias if alias != "" else None
139-
break
154+
if not self.has_old_import:
155+
alias = self._check_for_old_single_import(stmt)
156+
if alias is not None:
157+
self.has_old_import = True
158+
self.old_alias = alias if alias != "" else None
159+
if old_name and not self.has_direct_module_import:
160+
name = self._get_qualid_name(stmt.qualid)
161+
if name == old_module:
162+
self.has_direct_module_import = True
163+
self.module_alias = self._get_alias_name(stmt)
140164
elif isinstance(stmt, MultiImport):
141-
alias = self._check_for_old_import(stmt)
142-
if alias is not None:
143-
self.has_old_import = True
144-
self.old_alias = alias if alias != "" else None
145-
break
146-
147-
if not self.has_old_import:
165+
if not self.has_old_import:
166+
alias = self._check_for_old_import(stmt)
167+
if alias is not None:
168+
self.has_old_import = True
169+
self.old_alias = alias if alias != "" else None
170+
if old_name and not self.has_direct_module_import and stmt.from_ is None:
171+
for imp in stmt.names:
172+
name = self._get_qualid_name(imp.qualid)
173+
if name == old_module:
174+
self.has_direct_module_import = True
175+
self.module_alias = self._get_alias_name(imp)
176+
break
177+
178+
if not self.has_old_import and not self.has_direct_module_import:
148179
return cu
149180

150181
# Visit to transform imports
151182
result = super().visit_compilation_unit(cu, p)
152183
if not isinstance(result, CompilationUnit):
153184
return result
154185

155-
# Schedule adding the new import
156-
alias_to_use = new_alias or self.old_alias
157-
if new_name:
186+
# Schedule adding the new import (only for direct import changes)
187+
if self.has_old_import:
188+
alias_to_use = new_alias or self.old_alias
189+
if new_name:
190+
maybe_add_import(self, AddImportOptions(
191+
module=new_module,
192+
name=new_name,
193+
alias=alias_to_use,
194+
only_if_referenced=False
195+
))
196+
else:
197+
maybe_add_import(self, AddImportOptions(
198+
module=new_module,
199+
alias=alias_to_use,
200+
only_if_referenced=False
201+
))
202+
203+
# If we rewrote qualified references, manage the direct import
204+
if self.rewrote_qualified_refs:
158205
maybe_add_import(self, AddImportOptions(
159206
module=new_module,
160-
name=new_name,
161-
alias=alias_to_use,
207+
alias=new_alias,
162208
only_if_referenced=False
163209
))
164-
else:
165-
maybe_add_import(self, AddImportOptions(
166-
module=new_module,
167-
alias=alias_to_use,
168-
only_if_referenced=False
210+
maybe_remove_import(self, RemoveImportOptions(
211+
module=old_module,
169212
))
170213

171214
return result
@@ -196,6 +239,76 @@ def visit_multi_import(self, multi: MultiImport, p: ExecutionContext) -> Optiona
196239
# import X - remove entire import
197240
return self._remove_module_from_import(multi, old_module)
198241

242+
def visit_method_invocation(self, method: MethodInvocation, p: ExecutionContext) -> J:
243+
method = super().visit_method_invocation(method, p)
244+
if not old_name or not self.has_direct_module_import:
245+
return method
246+
if not isinstance(method, MethodInvocation):
247+
return method
248+
if not isinstance(method.select, Identifier):
249+
return method
250+
if not isinstance(method.name, Identifier):
251+
return method
252+
253+
select_name = method.select.simple_name
254+
expected_name = self.module_alias or old_module
255+
if select_name != expected_name:
256+
return method
257+
if method.name.simple_name != old_name:
258+
return method
259+
260+
self.rewrote_qualified_refs = True
261+
new_select_name = new_alias or new_module
262+
new_select = method.select.replace(_simple_name=new_select_name)
263+
# Update type attribution on the select identifier
264+
if method.select.type is not None:
265+
new_select = new_select.replace(_type=self._get_new_module_type())
266+
new_padded_select = method.padding.select.replace(_element=new_select)
267+
result = method.padding.replace(_select=new_padded_select)
268+
if new_name and new_name != old_name:
269+
result = result.replace(_name=result.name.replace(_simple_name=new_name))
270+
# Update method_type declaring type and name
271+
if result.method_type is not None:
272+
result = result.replace(_method_type=dc_replace(
273+
result.method_type,
274+
_declaring_type=self._get_new_module_type(),
275+
_name=new_name or old_name,
276+
))
277+
return result
278+
279+
def visit_field_access(self, field_access: FieldAccess, p: ExecutionContext) -> J:
280+
field_access = super().visit_field_access(field_access, p)
281+
if not old_name or not self.has_direct_module_import:
282+
return field_access
283+
if not isinstance(field_access, FieldAccess):
284+
return field_access
285+
if not isinstance(field_access.target, Identifier):
286+
return field_access
287+
288+
existing_name = field_access.target.simple_name
289+
expected_name = self.module_alias or old_module
290+
if existing_name != expected_name:
291+
return field_access
292+
if field_access.name.simple_name != old_name:
293+
return field_access
294+
295+
self.rewrote_qualified_refs = True
296+
new_target_name = new_alias or new_module
297+
new_target = field_access.target.replace(_simple_name=new_target_name)
298+
# Update type attribution on the target identifier
299+
if field_access.target.type is not None:
300+
new_target = new_target.replace(_type=self._get_new_module_type())
301+
result = field_access.replace(_target=new_target)
302+
if new_name and new_name != old_name:
303+
new_name_ident = result.name.replace(_simple_name=new_name)
304+
result = result.padding.replace(_name=result.padding.name.replace(_element=new_name_ident))
305+
return result
306+
307+
def _get_new_module_type(self) -> JavaType.Class:
308+
if self._new_module_type is None:
309+
self._new_module_type = _create_module_type(new_module)
310+
return self._new_module_type
311+
199312
def _check_for_old_single_import(self, imp: Import) -> Optional[str]:
200313
"""Check if a standalone J.Import matches the old import."""
201314
if old_name:

rewrite-python/rewrite/src/rewrite/python/remove_import.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ def __init__(self):
111111
super().__init__()
112112
self.in_import = False
113113

114+
def visit_import(self, import_: Import, p) -> J:
115+
# Don't collect identifiers from standalone import statements
116+
self.in_import = True
117+
result = super().visit_import(import_, p)
118+
self.in_import = False
119+
return result
120+
114121
def visit_multi_import(self, multi: MultiImport, p) -> J:
115122
# Don't collect identifiers from import statements
116123
self.in_import = True

rewrite-python/rewrite/tests/recipes/test_change_import.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,109 @@ def test_change_removes_one_name_from_multi_import(self):
120120
""",
121121
)
122122
)
123+
124+
def test_change_qualified_method_call(self):
125+
"""Change: import fractions / fractions.gcd() -> import math / math.gcd()"""
126+
spec = RecipeSpec(recipe=ChangeImport(
127+
old_module='fractions',
128+
old_name='gcd',
129+
new_module='math',
130+
))
131+
spec.rewrite_run(
132+
python(
133+
"""
134+
import fractions
135+
result = fractions.gcd(12, 8)
136+
""",
137+
"""
138+
import math
139+
result = math.gcd(12, 8)
140+
""",
141+
)
142+
)
143+
144+
def test_change_aliased_qualified_method_call(self):
145+
"""Change: import fractions as f / f.gcd() -> import math / math.gcd()"""
146+
spec = RecipeSpec(recipe=ChangeImport(
147+
old_module='fractions',
148+
old_name='gcd',
149+
new_module='math',
150+
))
151+
spec.rewrite_run(
152+
python(
153+
"""
154+
import fractions as f
155+
result = f.gcd(12, 8)
156+
""",
157+
"""
158+
import math
159+
result = math.gcd(12, 8)
160+
""",
161+
)
162+
)
163+
164+
def test_change_qualified_ref_keeps_import_when_other_usages(self):
165+
"""import fractions stays when fractions.Fraction is still used."""
166+
spec = RecipeSpec(recipe=ChangeImport(
167+
old_module='fractions',
168+
old_name='gcd',
169+
new_module='math',
170+
))
171+
spec.rewrite_run(
172+
python(
173+
"""
174+
import fractions
175+
result = fractions.gcd(12, 8)
176+
f = fractions.Fraction(1, 3)
177+
""",
178+
"""
179+
import fractions
180+
import math
181+
result = math.gcd(12, 8)
182+
f = fractions.Fraction(1, 3)
183+
""",
184+
)
185+
)
186+
187+
def test_change_qualified_field_access(self):
188+
"""Change: import fractions / fn = fractions.gcd -> import math / fn = math.gcd"""
189+
spec = RecipeSpec(recipe=ChangeImport(
190+
old_module='fractions',
191+
old_name='gcd',
192+
new_module='math',
193+
))
194+
spec.rewrite_run(
195+
python(
196+
"""
197+
import fractions
198+
fn = fractions.gcd
199+
""",
200+
"""
201+
import math
202+
fn = math.gcd
203+
""",
204+
)
205+
)
206+
207+
def test_change_qualified_ref_with_different_new_name(self):
208+
"""Qualified ref rewrite when new_name differs from old_name."""
209+
spec = RecipeSpec(recipe=ChangeImport(
210+
old_module='fractions',
211+
old_name='gcd',
212+
new_module='math',
213+
new_name='greatest_common_divisor',
214+
))
215+
spec.rewrite_run(
216+
python(
217+
"""
218+
import fractions
219+
result = fractions.gcd(12, 8)
220+
fn = fractions.gcd
221+
""",
222+
"""
223+
import math
224+
result = math.greatest_common_divisor(12, 8)
225+
fn = math.greatest_common_divisor
226+
""",
227+
)
228+
)

0 commit comments

Comments
 (0)