Skip to content

Commit ad53660

Browse files
Python: changing how imports are parsed, prefer J.Import (#6730)
* Python: Use J.Import for single imports instead of MultiImport Single `import foo` statements no longer get unnecessarily wrapped in a `MultiImport` node. MultiImport is now only used when there are multiple names (`import foo, bar`). Updated parser, printer, and import recipes (add, remove, change) accordingly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Added assertion --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7941fa7 commit ad53660

6 files changed

Lines changed: 76 additions & 5 deletions

File tree

rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,11 @@ def visit_Try(self, node):
789789
)
790790

791791
def visit_Import(self, node):
792-
# TODO only use `MultiImport` when necessary (requires corresponding changes to printer)
792+
if len(node.names) == 1:
793+
prefix = self.__source_before('import')
794+
imp = self.__convert(node.names[0])
795+
return imp.replace(prefix=prefix, qualid=imp.qualid.replace(prefix=imp.prefix))
796+
793797
prefix = self.__source_before('import')
794798
names_prefix = self.__whitespace()
795799
return py.MultiImport(

rewrite-python/rewrite/src/rewrite/python/add_import.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ def visit_compilation_unit(self, cu: CompilationUnit, p) -> J:
131131
def _import_exists(self, cu: CompilationUnit) -> bool:
132132
"""Check if the import already exists."""
133133
for stmt in cu.statements:
134-
if isinstance(stmt, MultiImport):
134+
if isinstance(stmt, Import) and not isinstance(stmt, MultiImport):
135+
if self.name is None and self._import_name_matches(stmt, self.module, self.alias):
136+
return True
137+
elif isinstance(stmt, MultiImport):
135138
if self._multi_import_matches(stmt):
136139
return True
137140
return False

rewrite-python/rewrite/src/rewrite/python/printer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,7 +1475,14 @@ def visit_import(self, import_: 'j.Import', p: PrintOutputCapture) -> J:
14751475

14761476
self._before_syntax(import_, p)
14771477

1478+
from rewrite.python import tree as _py
1479+
is_standalone = not self.get_cursor().first_enclosing(_py.MultiImport)
1480+
if is_standalone:
1481+
p.append("import")
1482+
14781483
if isinstance(import_.qualid.target, j.Empty):
1484+
if is_standalone:
1485+
self._visit_space(import_.qualid.prefix, p)
14791486
self.visit(import_.qualid.name, p)
14801487
else:
14811488
self.visit(import_.qualid, p)

rewrite-python/rewrite/src/rewrite/python/recipes/change_import.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,13 @@ def visit_compilation_unit(self, cu: CompilationUnit, p: ExecutionContext) -> J:
131131
self.old_alias = None
132132

133133
for stmt in cu.statements:
134-
if isinstance(stmt, MultiImport):
134+
if isinstance(stmt, Import) and not isinstance(stmt, MultiImport):
135+
alias = self._check_for_old_single_import(stmt)
136+
if alias is not None:
137+
self.has_old_import = True
138+
self.old_alias = alias if alias != "" else None
139+
break
140+
elif isinstance(stmt, MultiImport):
135141
alias = self._check_for_old_import(stmt)
136142
if alias is not None:
137143
self.has_old_import = True
@@ -164,6 +170,16 @@ def visit_compilation_unit(self, cu: CompilationUnit, p: ExecutionContext) -> J:
164170

165171
return result
166172

173+
def visit_import(self, imp: Import, p: ExecutionContext) -> Optional[J]:
174+
if not self.has_old_import or old_name:
175+
return imp
176+
if self.cursor.first_enclosing(MultiImport):
177+
return imp
178+
alias = self._check_for_old_single_import(imp)
179+
if alias is None:
180+
return imp
181+
return None
182+
167183
def visit_multi_import(self, multi: MultiImport, p: ExecutionContext) -> Optional[J]: # ty: ignore[invalid-method-override]
168184
if not self.has_old_import:
169185
return multi
@@ -180,6 +196,15 @@ def visit_multi_import(self, multi: MultiImport, p: ExecutionContext) -> Optiona
180196
# import X - remove entire import
181197
return self._remove_module_from_import(multi, old_module)
182198

199+
def _check_for_old_single_import(self, imp: Import) -> Optional[str]:
200+
"""Check if a standalone J.Import matches the old import."""
201+
if old_name:
202+
return None
203+
name = self._get_qualid_name(imp.qualid)
204+
if name == old_module:
205+
return self._get_alias_name(imp) or ""
206+
return None
207+
183208
def _check_for_old_import(self, multi: MultiImport) -> Optional[str]:
184209
"""Check if this MultiImport matches the old import.
185210

rewrite-python/rewrite/src/rewrite/python/remove_import.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,20 @@ def _remove_import(self, cu: CompilationUnit) -> CompilationUnit:
135135

136136
for padded in cu.padding.statements:
137137
stmt = padded.element
138+
if isinstance(stmt, Import) and not isinstance(stmt, MultiImport):
139+
result = self._process_single_import(stmt)
140+
if result is None:
141+
removed_prefix = stmt.prefix
142+
changed = True
143+
else:
144+
if removed_prefix is not None:
145+
padded = JRightPadded(
146+
padded.element.replace(prefix=removed_prefix),
147+
padded.after, padded.markers
148+
)
149+
removed_prefix = None
150+
new_padded_stmts.append(padded)
151+
continue
138152
if isinstance(stmt, MultiImport):
139153
result = self._process_multi_import(stmt)
140154
if result is None:
@@ -169,6 +183,15 @@ def _remove_import(self, cu: CompilationUnit) -> CompilationUnit:
169183
return cu.padding.replace(_statements=new_padded_stmts)
170184
return cu
171185

186+
def _process_single_import(self, imp: Import) -> Optional[Import]:
187+
"""Process a standalone J.Import. Return None to remove, or the original."""
188+
if self.name is not None:
189+
return imp
190+
name = self._get_qualid_name(imp.qualid)
191+
if name == self.module:
192+
return None
193+
return imp
194+
172195
def _process_multi_import(self, multi: MultiImport) -> Optional[MultiImport]:
173196
"""Process a MultiImport and return None to remove, modified, or original."""
174197
if self.name is None:

rewrite-python/rewrite/tests/python/all/tree/import_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
from rewrite.java import Import
2+
from rewrite.python import CompilationUnit
13
from rewrite.test import RecipeSpec, python
24

35

6+
def _assert_single_j_import(cu: CompilationUnit) -> None:
7+
assert len(cu.statements) == 1
8+
stmt = cu.statements[0]
9+
assert isinstance(stmt, Import), \
10+
f"Single import should be a J.Import, got {type(stmt).__name__}"
11+
12+
413
# noinspection PyUnresolvedReferences
514
def test_simple():
615
# language=python
7-
RecipeSpec().rewrite_run(python("import io"))
16+
RecipeSpec().rewrite_run(python("import io", after_recipe=_assert_single_j_import))
817

918

1019
# noinspection PyUnresolvedReferences
@@ -22,7 +31,7 @@ def test_unicode_char_normalization():
2231
# noinspection PyUnresolvedReferences
2332
def test_qualified():
2433
# language=python
25-
RecipeSpec().rewrite_run(python("import xml.dom"))
34+
RecipeSpec().rewrite_run(python("import xml.dom", after_recipe=_assert_single_j_import))
2635

2736

2837
# noinspection PyUnresolvedReferences

0 commit comments

Comments
 (0)