|
14 | 14 |
|
15 | 15 | """Recipe to change Python imports from one module/name to another.""" |
16 | 16 |
|
17 | | -from dataclasses import dataclass, field |
| 17 | +from dataclasses import dataclass, field, replace as dc_replace |
18 | 18 | from typing import Any, Optional |
19 | 19 |
|
20 | 20 | from rewrite import ExecutionContext, Recipe, TreeVisitor |
|
23 | 23 | from rewrite.marketplace import Python |
24 | 24 | from rewrite.recipe import option |
25 | 25 | from rewrite.java import J |
26 | | -from rewrite.java.tree import Empty, FieldAccess, Identifier, Import |
| 26 | +from rewrite.java.support_types import JavaType |
| 27 | +from rewrite.java.tree import Empty, FieldAccess, Identifier, Import, MethodInvocation |
27 | 28 | from rewrite.python.tree import CompilationUnit, MultiImport |
28 | 29 | from rewrite.python.visitor import PythonVisitor |
29 | 30 | from rewrite.python.add_import import AddImportOptions, maybe_add_import |
| 31 | +from rewrite.python.remove_import import RemoveImportOptions, maybe_remove_import |
30 | 32 |
|
31 | 33 |
|
32 | 34 | _Imports = [*Python, CategoryDescriptor(display_name="Imports")] |
33 | 35 |
|
34 | 36 |
|
| 37 | +def _create_module_type(fqn: str) -> JavaType.Class: |
| 38 | + """Create a JavaType.Class for a module from its fully qualified name.""" |
| 39 | + class_type = JavaType.Class() |
| 40 | + class_type._flags_bit_map = 0 |
| 41 | + class_type._fully_qualified_name = fqn |
| 42 | + class_type._kind = JavaType.FullyQualified.Kind.Class |
| 43 | + return class_type |
| 44 | + |
| 45 | + |
35 | 46 | @categorize(_Imports) |
36 | 47 | @dataclass |
37 | 48 | class ChangeImport(Recipe): |
@@ -124,48 +135,80 @@ def __init__(self): |
124 | 135 | super().__init__() |
125 | 136 | self.has_old_import = False |
126 | 137 | self.old_alias: Optional[str] = None |
| 138 | + self.has_direct_module_import = False |
| 139 | + self.module_alias: Optional[str] = None |
| 140 | + self.rewrote_qualified_refs = False |
| 141 | + self._new_module_type: Optional[JavaType.Class] = None |
127 | 142 |
|
128 | 143 | def visit_compilation_unit(self, cu: CompilationUnit, p: ExecutionContext) -> J: |
129 | | - # First pass: check if the old import exists |
| 144 | + # Reset state for each compilation unit |
130 | 145 | self.has_old_import = False |
131 | 146 | self.old_alias = None |
| 147 | + self.has_direct_module_import = False |
| 148 | + self.module_alias = None |
| 149 | + self.rewrote_qualified_refs = False |
132 | 150 |
|
| 151 | + # Single pass: detect old imports and direct module imports |
133 | 152 | for stmt in cu.statements: |
134 | 153 | if isinstance(stmt, Import) and not isinstance(stmt, MultiImport): |
135 | | - alias = self._check_for_old_single_import(stmt) |
136 | | - if alias is not None: |
137 | | - self.has_old_import = True |
138 | | - self.old_alias = alias if alias != "" else None |
139 | | - break |
| 154 | + if not self.has_old_import: |
| 155 | + alias = self._check_for_old_single_import(stmt) |
| 156 | + if alias is not None: |
| 157 | + self.has_old_import = True |
| 158 | + self.old_alias = alias if alias != "" else None |
| 159 | + if old_name and not self.has_direct_module_import: |
| 160 | + name = self._get_qualid_name(stmt.qualid) |
| 161 | + if name == old_module: |
| 162 | + self.has_direct_module_import = True |
| 163 | + self.module_alias = self._get_alias_name(stmt) |
140 | 164 | elif isinstance(stmt, MultiImport): |
141 | | - alias = self._check_for_old_import(stmt) |
142 | | - if alias is not None: |
143 | | - self.has_old_import = True |
144 | | - self.old_alias = alias if alias != "" else None |
145 | | - break |
146 | | - |
147 | | - if not self.has_old_import: |
| 165 | + if not self.has_old_import: |
| 166 | + alias = self._check_for_old_import(stmt) |
| 167 | + if alias is not None: |
| 168 | + self.has_old_import = True |
| 169 | + self.old_alias = alias if alias != "" else None |
| 170 | + if old_name and not self.has_direct_module_import and stmt.from_ is None: |
| 171 | + for imp in stmt.names: |
| 172 | + name = self._get_qualid_name(imp.qualid) |
| 173 | + if name == old_module: |
| 174 | + self.has_direct_module_import = True |
| 175 | + self.module_alias = self._get_alias_name(imp) |
| 176 | + break |
| 177 | + |
| 178 | + if not self.has_old_import and not self.has_direct_module_import: |
148 | 179 | return cu |
149 | 180 |
|
150 | 181 | # Visit to transform imports |
151 | 182 | result = super().visit_compilation_unit(cu, p) |
152 | 183 | if not isinstance(result, CompilationUnit): |
153 | 184 | return result |
154 | 185 |
|
155 | | - # Schedule adding the new import |
156 | | - alias_to_use = new_alias or self.old_alias |
157 | | - if new_name: |
| 186 | + # Schedule adding the new import (only for direct import changes) |
| 187 | + if self.has_old_import: |
| 188 | + alias_to_use = new_alias or self.old_alias |
| 189 | + if new_name: |
| 190 | + maybe_add_import(self, AddImportOptions( |
| 191 | + module=new_module, |
| 192 | + name=new_name, |
| 193 | + alias=alias_to_use, |
| 194 | + only_if_referenced=False |
| 195 | + )) |
| 196 | + else: |
| 197 | + maybe_add_import(self, AddImportOptions( |
| 198 | + module=new_module, |
| 199 | + alias=alias_to_use, |
| 200 | + only_if_referenced=False |
| 201 | + )) |
| 202 | + |
| 203 | + # If we rewrote qualified references, manage the direct import |
| 204 | + if self.rewrote_qualified_refs: |
158 | 205 | maybe_add_import(self, AddImportOptions( |
159 | 206 | module=new_module, |
160 | | - name=new_name, |
161 | | - alias=alias_to_use, |
| 207 | + alias=new_alias, |
162 | 208 | only_if_referenced=False |
163 | 209 | )) |
164 | | - else: |
165 | | - maybe_add_import(self, AddImportOptions( |
166 | | - module=new_module, |
167 | | - alias=alias_to_use, |
168 | | - only_if_referenced=False |
| 210 | + maybe_remove_import(self, RemoveImportOptions( |
| 211 | + module=old_module, |
169 | 212 | )) |
170 | 213 |
|
171 | 214 | return result |
@@ -196,6 +239,76 @@ def visit_multi_import(self, multi: MultiImport, p: ExecutionContext) -> Optiona |
196 | 239 | # import X - remove entire import |
197 | 240 | return self._remove_module_from_import(multi, old_module) |
198 | 241 |
|
| 242 | + def visit_method_invocation(self, method: MethodInvocation, p: ExecutionContext) -> J: |
| 243 | + method = super().visit_method_invocation(method, p) |
| 244 | + if not old_name or not self.has_direct_module_import: |
| 245 | + return method |
| 246 | + if not isinstance(method, MethodInvocation): |
| 247 | + return method |
| 248 | + if not isinstance(method.select, Identifier): |
| 249 | + return method |
| 250 | + if not isinstance(method.name, Identifier): |
| 251 | + return method |
| 252 | + |
| 253 | + select_name = method.select.simple_name |
| 254 | + expected_name = self.module_alias or old_module |
| 255 | + if select_name != expected_name: |
| 256 | + return method |
| 257 | + if method.name.simple_name != old_name: |
| 258 | + return method |
| 259 | + |
| 260 | + self.rewrote_qualified_refs = True |
| 261 | + new_select_name = new_alias or new_module |
| 262 | + new_select = method.select.replace(_simple_name=new_select_name) |
| 263 | + # Update type attribution on the select identifier |
| 264 | + if method.select.type is not None: |
| 265 | + new_select = new_select.replace(_type=self._get_new_module_type()) |
| 266 | + new_padded_select = method.padding.select.replace(_element=new_select) |
| 267 | + result = method.padding.replace(_select=new_padded_select) |
| 268 | + if new_name and new_name != old_name: |
| 269 | + result = result.replace(_name=result.name.replace(_simple_name=new_name)) |
| 270 | + # Update method_type declaring type and name |
| 271 | + if result.method_type is not None: |
| 272 | + result = result.replace(_method_type=dc_replace( |
| 273 | + result.method_type, |
| 274 | + _declaring_type=self._get_new_module_type(), |
| 275 | + _name=new_name or old_name, |
| 276 | + )) |
| 277 | + return result |
| 278 | + |
| 279 | + def visit_field_access(self, field_access: FieldAccess, p: ExecutionContext) -> J: |
| 280 | + field_access = super().visit_field_access(field_access, p) |
| 281 | + if not old_name or not self.has_direct_module_import: |
| 282 | + return field_access |
| 283 | + if not isinstance(field_access, FieldAccess): |
| 284 | + return field_access |
| 285 | + if not isinstance(field_access.target, Identifier): |
| 286 | + return field_access |
| 287 | + |
| 288 | + existing_name = field_access.target.simple_name |
| 289 | + expected_name = self.module_alias or old_module |
| 290 | + if existing_name != expected_name: |
| 291 | + return field_access |
| 292 | + if field_access.name.simple_name != old_name: |
| 293 | + return field_access |
| 294 | + |
| 295 | + self.rewrote_qualified_refs = True |
| 296 | + new_target_name = new_alias or new_module |
| 297 | + new_target = field_access.target.replace(_simple_name=new_target_name) |
| 298 | + # Update type attribution on the target identifier |
| 299 | + if field_access.target.type is not None: |
| 300 | + new_target = new_target.replace(_type=self._get_new_module_type()) |
| 301 | + result = field_access.replace(_target=new_target) |
| 302 | + if new_name and new_name != old_name: |
| 303 | + new_name_ident = result.name.replace(_simple_name=new_name) |
| 304 | + result = result.padding.replace(_name=result.padding.name.replace(_element=new_name_ident)) |
| 305 | + return result |
| 306 | + |
| 307 | + def _get_new_module_type(self) -> JavaType.Class: |
| 308 | + if self._new_module_type is None: |
| 309 | + self._new_module_type = _create_module_type(new_module) |
| 310 | + return self._new_module_type |
| 311 | + |
199 | 312 | def _check_for_old_single_import(self, imp: Import) -> Optional[str]: |
200 | 313 | """Check if a standalone J.Import matches the old import.""" |
201 | 314 | if old_name: |
|
0 commit comments