Skip to content

Commit 0068882

Browse files
Python: Fix RPC issues (#6721)
* Python: Fix RPC issues Enums get sent as strings and need to be converted to enums o the Python side. * Register missing SearchResult sender * Register missing Markup RPC codecs
1 parent e19e752 commit 0068882

3 files changed

Lines changed: 155 additions & 46 deletions

File tree

rewrite-python/rewrite/src/rewrite/python/support_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from dataclasses import replace, dataclass
4-
from enum import Enum
54
from typing import TypeVar, Any, Optional, TYPE_CHECKING
65

76
from rewrite import TreeVisitor, Markers

rewrite-python/rewrite/src/rewrite/rpc/python_receiver.py

Lines changed: 101 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,12 @@
1818
This uses the visitor pattern with pre_visit handling common fields (id, prefix, markers)
1919
and type-specific visit methods handling only additional fields.
2020
"""
21+
from enum import Enum
2122
from pathlib import Path
22-
from typing import Any, Optional, TypeVar, List
23-
from uuid import UUID
23+
from typing import Any, Callable, Optional, Type, TypeVar
2424

25-
from rewrite import Markers
26-
from rewrite.utils import replace_if_changed
2725
from rewrite.java import Space, JRightPadded, JLeftPadded, JContainer, J
2826
from rewrite.python import CompilationUnit
29-
from rewrite.python.support_types import Py
3027
from rewrite.python.tree import (
3128
Async, Await, Binary, ChainedAssignment, ExceptionType,
3229
LiteralType, TypeHint, ExpressionStatement, ExpressionTypeTree,
@@ -36,8 +33,20 @@
3633
Star, NamedArgument, TypeHintedExpression, ErrorFrom, MatchCase, Slice
3734
)
3835
from rewrite.rpc.receive_queue import RpcReceiveQueue
36+
from rewrite.utils import replace_if_changed
3937

4038
T = TypeVar('T')
39+
E = TypeVar('E', bound=Enum)
40+
41+
42+
def _to_enum(enum_class: Type[E]) -> Callable[[Any], E]:
43+
"""Create a mapping function that converts string values to enum members.
44+
45+
Similar to Java's toEnum(EnumClass.class) used in RPC deserialization.
46+
"""
47+
def mapper(value: Any) -> E:
48+
return enum_class[value] if isinstance(value, str) else value
49+
return mapper
4150

4251

4352
class PythonRpcReceiver:
@@ -204,7 +213,7 @@ def _visit_await(self, await_: Await, q: RpcReceiveQueue) -> Await:
204213

205214
def _visit_binary(self, binary: Binary, q: RpcReceiveQueue) -> Binary:
206215
left = q.receive(binary.left)
207-
operator = q.receive(binary.padding.operator)
216+
operator = q.receive(binary.padding.operator, lambda lp: self._receive_left_padded(lp, q, _to_enum(Binary.Type)))
208217
negation = q.receive(binary.negation)
209218
right = q.receive(binary.right)
210219
type_ = q.receive(binary.type)
@@ -262,7 +271,7 @@ def _visit_dict_literal(self, dl: DictLiteral, q: RpcReceiveQueue) -> DictLitera
262271
return replace_if_changed(dl, elements=elements, type=type_)
263272

264273
def _visit_collection_literal(self, cl: CollectionLiteral, q: RpcReceiveQueue) -> CollectionLiteral:
265-
kind = q.receive(cl.kind)
274+
kind = _to_enum(CollectionLiteral.Kind)(q.receive(cl.kind))
266275
elements = q.receive(cl.padding.elements)
267276
type_ = q.receive(cl.type)
268277
return replace_if_changed(cl, kind=kind, elements=elements, type=type_)
@@ -276,7 +285,7 @@ def _visit_formatted_string(self, fs: FormattedString, q: RpcReceiveQueue) -> Fo
276285
def _visit_formatted_string_value(self, v: FormattedString.Value, q: RpcReceiveQueue) -> FormattedString.Value:
277286
expression = q.receive(v.padding.expression)
278287
debug = q.receive(v.padding.debug)
279-
conversion = q.receive(v.conversion)
288+
conversion = _to_enum(FormattedString.Value.Conversion)(q.receive(v.conversion))
280289
format_ = q.receive(v.format)
281290
return replace_if_changed(v, expression=expression, debug=debug, conversion=conversion, format=format_)
282291

@@ -290,7 +299,7 @@ def _visit_trailing_else_wrapper(self, tew: TrailingElseWrapper, q: RpcReceiveQu
290299
return replace_if_changed(tew, statement=statement, else_block=else_block)
291300

292301
def _visit_comprehension_expression(self, ce: ComprehensionExpression, q: RpcReceiveQueue) -> ComprehensionExpression:
293-
kind = q.receive(ce.kind)
302+
kind = _to_enum(ComprehensionExpression.Kind)(q.receive(ce.kind))
294303
result = q.receive(ce.result)
295304
clauses = q.receive_list(ce.clauses)
296305
suffix = q.receive(ce.suffix)
@@ -325,7 +334,7 @@ def _visit_union_type(self, ut: UnionType, q: RpcReceiveQueue) -> UnionType:
325334
return replace_if_changed(ut, types=types, type=type_)
326335

327336
def _visit_variable_scope(self, vs: VariableScope, q: RpcReceiveQueue) -> VariableScope:
328-
kind = q.receive(vs.kind)
337+
kind = _to_enum(VariableScope.Kind)(q.receive(vs.kind))
329338
names = q.receive_list(vs.padding.names)
330339
return replace_if_changed(vs, kind=kind, names=names)
331340

@@ -334,13 +343,13 @@ def _visit_del(self, del_: Del, q: RpcReceiveQueue) -> Del:
334343
return replace_if_changed(del_, targets=targets)
335344

336345
def _visit_special_parameter(self, sp: SpecialParameter, q: RpcReceiveQueue) -> SpecialParameter:
337-
kind = q.receive(sp.kind)
346+
kind = _to_enum(SpecialParameter.Kind)(q.receive(sp.kind))
338347
type_hint = q.receive(sp.type_hint)
339348
type_ = q.receive(sp.type)
340349
return replace_if_changed(sp, kind=kind, type_hint=type_hint, type=type_)
341350

342351
def _visit_star(self, star: Star, q: RpcReceiveQueue) -> Star:
343-
kind = q.receive(star.kind)
352+
kind = _to_enum(Star.Kind)(q.receive(star.kind))
344353
expression = q.receive(star.expression)
345354
type_ = q.receive(star.type)
346355
return replace_if_changed(star, kind=kind, expression=expression, type=type_)
@@ -370,7 +379,7 @@ def _visit_match_case(self, mc: MatchCase, q: RpcReceiveQueue) -> MatchCase:
370379
return replace_if_changed(mc, pattern=pattern, guard=guard, type=type_)
371380

372381
def _visit_match_case_pattern(self, p: MatchCase.Pattern, q: RpcReceiveQueue) -> MatchCase.Pattern:
373-
kind = q.receive(p.kind)
382+
kind = _to_enum(MatchCase.Pattern.Kind)(q.receive(p.kind))
374383
children = q.receive(p.padding.children)
375384
type_ = q.receive(p.type)
376385
return replace_if_changed(p, kind=kind, children=children, type=type_)
@@ -533,14 +542,16 @@ def _visit_block(self, block, q: RpcReceiveQueue):
533542
return replace_if_changed(block, static=static, statements=statements, end=end)
534543

535544
def _visit_j_unary(self, unary, q: RpcReceiveQueue):
536-
operator = q.receive(unary.padding.operator)
545+
from rewrite.java.tree import Unary
546+
operator = q.receive(unary.padding.operator, lambda lp: self._receive_left_padded(lp, q, _to_enum(Unary.Type)))
537547
expression = q.receive(unary.expression)
538548
type_ = q.receive(unary.type)
539549
return replace_if_changed(unary, operator=operator, expression=expression, type=type_)
540550

541551
def _visit_j_binary(self, binary, q: RpcReceiveQueue):
552+
from rewrite.java.tree import Binary as JBinary
542553
left = q.receive(binary.left)
543-
operator = q.receive(binary.padding.operator)
554+
operator = q.receive(binary.padding.operator, lambda lp: self._receive_left_padded(lp, q, _to_enum(JBinary.Type)))
544555
right = q.receive(binary.right)
545556
type_ = q.receive(binary.type)
546557
return replace_if_changed(binary, left=left, operator=operator, right=right, type=type_)
@@ -552,8 +563,9 @@ def _visit_j_assignment(self, assign, q: RpcReceiveQueue):
552563
return replace_if_changed(assign, variable=variable, assignment=assignment, type=type_)
553564

554565
def _visit_j_assignment_operation(self, assign, q: RpcReceiveQueue):
566+
from rewrite.java.tree import AssignmentOperation
555567
variable = q.receive(assign.variable)
556-
operator = q.receive(assign.padding.operator)
568+
operator = q.receive(assign.padding.operator, lambda lp: self._receive_left_padded(lp, q, _to_enum(AssignmentOperation.Type)))
557569
assignment = q.receive(assign.assignment)
558570
type_ = q.receive(assign.type)
559571
return replace_if_changed(assign, variable=variable, operator=operator, assignment=assignment, type=type_)
@@ -702,9 +714,10 @@ def _visit_j_class_declaration(self, class_decl, q: RpcReceiveQueue):
702714
permits=permits, body=body)
703715

704716
def _visit_j_class_declaration_kind(self, kind, q: RpcReceiveQueue):
717+
from rewrite.java.tree import ClassDeclaration
705718
# Note: _pre_visit is already called by _visit before this method
706719
annotations = q.receive_list(kind.annotations)
707-
type_ = q.receive(kind.type) # Enum type
720+
type_ = _to_enum(ClassDeclaration.Kind.Type)(q.receive(kind.type))
708721
return replace_if_changed(kind, annotations=annotations, type=type_)
709722

710723
def _visit_j_method_declaration(self, method, q: RpcReceiveQueue):
@@ -741,7 +754,8 @@ def _visit_j_switch(self, switch, q: RpcReceiveQueue):
741754
return replace_if_changed(switch, selector=selector, cases=cases)
742755

743756
def _visit_j_case(self, case, q: RpcReceiveQueue):
744-
type_ = q.receive(case.type) # Enum type
757+
from rewrite.java.tree import Case
758+
type_ = _to_enum(Case.Type)(q.receive(case.type))
745759
case_labels = q.receive(case.padding.case_labels)
746760
statements = q.receive(case.padding.statements)
747761
body = q.receive(case.padding.body, lambda rp: self._receive_right_padded(rp, q) if rp else None)
@@ -785,8 +799,9 @@ def _visit_j_control_parentheses(self, parens, q: RpcReceiveQueue):
785799
return replace_if_changed(parens, tree=tree)
786800

787801
def _visit_j_modifier(self, mod, q: RpcReceiveQueue):
802+
from rewrite.java.tree import Modifier
788803
keyword = q.receive(mod.keyword)
789-
type_ = q.receive(mod.type) # Enum type
804+
type_ = _to_enum(Modifier.Type)(q.receive(mod.type))
790805
annotations = q.receive_list(mod.annotations)
791806
return replace_if_changed(mod, keyword=keyword, type=type_, annotations=annotations)
792807

@@ -869,14 +884,16 @@ def _receive_right_padded(self, rp: JRightPadded, q: RpcReceiveQueue) -> Optiona
869884

870885
return JRightPadded(element, after, markers)
871886

872-
def _receive_left_padded(self, lp: JLeftPadded, q: RpcReceiveQueue) -> Optional[JLeftPadded]:
887+
def _receive_left_padded(self, lp: JLeftPadded, q: RpcReceiveQueue, element_mapping: Optional[Callable[[Any], Any]] = None) -> Optional[JLeftPadded]:
873888
"""Receive a JLeftPadded wrapper."""
874889
if lp is None:
875890
return None
876891

877892
# Codec registry handles type dispatch automatically
878893
before = q.receive_defined(lp.before)
879894
element = q.receive(lp.element)
895+
if element_mapping is not None:
896+
element = element_mapping(element)
880897
markers = q.receive_markers(lp.markers)
881898

882899
if before is lp.before and element is lp.element and markers is lp.markers:
@@ -966,7 +983,7 @@ def _receive_type(self, java_type, q: RpcReceiveQueue):
966983
# Class: flagsBitMap, kind, fullyQualifiedName, typeParameters, supertype,
967984
# owningClass, annotations, interfaces, members, methods
968985
flags = q.receive_defined(getattr(java_type, '_flags_bit_map', 0))
969-
kind = q.receive(getattr(java_type, '_kind', JT.FullyQualified.Kind.Class))
986+
kind = _to_enum(JT.FullyQualified.Kind)(q.receive(getattr(java_type, '_kind', JT.FullyQualified.Kind.Class)))
970987
fqn = q.receive_defined(getattr(java_type, '_fully_qualified_name', ''))
971988
type_params = q.receive_list(getattr(java_type, '_type_parameters', None) or [],
972989
lambda t: self._receive_type(t, q))
@@ -1007,19 +1024,11 @@ def _receive_type(self, java_type, q: RpcReceiveQueue):
10071024
# Register marker codecs
10081025
def _register_marker_codecs():
10091026
"""Register receive and send codecs for Java marker types."""
1010-
from rewrite import Markers
10111027
from rewrite.java.markers import Semicolon, TrailingComma, OmitParentheses
10121028
from rewrite.java.support_types import Space
1013-
from rewrite.rpc.receive_queue import register_codec_with_both_names, register_send_codec
1029+
from rewrite.rpc.receive_queue import register_codec_with_both_names
10141030
from rewrite.rpc.send_queue import RpcSendQueue
10151031

1016-
# Markers send codec (Markers uses special handling in receive_queue, but needs send codec)
1017-
def _send_markers(markers: Markers, q: RpcSendQueue) -> None:
1018-
q.get_and_send(markers, lambda x: x.id)
1019-
q.get_and_send_list(markers, lambda x: x.markers, lambda m: m.id, None)
1020-
1021-
register_send_codec(Markers, _send_markers)
1022-
10231032
# Receive codecs
10241033
def _receive_semicolon(semicolon: Semicolon, q: RpcReceiveQueue) -> Semicolon:
10251034
new_id = q.receive_defined(semicolon.id)
@@ -1179,6 +1188,16 @@ def _receive_parse_exception_result(marker, q: RpcReceiveQueue):
11791188
)
11801189

11811190

1191+
def _send_search_result(marker, q):
1192+
"""Codec for sending SearchResult marker.
1193+
1194+
Fields are sent in the order expected by Java's SearchResult.rpcReceive():
1195+
id, description
1196+
"""
1197+
q.get_and_send(marker, lambda x: str(x.id))
1198+
q.get_and_send(marker, lambda x: x.description)
1199+
1200+
11821201
def _send_parse_exception_result(marker, q):
11831202
"""Codec for sending ParseExceptionResult marker.
11841203
@@ -1192,6 +1211,33 @@ def _send_parse_exception_result(marker, q):
11921211
q.get_and_send(marker, lambda x: x.tree_type)
11931212

11941213

1214+
def _receive_markup_marker(marker, q: RpcReceiveQueue, cls):
1215+
"""Generic codec for receiving Markup markers (Warn/Error/Info/Debug).
1216+
1217+
All four share the same fields in the same order: id, message, detail.
1218+
Matches Java's Markup.{Warn,Error,Info,Debug}.rpcReceive().
1219+
"""
1220+
from uuid import UUID
1221+
1222+
id_str = q.receive(str(marker.id) if marker else None)
1223+
message = q.receive(marker.message if marker else None)
1224+
detail = q.receive(marker.detail if marker else None)
1225+
1226+
new_id = UUID(id_str) if id_str else (marker.id if marker else None)
1227+
return cls(_id=new_id, _message=message, _detail=detail)
1228+
1229+
1230+
def _send_markup_marker(marker, q):
1231+
"""Generic codec for sending Markup markers (Warn/Error/Info/Debug).
1232+
1233+
All four share the same fields in the same order: id, message, detail.
1234+
Matches Java's Markup.{Warn,Error,Info,Debug}.rpcSend().
1235+
"""
1236+
q.get_and_send(marker, lambda x: str(x.id))
1237+
q.get_and_send(marker, lambda x: x.message)
1238+
q.get_and_send(marker, lambda x: x.detail)
1239+
1240+
11951241
def _receive_style(style, q: RpcReceiveQueue):
11961242
"""Codec for receiving Style objects."""
11971243
# For now, styles are passed through - full deserialization would need more work
@@ -1279,7 +1325,7 @@ def _receive_java_type_class(cls, q: RpcReceiveQueue):
12791325
# flagsBitMap, kind, fullyQualifiedName, typeParameters, supertype,
12801326
# owningClass, annotations, interfaces, members, methods
12811327
flags = q.receive_defined(getattr(cls, '_flags_bit_map', 0) if cls else 0)
1282-
kind = q.receive(getattr(cls, '_kind', JT.FullyQualified.Kind.Class) if cls else JT.FullyQualified.Kind.Class)
1328+
kind = _to_enum(JT.FullyQualified.Kind)(q.receive(getattr(cls, '_kind', JT.FullyQualified.Kind.Class) if cls else JT.FullyQualified.Kind.Class))
12831329
fqn = q.receive_defined(getattr(cls, '_fully_qualified_name', '') if cls else '')
12841330
type_params = q.receive_list(getattr(cls, '_type_parameters', None) if cls else None)
12851331
supertype = q.receive(getattr(cls, '_supertype', None) if cls else None)
@@ -1308,8 +1354,7 @@ def _receive_java_type_class(cls, q: RpcReceiveQueue):
13081354
def _register_java_type_codecs():
13091355
"""Register codecs for JavaType classes."""
13101356
from rewrite.java.support_types import JavaType as JT
1311-
from rewrite.rpc.receive_queue import register_codec_with_both_names, register_receive_codec, register_send_codec
1312-
from rewrite.rpc.send_queue import RpcSendQueue
1357+
from rewrite.rpc.receive_queue import register_codec_with_both_names
13131358

13141359
# JavaType.Primitive - special handling to consume keyword
13151360
register_codec_with_both_names(
@@ -1427,12 +1472,13 @@ def _register_core_marker_codecs():
14271472
_receive_markers,
14281473
lambda: Markers.EMPTY
14291474
)
1430-
# SearchResult - has specific fields to receive
1475+
# SearchResult - has specific fields to receive/send
14311476
register_codec_with_both_names(
14321477
'org.openrewrite.marker.SearchResult',
14331478
SearchResult,
14341479
_receive_search_result,
1435-
make_dataclass_factory(SearchResult)
1480+
make_dataclass_factory(SearchResult),
1481+
sender=_send_search_result
14361482
)
14371483
# ParseExceptionResult - has specific fields to receive/send
14381484
register_codec_with_both_names(
@@ -1444,6 +1490,24 @@ def _register_core_marker_codecs():
14441490
)
14451491

14461492

1493+
def _register_markup_marker_codecs():
1494+
"""Register codecs for Markup marker types (Warn, Error, Info, Debug)."""
1495+
from rewrite.markers import MarkupWarn, MarkupError, MarkupInfo, MarkupDebug
1496+
from rewrite.rpc.receive_queue import register_codec_with_both_names
1497+
from uuid import uuid4
1498+
1499+
for java_suffix, py_cls in [
1500+
('Warn', MarkupWarn),
1501+
('Error', MarkupError),
1502+
('Info', MarkupInfo),
1503+
('Debug', MarkupDebug),
1504+
]:
1505+
java_type = f'org.openrewrite.marker.Markup${java_suffix}'
1506+
receive = lambda marker, q, c=py_cls: _receive_markup_marker(marker, q, c)
1507+
factory = lambda c=py_cls: c(_id=uuid4(), _message='', _detail=None)
1508+
register_codec_with_both_names(java_type, py_cls, receive, factory, _send_markup_marker)
1509+
1510+
14471511
def _register_style_codecs():
14481512
"""Register codecs for style types."""
14491513
from rewrite.style import GeneralFormatStyle, NamedStyles
@@ -1572,7 +1636,7 @@ def _send_keyword_only_arguments(marker: KeywordOnlyArguments, q: RpcSendQueue)
15721636
# Quoted - has id and style (enum)
15731637
def _receive_quoted(marker: Quoted, q: RpcReceiveQueue) -> Quoted:
15741638
new_id = q.receive_defined(marker.id)
1575-
new_style = q.receive_defined(marker.style)
1639+
new_style = _to_enum(Quoted.Style)(q.receive_defined(marker.style))
15761640
if new_id is marker.id and new_style is marker.style:
15771641
return marker
15781642
result = marker
@@ -1668,6 +1732,7 @@ def _send_exec_syntax(marker: ExecSyntax, q: RpcSendQueue) -> None:
16681732
_register_support_type_codecs()
16691733
_register_java_type_codecs() # JavaType.Primitive handling
16701734
_register_core_marker_codecs()
1735+
_register_markup_marker_codecs() # Markup.Warn, Error, Info, Debug
16711736
_register_style_codecs()
16721737
_register_parse_error_codec() # ParseError handling
16731738
_register_python_marker_codecs() # Python-specific markers including PrintSyntax, ExecSyntax

0 commit comments

Comments
 (0)