Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,10 +1208,12 @@ def visit_MatchValue(self, node):
def visit_MatchSequence(self, node):
prefix = self.__whitespace()
end_delim = None
if self.__skip('['):
if self.__at_token('[') and self.__is_own_sequence_delimiter(node, '['):
self.__skip('[')
kind = py.MatchCase.Pattern.Kind.SEQUENCE_LIST
end_delim = ']'
elif self.__skip('('):
elif self.__at_token('(') and self.__is_own_sequence_delimiter(node, '('):
self.__skip('(')
kind = py.MatchCase.Pattern.Kind.SEQUENCE_TUPLE
end_delim = ')'
else:
Expand Down Expand Up @@ -3321,3 +3323,26 @@ def __at_token(self, s: str) -> bool:
if self._token_idx >= len(self._tokens):
return False
return self._tokens[self._token_idx].string == s

def __is_own_sequence_delimiter(self, node, delim: str) -> bool:
"""Check if the delimiter at the current token belongs to this MatchSequence.

When the current token is '[' (or '('), it could belong to this
sequence or to its first child (e.g., ``[c], _`` vs ``[c, _]``).

If the first child pattern is itself a MatchSequence, the delimiter
might belong to the child. We disambiguate by peeking at the next
token: if it is also a delimiter (``[`` or ``(``), the current one
opens this sequence (e.g., ``[[a], b]``); otherwise the current
delimiter belongs to the child (e.g., ``[c], _``).
"""
import ast as stdlib_ast
if node.patterns and isinstance(node.patterns[0], stdlib_ast.MatchSequence):
# The first child is also a sequence — check whether there are
# two consecutive delimiters, meaning the outer one is ours.
next_idx = self._token_idx + 1
if next_idx < len(self._tokens):
next_tok = self._tokens[next_idx].string
return next_tok in ('[', '(')
return False
return True
52 changes: 35 additions & 17 deletions rewrite-python/rewrite/src/rewrite/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,28 @@ def generate_id() -> str:
return str(uuid4())


def parse_python_file(path: str) -> dict:
def parse_python_file(path: str, relative_to: Optional[str] = None) -> dict:
"""Parse a Python file and return its LST."""
with open(path, 'r', encoding='utf-8') as f:
source = f.read()
return parse_python_source(source, path)
return parse_python_source(source, path, relative_to)


def parse_python_source(source: str, path: str = "<unknown>") -> dict:
def parse_python_source(source: str, path: str = "<unknown>", relative_to: Optional[str] = None) -> dict:
"""Parse Python source code and return its LST.

The parser used depends on the REWRITE_PYTHON_VERSION environment variable:
- "2" or "2.7": Use parso-based Py2ParserVisitor for Python 2 code
- "3" (default): Use ast-based ParserVisitor for Python 3 code
"""
# Compute the source_path that will be stored on the LST
source_path = Path(path)
if relative_to is not None:
try:
source_path = source_path.relative_to(relative_to)
except ValueError:
pass # path is not under relative_to, keep absolute

try:
from rewrite import Markers

Expand Down Expand Up @@ -250,7 +258,7 @@ def parse_python_source(source: str, path: str = "<unknown>") -> dict:
# Convert to OpenRewrite LST
cu = ParserVisitor(source, path).visit(tree)

cu = cu.replace(source_path=Path(path))
cu = cu.replace(source_path=source_path)
cu = cu.replace(markers=Markers.EMPTY)

# Store and return
Expand All @@ -263,14 +271,14 @@ def parse_python_source(source: str, path: str = "<unknown>") -> dict:
except ImportError as e:
logger.error(f"Failed to import parser: {e}")
traceback.print_exc()
return _create_parse_error(path, str(e), source)
return _create_parse_error(str(source_path), str(e), source)
except SyntaxError as e:
logger.error(f"Syntax error parsing {path}: {e}")
return _create_parse_error(path, str(e), source)
return _create_parse_error(str(source_path), str(e), source)
except Exception as e:
logger.error(f"Error parsing {path}: {e}")
traceback.print_exc()
return _create_parse_error(path, str(e), source)
return _create_parse_error(str(source_path), str(e), source)


def _create_parse_error(path: str, message: str, source: str = '') -> dict:
Expand Down Expand Up @@ -313,19 +321,29 @@ def _create_parse_error(path: str, message: str, source: str = '') -> dict:
def handle_parse(params: dict) -> List[str]:
"""Handle a Parse RPC request."""
inputs = params.get('inputs', [])
relative_to = params.get('relativeTo')
logger.info(f"handle_parse: {len(inputs)} inputs, relativeTo={relative_to}")
results = []

for input_item in inputs:
if 'path' in input_item:
# File input
result = parse_python_file(input_item['path'])
for i, input_item in enumerate(inputs):
if isinstance(input_item, str):
# PathInput serialized via @JsonValue as a bare path string
logger.info(f" [{i}] parsing file: {input_item}")
result = parse_python_file(input_item, relative_to)
elif 'path' in input_item:
# File input as dict
logger.info(f" [{i}] parsing file: {input_item['path']}")
result = parse_python_file(input_item['path'], relative_to)
elif 'text' in input_item or 'source' in input_item:
# String input - Java sends 'text' and 'sourcePath'
source = input_item.get('text') or input_item.get('source')
path = input_item.get('sourcePath') or input_item.get('relativePath', '<unknown>')
result = parse_python_source(source, path)
logger.info(f" [{i}] parsing source: {path}")
result = parse_python_source(source, path, relative_to)
else:
logger.warning(f" [{i}] unknown input type: {type(input_item)}")
continue
logger.info(f" [{i}] result: {result}")
results.append(result['id'])

return results
Expand Down Expand Up @@ -387,18 +405,18 @@ def handle_get_object(params: dict) -> List[dict]:
before = remote_objects.get(obj_id)

q = RpcSendQueue(source_file_type)
logger.info(f"handle_get_object: starting generate for {obj_id}")
Comment thread
knutwannheden marked this conversation as resolved.
Outdated
result = q.generate(obj, before)
logger.debug(f"GetObject result: {len(result)} items")
for i, item in enumerate(result[:10]): # Log first 10 items
logger.debug(f" [{i}] {item}")
logger.info(f"handle_get_object: generate complete, {len(result)} items")

# Update remote_objects to track that Java now has this version
remote_objects[obj_id] = obj

return result

except Exception as e:
logger.error(f"Error serializing object: {e}")
except BaseException as e:
source_path = getattr(obj, 'source_path', None)
logger.error(f"Error serializing object {obj_id} (type={type(obj).__name__}, path={source_path}): {e}")
import traceback as tb
tb.print_exc()
return [{'state': 'END_OF_OBJECT'}]
Expand Down
109 changes: 109 additions & 0 deletions rewrite-python/rewrite/tests/python/all/tree/match_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,115 @@ def f(x):
))


def test_match_with_star_wildcard_and_capture():
# language=python - star wildcard followed by a capture variable
RecipeSpec().rewrite_run(python(
"""\
match x:
case [*_, stmt]:
pass
"""
))


def test_match_with_star_capture_and_variable():
# language=python - star capture followed by a variable
RecipeSpec().rewrite_run(python(
"""\
match x:
case [*prev, stmt]:
pass
"""
))


def test_match_with_star_wildcard_expression_not_none():
# Verify that the Star node for *_ has a non-None expression
import ast as stdlib_ast
from rewrite.python._parser_visitor import ParserVisitor

code = """\
match x:
case [*_, stmt]:
pass
"""
tree = stdlib_ast.parse(code)
cu = ParserVisitor(code, 'test.py').visit(tree)

# Walk the LST to find Star nodes
stars = []
_collect_stars(cu, stars)
assert len(stars) > 0, "Expected at least one Star node"
for star in stars:
assert star.expression is not None, "Star expression should not be None for *_ pattern"


def test_match_with_nested_star_wildcard_expression_not_none():
# Verify that Star nodes in nested match-class patterns have non-None expression
# This is the pattern from refurb that triggers the PythonValidator error
import ast as stdlib_ast
from rewrite.python._parser_visitor import ParserVisitor

code = """\
match node:
case IfStmt(else_body=Block(body=[*_, stmt])) | WithStmt(body=Block(body=[*_, stmt])):
pass
case ForStmt(body=Block(body=[*prev, stmt])) | WhileStmt(body=Block(body=[*prev, stmt])):
pass
"""
tree = stdlib_ast.parse(code)
cu = ParserVisitor(code, 'test.py').visit(tree)

stars = []
_collect_stars(cu, stars)
assert len(stars) > 0, f"Expected Star nodes, found none"
for star in stars:
assert star.expression is not None, f"Star expression should not be None"


def _collect_stars(node, result, visited=None):
"""Recursively collect Star nodes from an LST."""
from rewrite.python.tree import Star

if visited is None:
visited = set()
node_id = id(node)
if node_id in visited:
return
visited.add(node_id)

if isinstance(node, Star):
result.append(node)

# Check dataclass fields
if hasattr(node, '__dataclass_fields__'):
for field_name in node.__dataclass_fields__:
val = getattr(node, field_name, None)
if val is None or isinstance(val, (str, int, float, bool, bytes)):
continue
if hasattr(val, '__dataclass_fields__'):
_collect_stars(val, result, visited)
elif isinstance(val, list):
for item in val:
if hasattr(item, '__dataclass_fields__'):
_collect_stars(item, result, visited)
elif hasattr(item, 'element'):
_collect_stars(item.element, result, visited)


def test_match_tuple_with_sequence_pattern():
# language=python - implicit tuple match with sequence pattern [c] as first element
RecipeSpec().rewrite_run(python(
"""\
match (y, z):
case _, b:
pass
case [c], _:
pass
"""
))


def test_match_with_or_pattern_in_tuple():
# language=python - OR pattern as first element of implicit tuple
RecipeSpec().rewrite_run(python(
Expand Down