diff --git a/rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py b/rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py index 8c71a4793ca..f03d4f590c1 100644 --- a/rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py +++ b/rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py @@ -1208,10 +1208,12 @@ def visit_MatchValue(self, node): def visit_MatchSequence(self, node): prefix = self.__whitespace() end_delim = None - if self.__skip('['): + if self.__at_token('[') and self.__is_own_sequence_delimiter(node, '['): + self.__skip('[') kind = py.MatchCase.Pattern.Kind.SEQUENCE_LIST end_delim = ']' - elif self.__skip('('): + elif self.__at_token('(') and self.__is_own_sequence_delimiter(node, '('): + self.__skip('(') kind = py.MatchCase.Pattern.Kind.SEQUENCE_TUPLE end_delim = ')' else: @@ -3321,3 +3323,26 @@ def __at_token(self, s: str) -> bool: if self._token_idx >= len(self._tokens): return False return self._tokens[self._token_idx].string == s + + def __is_own_sequence_delimiter(self, node, delim: str) -> bool: + """Check if the delimiter at the current token belongs to this MatchSequence. + + When the current token is '[' (or '('), it could belong to this + sequence or to its first child (e.g., ``[c], _`` vs ``[c, _]``). + + If the first child pattern is itself a MatchSequence, the delimiter + might belong to the child. We disambiguate by peeking at the next + token: if it is also a delimiter (``[`` or ``(``), the current one + opens this sequence (e.g., ``[[a], b]``); otherwise the current + delimiter belongs to the child (e.g., ``[c], _``). + """ + import ast as stdlib_ast + if node.patterns and isinstance(node.patterns[0], stdlib_ast.MatchSequence): + # The first child is also a sequence — check whether there are + # two consecutive delimiters, meaning the outer one is ours. + next_idx = self._token_idx + 1 + if next_idx < len(self._tokens): + next_tok = self._tokens[next_idx].string + return next_tok in ('[', '(') + return False + return True diff --git a/rewrite-python/rewrite/src/rewrite/rpc/server.py b/rewrite-python/rewrite/src/rewrite/rpc/server.py index 277154e0e57..25f47bedace 100644 --- a/rewrite-python/rewrite/src/rewrite/rpc/server.py +++ b/rewrite-python/rewrite/src/rewrite/rpc/server.py @@ -209,20 +209,28 @@ def generate_id() -> str: return str(uuid4()) -def parse_python_file(path: str) -> dict: +def parse_python_file(path: str, relative_to: Optional[str] = None) -> dict: """Parse a Python file and return its LST.""" with open(path, 'r', encoding='utf-8') as f: source = f.read() - return parse_python_source(source, path) + return parse_python_source(source, path, relative_to) -def parse_python_source(source: str, path: str = "") -> dict: +def parse_python_source(source: str, path: str = "", relative_to: Optional[str] = None) -> dict: """Parse Python source code and return its LST. The parser used depends on the REWRITE_PYTHON_VERSION environment variable: - "2" or "2.7": Use parso-based Py2ParserVisitor for Python 2 code - "3" (default): Use ast-based ParserVisitor for Python 3 code """ + # Compute the source_path that will be stored on the LST + source_path = Path(path) + if relative_to is not None: + try: + source_path = source_path.relative_to(relative_to) + except ValueError: + pass # path is not under relative_to, keep absolute + try: from rewrite import Markers @@ -250,7 +258,7 @@ def parse_python_source(source: str, path: str = "") -> dict: # Convert to OpenRewrite LST cu = ParserVisitor(source, path).visit(tree) - cu = cu.replace(source_path=Path(path)) + cu = cu.replace(source_path=source_path) cu = cu.replace(markers=Markers.EMPTY) # Store and return @@ -263,14 +271,14 @@ def parse_python_source(source: str, path: str = "") -> dict: except ImportError as e: logger.error(f"Failed to import parser: {e}") traceback.print_exc() - return _create_parse_error(path, str(e), source) + return _create_parse_error(str(source_path), str(e), source) except SyntaxError as e: logger.error(f"Syntax error parsing {path}: {e}") - return _create_parse_error(path, str(e), source) + return _create_parse_error(str(source_path), str(e), source) except Exception as e: logger.error(f"Error parsing {path}: {e}") traceback.print_exc() - return _create_parse_error(path, str(e), source) + return _create_parse_error(str(source_path), str(e), source) def _create_parse_error(path: str, message: str, source: str = '') -> dict: @@ -313,18 +321,23 @@ def _create_parse_error(path: str, message: str, source: str = '') -> dict: def handle_parse(params: dict) -> List[str]: """Handle a Parse RPC request.""" inputs = params.get('inputs', []) + relative_to = params.get('relativeTo') results = [] - for input_item in inputs: - if 'path' in input_item: - # File input - result = parse_python_file(input_item['path']) + for i, input_item in enumerate(inputs): + if isinstance(input_item, str): + # PathInput serialized via @JsonValue as a bare path string + result = parse_python_file(input_item, relative_to) + elif 'path' in input_item: + # File input as dict + result = parse_python_file(input_item['path'], relative_to) elif 'text' in input_item or 'source' in input_item: # String input - Java sends 'text' and 'sourcePath' source = input_item.get('text') or input_item.get('source') path = input_item.get('sourcePath') or input_item.get('relativePath', '') - result = parse_python_source(source, path) + result = parse_python_source(source, path, relative_to) else: + logger.warning(f" [{i}] unknown input type: {type(input_item)}") continue results.append(result['id']) @@ -388,17 +401,15 @@ def handle_get_object(params: dict) -> List[dict]: q = RpcSendQueue(source_file_type) result = q.generate(obj, before) - logger.debug(f"GetObject result: {len(result)} items") - for i, item in enumerate(result[:10]): # Log first 10 items - logger.debug(f" [{i}] {item}") # Update remote_objects to track that Java now has this version remote_objects[obj_id] = obj return result - except Exception as e: - logger.error(f"Error serializing object: {e}") + except BaseException as e: + source_path = getattr(obj, 'source_path', None) + logger.error(f"Error serializing object {obj_id} (type={type(obj).__name__}, path={source_path}): {e}") import traceback as tb tb.print_exc() return [{'state': 'END_OF_OBJECT'}] diff --git a/rewrite-python/rewrite/tests/python/all/tree/match_test.py b/rewrite-python/rewrite/tests/python/all/tree/match_test.py index 9d63b1b0911..91a8237235e 100644 --- a/rewrite-python/rewrite/tests/python/all/tree/match_test.py +++ b/rewrite-python/rewrite/tests/python/all/tree/match_test.py @@ -117,6 +117,115 @@ def f(x): )) +def test_match_with_star_wildcard_and_capture(): + # language=python - star wildcard followed by a capture variable + RecipeSpec().rewrite_run(python( + """\ +match x: + case [*_, stmt]: + pass +""" + )) + + +def test_match_with_star_capture_and_variable(): + # language=python - star capture followed by a variable + RecipeSpec().rewrite_run(python( + """\ +match x: + case [*prev, stmt]: + pass +""" + )) + + +def test_match_with_star_wildcard_expression_not_none(): + # Verify that the Star node for *_ has a non-None expression + import ast as stdlib_ast + from rewrite.python._parser_visitor import ParserVisitor + + code = """\ +match x: + case [*_, stmt]: + pass +""" + tree = stdlib_ast.parse(code) + cu = ParserVisitor(code, 'test.py').visit(tree) + + # Walk the LST to find Star nodes + stars = [] + _collect_stars(cu, stars) + assert len(stars) > 0, "Expected at least one Star node" + for star in stars: + assert star.expression is not None, "Star expression should not be None for *_ pattern" + + +def test_match_with_nested_star_wildcard_expression_not_none(): + # Verify that Star nodes in nested match-class patterns have non-None expression + # This is the pattern from refurb that triggers the PythonValidator error + import ast as stdlib_ast + from rewrite.python._parser_visitor import ParserVisitor + + code = """\ +match node: + case IfStmt(else_body=Block(body=[*_, stmt])) | WithStmt(body=Block(body=[*_, stmt])): + pass + case ForStmt(body=Block(body=[*prev, stmt])) | WhileStmt(body=Block(body=[*prev, stmt])): + pass +""" + tree = stdlib_ast.parse(code) + cu = ParserVisitor(code, 'test.py').visit(tree) + + stars = [] + _collect_stars(cu, stars) + assert len(stars) > 0, f"Expected Star nodes, found none" + for star in stars: + assert star.expression is not None, f"Star expression should not be None" + + +def _collect_stars(node, result, visited=None): + """Recursively collect Star nodes from an LST.""" + from rewrite.python.tree import Star + + if visited is None: + visited = set() + node_id = id(node) + if node_id in visited: + return + visited.add(node_id) + + if isinstance(node, Star): + result.append(node) + + # Check dataclass fields + if hasattr(node, '__dataclass_fields__'): + for field_name in node.__dataclass_fields__: + val = getattr(node, field_name, None) + if val is None or isinstance(val, (str, int, float, bool, bytes)): + continue + if hasattr(val, '__dataclass_fields__'): + _collect_stars(val, result, visited) + elif isinstance(val, list): + for item in val: + if hasattr(item, '__dataclass_fields__'): + _collect_stars(item, result, visited) + elif hasattr(item, 'element'): + _collect_stars(item.element, result, visited) + + +def test_match_tuple_with_sequence_pattern(): + # language=python - implicit tuple match with sequence pattern [c] as first element + RecipeSpec().rewrite_run(python( + """\ +match (y, z): + case _, b: + pass + case [c], _: + pass +""" + )) + + def test_match_with_or_pattern_in_tuple(): # language=python - OR pattern as first element of implicit tuple RecipeSpec().rewrite_run(python(