Skip to content

Commit b645662

Browse files
committed
Python: Fix RPC issues
Enums get sent as strings and need to be converted to enums o the Python side.
1 parent be49de4 commit b645662

3 files changed

Lines changed: 95 additions & 35 deletions

File tree

rewrite-python/rewrite/src/rewrite/python/support_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from dataclasses import replace, dataclass
4-
from enum import Enum
54
from typing import TypeVar, Any, Optional, TYPE_CHECKING
65

76
from rewrite import TreeVisitor, Markers

rewrite-python/rewrite/src/rewrite/rpc/python_receiver.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,12 @@
1818
This uses the visitor pattern with pre_visit handling common fields (id, prefix, markers)
1919
and type-specific visit methods handling only additional fields.
2020
"""
21+
from enum import Enum
2122
from pathlib import Path
22-
from typing import Any, Optional, TypeVar, List
23-
from uuid import UUID
23+
from typing import Any, Callable, Optional, Type, TypeVar
2424

25-
from rewrite import Markers
26-
from rewrite.utils import replace_if_changed
2725
from rewrite.java import Space, JRightPadded, JLeftPadded, JContainer, J
2826
from rewrite.python import CompilationUnit
29-
from rewrite.python.support_types import Py
3027
from rewrite.python.tree import (
3128
Async, Await, Binary, ChainedAssignment, ExceptionType,
3229
LiteralType, TypeHint, ExpressionStatement, ExpressionTypeTree,
@@ -36,8 +33,20 @@
3633
Star, NamedArgument, TypeHintedExpression, ErrorFrom, MatchCase, Slice
3734
)
3835
from rewrite.rpc.receive_queue import RpcReceiveQueue
36+
from rewrite.utils import replace_if_changed
3937

4038
T = TypeVar('T')
39+
E = TypeVar('E', bound=Enum)
40+
41+
42+
def _to_enum(enum_class: Type[E]) -> Callable[[Any], E]:
43+
"""Create a mapping function that converts string values to enum members.
44+
45+
Similar to Java's toEnum(EnumClass.class) used in RPC deserialization.
46+
"""
47+
def mapper(value: Any) -> E:
48+
return enum_class[value] if isinstance(value, str) else value
49+
return mapper
4150

4251

4352
class PythonRpcReceiver:
@@ -204,7 +213,7 @@ def _visit_await(self, await_: Await, q: RpcReceiveQueue) -> Await:
204213

205214
def _visit_binary(self, binary: Binary, q: RpcReceiveQueue) -> Binary:
206215
left = q.receive(binary.left)
207-
operator = q.receive(binary.padding.operator)
216+
operator = q.receive(binary.padding.operator, lambda lp: self._receive_left_padded(lp, q, _to_enum(Binary.Type)))
208217
negation = q.receive(binary.negation)
209218
right = q.receive(binary.right)
210219
type_ = q.receive(binary.type)
@@ -262,7 +271,7 @@ def _visit_dict_literal(self, dl: DictLiteral, q: RpcReceiveQueue) -> DictLitera
262271
return replace_if_changed(dl, elements=elements, type=type_)
263272

264273
def _visit_collection_literal(self, cl: CollectionLiteral, q: RpcReceiveQueue) -> CollectionLiteral:
265-
kind = q.receive(cl.kind)
274+
kind = _to_enum(CollectionLiteral.Kind)(q.receive(cl.kind))
266275
elements = q.receive(cl.padding.elements)
267276
type_ = q.receive(cl.type)
268277
return replace_if_changed(cl, kind=kind, elements=elements, type=type_)
@@ -276,7 +285,7 @@ def _visit_formatted_string(self, fs: FormattedString, q: RpcReceiveQueue) -> Fo
276285
def _visit_formatted_string_value(self, v: FormattedString.Value, q: RpcReceiveQueue) -> FormattedString.Value:
277286
expression = q.receive(v.padding.expression)
278287
debug = q.receive(v.padding.debug)
279-
conversion = q.receive(v.conversion)
288+
conversion = _to_enum(FormattedString.Value.Conversion)(q.receive(v.conversion))
280289
format_ = q.receive(v.format)
281290
return replace_if_changed(v, expression=expression, debug=debug, conversion=conversion, format=format_)
282291

@@ -290,7 +299,7 @@ def _visit_trailing_else_wrapper(self, tew: TrailingElseWrapper, q: RpcReceiveQu
290299
return replace_if_changed(tew, statement=statement, else_block=else_block)
291300

292301
def _visit_comprehension_expression(self, ce: ComprehensionExpression, q: RpcReceiveQueue) -> ComprehensionExpression:
293-
kind = q.receive(ce.kind)
302+
kind = _to_enum(ComprehensionExpression.Kind)(q.receive(ce.kind))
294303
result = q.receive(ce.result)
295304
clauses = q.receive_list(ce.clauses)
296305
suffix = q.receive(ce.suffix)
@@ -325,7 +334,7 @@ def _visit_union_type(self, ut: UnionType, q: RpcReceiveQueue) -> UnionType:
325334
return replace_if_changed(ut, types=types, type=type_)
326335

327336
def _visit_variable_scope(self, vs: VariableScope, q: RpcReceiveQueue) -> VariableScope:
328-
kind = q.receive(vs.kind)
337+
kind = _to_enum(VariableScope.Kind)(q.receive(vs.kind))
329338
names = q.receive_list(vs.padding.names)
330339
return replace_if_changed(vs, kind=kind, names=names)
331340

@@ -334,13 +343,13 @@ def _visit_del(self, del_: Del, q: RpcReceiveQueue) -> Del:
334343
return replace_if_changed(del_, targets=targets)
335344

336345
def _visit_special_parameter(self, sp: SpecialParameter, q: RpcReceiveQueue) -> SpecialParameter:
337-
kind = q.receive(sp.kind)
346+
kind = _to_enum(SpecialParameter.Kind)(q.receive(sp.kind))
338347
type_hint = q.receive(sp.type_hint)
339348
type_ = q.receive(sp.type)
340349
return replace_if_changed(sp, kind=kind, type_hint=type_hint, type=type_)
341350

342351
def _visit_star(self, star: Star, q: RpcReceiveQueue) -> Star:
343-
kind = q.receive(star.kind)
352+
kind = _to_enum(Star.Kind)(q.receive(star.kind))
344353
expression = q.receive(star.expression)
345354
type_ = q.receive(star.type)
346355
return replace_if_changed(star, kind=kind, expression=expression, type=type_)
@@ -370,7 +379,7 @@ def _visit_match_case(self, mc: MatchCase, q: RpcReceiveQueue) -> MatchCase:
370379
return replace_if_changed(mc, pattern=pattern, guard=guard, type=type_)
371380

372381
def _visit_match_case_pattern(self, p: MatchCase.Pattern, q: RpcReceiveQueue) -> MatchCase.Pattern:
373-
kind = q.receive(p.kind)
382+
kind = _to_enum(MatchCase.Pattern.Kind)(q.receive(p.kind))
374383
children = q.receive(p.padding.children)
375384
type_ = q.receive(p.type)
376385
return replace_if_changed(p, kind=kind, children=children, type=type_)
@@ -533,14 +542,16 @@ def _visit_block(self, block, q: RpcReceiveQueue):
533542
return replace_if_changed(block, static=static, statements=statements, end=end)
534543

535544
def _visit_j_unary(self, unary, q: RpcReceiveQueue):
536-
operator = q.receive(unary.padding.operator)
545+
from rewrite.java.tree import Unary
546+
operator = q.receive(unary.padding.operator, lambda lp: self._receive_left_padded(lp, q, _to_enum(Unary.Type)))
537547
expression = q.receive(unary.expression)
538548
type_ = q.receive(unary.type)
539549
return replace_if_changed(unary, operator=operator, expression=expression, type=type_)
540550

541551
def _visit_j_binary(self, binary, q: RpcReceiveQueue):
552+
from rewrite.java.tree import Binary as JBinary
542553
left = q.receive(binary.left)
543-
operator = q.receive(binary.padding.operator)
554+
operator = q.receive(binary.padding.operator, lambda lp: self._receive_left_padded(lp, q, _to_enum(JBinary.Type)))
544555
right = q.receive(binary.right)
545556
type_ = q.receive(binary.type)
546557
return replace_if_changed(binary, left=left, operator=operator, right=right, type=type_)
@@ -552,8 +563,9 @@ def _visit_j_assignment(self, assign, q: RpcReceiveQueue):
552563
return replace_if_changed(assign, variable=variable, assignment=assignment, type=type_)
553564

554565
def _visit_j_assignment_operation(self, assign, q: RpcReceiveQueue):
566+
from rewrite.java.tree import AssignmentOperation
555567
variable = q.receive(assign.variable)
556-
operator = q.receive(assign.padding.operator)
568+
operator = q.receive(assign.padding.operator, lambda lp: self._receive_left_padded(lp, q, _to_enum(AssignmentOperation.Type)))
557569
assignment = q.receive(assign.assignment)
558570
type_ = q.receive(assign.type)
559571
return replace_if_changed(assign, variable=variable, operator=operator, assignment=assignment, type=type_)
@@ -702,9 +714,10 @@ def _visit_j_class_declaration(self, class_decl, q: RpcReceiveQueue):
702714
permits=permits, body=body)
703715

704716
def _visit_j_class_declaration_kind(self, kind, q: RpcReceiveQueue):
717+
from rewrite.java.tree import ClassDeclaration
705718
# Note: _pre_visit is already called by _visit before this method
706719
annotations = q.receive_list(kind.annotations)
707-
type_ = q.receive(kind.type) # Enum type
720+
type_ = _to_enum(ClassDeclaration.Kind.Type)(q.receive(kind.type))
708721
return replace_if_changed(kind, annotations=annotations, type=type_)
709722

710723
def _visit_j_method_declaration(self, method, q: RpcReceiveQueue):
@@ -741,7 +754,8 @@ def _visit_j_switch(self, switch, q: RpcReceiveQueue):
741754
return replace_if_changed(switch, selector=selector, cases=cases)
742755

743756
def _visit_j_case(self, case, q: RpcReceiveQueue):
744-
type_ = q.receive(case.type) # Enum type
757+
from rewrite.java.tree import Case
758+
type_ = _to_enum(Case.Type)(q.receive(case.type))
745759
case_labels = q.receive(case.padding.case_labels)
746760
statements = q.receive(case.padding.statements)
747761
body = q.receive(case.padding.body, lambda rp: self._receive_right_padded(rp, q) if rp else None)
@@ -785,8 +799,9 @@ def _visit_j_control_parentheses(self, parens, q: RpcReceiveQueue):
785799
return replace_if_changed(parens, tree=tree)
786800

787801
def _visit_j_modifier(self, mod, q: RpcReceiveQueue):
802+
from rewrite.java.tree import Modifier
788803
keyword = q.receive(mod.keyword)
789-
type_ = q.receive(mod.type) # Enum type
804+
type_ = _to_enum(Modifier.Type)(q.receive(mod.type))
790805
annotations = q.receive_list(mod.annotations)
791806
return replace_if_changed(mod, keyword=keyword, type=type_, annotations=annotations)
792807

@@ -869,14 +884,16 @@ def _receive_right_padded(self, rp: JRightPadded, q: RpcReceiveQueue) -> Optiona
869884

870885
return JRightPadded(element, after, markers)
871886

872-
def _receive_left_padded(self, lp: JLeftPadded, q: RpcReceiveQueue) -> Optional[JLeftPadded]:
887+
def _receive_left_padded(self, lp: JLeftPadded, q: RpcReceiveQueue, element_mapping: Optional[Callable[[Any], Any]] = None) -> Optional[JLeftPadded]:
873888
"""Receive a JLeftPadded wrapper."""
874889
if lp is None:
875890
return None
876891

877892
# Codec registry handles type dispatch automatically
878893
before = q.receive_defined(lp.before)
879894
element = q.receive(lp.element)
895+
if element_mapping is not None:
896+
element = element_mapping(element)
880897
markers = q.receive_markers(lp.markers)
881898

882899
if before is lp.before and element is lp.element and markers is lp.markers:
@@ -966,7 +983,7 @@ def _receive_type(self, java_type, q: RpcReceiveQueue):
966983
# Class: flagsBitMap, kind, fullyQualifiedName, typeParameters, supertype,
967984
# owningClass, annotations, interfaces, members, methods
968985
flags = q.receive_defined(getattr(java_type, '_flags_bit_map', 0))
969-
kind = q.receive(getattr(java_type, '_kind', JT.FullyQualified.Kind.Class))
986+
kind = _to_enum(JT.FullyQualified.Kind)(q.receive(getattr(java_type, '_kind', JT.FullyQualified.Kind.Class)))
970987
fqn = q.receive_defined(getattr(java_type, '_fully_qualified_name', ''))
971988
type_params = q.receive_list(getattr(java_type, '_type_parameters', None) or [],
972989
lambda t: self._receive_type(t, q))
@@ -1279,7 +1296,7 @@ def _receive_java_type_class(cls, q: RpcReceiveQueue):
12791296
# flagsBitMap, kind, fullyQualifiedName, typeParameters, supertype,
12801297
# owningClass, annotations, interfaces, members, methods
12811298
flags = q.receive_defined(getattr(cls, '_flags_bit_map', 0) if cls else 0)
1282-
kind = q.receive(getattr(cls, '_kind', JT.FullyQualified.Kind.Class) if cls else JT.FullyQualified.Kind.Class)
1299+
kind = _to_enum(JT.FullyQualified.Kind)(q.receive(getattr(cls, '_kind', JT.FullyQualified.Kind.Class) if cls else JT.FullyQualified.Kind.Class))
12831300
fqn = q.receive_defined(getattr(cls, '_fully_qualified_name', '') if cls else '')
12841301
type_params = q.receive_list(getattr(cls, '_type_parameters', None) if cls else None)
12851302
supertype = q.receive(getattr(cls, '_supertype', None) if cls else None)
@@ -1308,8 +1325,7 @@ def _receive_java_type_class(cls, q: RpcReceiveQueue):
13081325
def _register_java_type_codecs():
13091326
"""Register codecs for JavaType classes."""
13101327
from rewrite.java.support_types import JavaType as JT
1311-
from rewrite.rpc.receive_queue import register_codec_with_both_names, register_receive_codec, register_send_codec
1312-
from rewrite.rpc.send_queue import RpcSendQueue
1328+
from rewrite.rpc.receive_queue import register_codec_with_both_names
13131329

13141330
# JavaType.Primitive - special handling to consume keyword
13151331
register_codec_with_both_names(
@@ -1572,7 +1588,7 @@ def _send_keyword_only_arguments(marker: KeywordOnlyArguments, q: RpcSendQueue)
15721588
# Quoted - has id and style (enum)
15731589
def _receive_quoted(marker: Quoted, q: RpcReceiveQueue) -> Quoted:
15741590
new_id = q.receive_defined(marker.id)
1575-
new_style = q.receive_defined(marker.style)
1591+
new_style = _to_enum(Quoted.Style)(q.receive_defined(marker.style))
15761592
if new_id is marker.id and new_style is marker.style:
15771593
return marker
15781594
result = marker

rewrite-python/rewrite/src/rewrite/rpc/server.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def handle_get_object(params: dict) -> List[dict]:
385385
if obj_id is None:
386386
return [{'state': 'DELETE'}, {'state': 'END_OF_OBJECT'}]
387387
obj = local_objects.get(obj_id)
388-
logger.info(f"handle_get_object: id={obj_id}, type={type(obj).__name__ if obj else 'None'}")
388+
logger.debug(f"handle_get_object: id={obj_id}, type={type(obj).__name__ if obj else 'None'}")
389389

390390
if obj is None:
391391
return [
@@ -446,7 +446,7 @@ def handle_print(params: dict) -> str:
446446
obj_id = params.get('treeId') or params.get('id')
447447
source_file_type = params.get('sourceFileType')
448448

449-
logger.info(f"handle_print: treeId={obj_id}, sourceFileType={source_file_type}")
449+
logger.debug(f"handle_print: treeId={obj_id}, sourceFileType={source_file_type}")
450450

451451
if obj_id is None:
452452
logger.warning("No treeId or id provided")
@@ -555,7 +555,7 @@ def handle_install_recipes(params: dict) -> dict:
555555
# For local paths, we look for the package name from setup.py/pyproject.toml
556556
package_name = _find_package_name(local_path)
557557
if package_name:
558-
_import_and_activate_package(package_name, marketplace)
558+
_import_and_activate_package(package_name, marketplace, local_path)
559559

560560
elif isinstance(recipes, dict):
561561
# Package spec with name and optional version - package should already be installed
@@ -588,6 +588,43 @@ def handle_install_recipes(params: dict) -> dict:
588588
}
589589

590590

591+
def _add_source_to_path(local_path: Path) -> None:
592+
"""Add the package source directory to sys.path so it can be imported.
593+
594+
Reads [tool.setuptools.packages.find] 'where' from pyproject.toml to
595+
determine the source directory. Falls back to adding the local_path itself.
596+
"""
597+
import sys
598+
if sys.version_info >= (3, 11):
599+
import tomllib
600+
else:
601+
try:
602+
import tomli as tomllib # type: ignore[import-not-found]
603+
except ModuleNotFoundError:
604+
src_dir = str(local_path)
605+
if src_dir not in sys.path:
606+
sys.path.insert(0, src_dir)
607+
return
608+
609+
source_dir = local_path
610+
pyproject_path = local_path / 'pyproject.toml'
611+
if pyproject_path.exists():
612+
try:
613+
with open(pyproject_path, 'rb') as f:
614+
data = tomllib.load(f)
615+
where = (data.get('tool', {}).get('setuptools', {})
616+
.get('packages', {}).get('find', {}).get('where'))
617+
if where and isinstance(where, list) and len(where) > 0:
618+
source_dir = local_path / where[0]
619+
except Exception as e:
620+
logger.warning(f"Failed to read source layout from pyproject.toml: {e}")
621+
622+
src_str = str(source_dir)
623+
if src_str not in sys.path:
624+
logger.info(f"Adding to sys.path: {src_str}")
625+
sys.path.insert(0, src_str)
626+
627+
591628
def _find_package_name(local_path: Path) -> Optional[str]:
592629
"""Find the package name from a local path."""
593630
import sys
@@ -630,14 +667,18 @@ def _find_package_name(local_path: Path) -> Optional[str]:
630667
return None
631668

632669

633-
def _import_and_activate_package(package_name: str, marketplace):
670+
def _import_and_activate_package(package_name: str, marketplace, local_path: Optional[Path] = None):
634671
"""Import a package and call its activate function using entry points.
635672
636673
Uses importlib.metadata to discover entry points registered under
637674
the 'openrewrite.recipes' group and calls their activate functions.
638675
Since matching package names to entry points is unreliable (hyphens vs
639676
underscores, different naming conventions), we activate ALL entry points
640677
but the marketplace handles deduplication.
678+
679+
If entry points aren't found (e.g., package not pip-installed) and a
680+
local_path is provided, the source directory is added to sys.path
681+
as a fallback so the module can be imported directly.
641682
"""
642683
from importlib.metadata import entry_points
643684

@@ -681,6 +722,10 @@ def normalize(name: str) -> str:
681722

682723
if not activated:
683724
# Fallback: try direct module import (for packages without entry points)
725+
# If a local path was provided, add its source directory to sys.path
726+
if local_path is not None:
727+
_add_source_to_path(local_path)
728+
684729
import importlib
685730
module_name = package_name.replace('-', '_')
686731
try:
@@ -826,7 +871,7 @@ def handle_prepare_recipe(params: dict) -> dict:
826871
_data_table_output_dir = params['dataTableOutputDir']
827872
logger.info(f"Data table output directory set to: {_data_table_output_dir}")
828873

829-
logger.info(f"PrepareRecipe: id={recipe_name}, options={options}")
874+
logger.debug(f"PrepareRecipe: id={recipe_name}, options={options}")
830875

831876
marketplace = _get_marketplace()
832877

@@ -865,7 +910,7 @@ def handle_prepare_recipe(params: dict) -> dict:
865910
'scanPreconditions': _get_preconditions(recipe, 'scan') if is_scanning else [],
866911
}
867912

868-
logger.info(f"PrepareRecipe response: {response}")
913+
logger.debug(f"PrepareRecipe response: {response}")
869914
return response
870915

871916

@@ -903,7 +948,7 @@ def handle_visit(params: dict) -> dict:
903948
if tree_id is None:
904949
raise ValueError("'treeId' is required")
905950

906-
logger.info(f"Visit: visitor={visitor_name}, treeId={tree_id}, p={p_id}")
951+
logger.debug(f"Visit: visitor={visitor_name}, treeId={tree_id}, p={p_id}")
907952

908953
# Get or create execution context
909954
if p_id and p_id in _execution_contexts:
@@ -954,7 +999,7 @@ def handle_visit(params: dict) -> dict:
954999
else:
9551000
modified = False
9561001

957-
logger.info(f"Visit result: modified={modified}, tree_id={tree_id}, before.id={before.id}, after.id={after.id if after else None}")
1002+
logger.debug(f"Visit result: modified={modified}, tree_id={tree_id}, before.id={before.id}, after.id={after.id if after else None}")
9581003
return {'modified': modified}
9591004

9601005

@@ -1036,7 +1081,7 @@ def handle_generate(params: dict) -> dict:
10361081
if recipe_id is None:
10371082
raise ValueError("'id' is required")
10381083

1039-
logger.info(f"Generate: id={recipe_id}, p={p_id}")
1084+
logger.debug(f"Generate: id={recipe_id}, p={p_id}")
10401085

10411086
recipe = _prepared_recipes.get(recipe_id)
10421087
if recipe is None:

0 commit comments

Comments
 (0)