1818This uses the visitor pattern with pre_visit handling common fields (id, prefix, markers)
1919and type-specific visit methods handling only additional fields.
2020"""
21+ from enum import Enum
2122from pathlib import Path
22- from typing import Any , Optional , TypeVar , List
23- from uuid import UUID
23+ from typing import Any , Callable , Optional , Type , TypeVar
2424
25- from rewrite import Markers
26- from rewrite .utils import replace_if_changed
2725from rewrite .java import Space , JRightPadded , JLeftPadded , JContainer , J
2826from rewrite .python import CompilationUnit
29- from rewrite .python .support_types import Py
3027from rewrite .python .tree import (
3128 Async , Await , Binary , ChainedAssignment , ExceptionType ,
3229 LiteralType , TypeHint , ExpressionStatement , ExpressionTypeTree ,
3633 Star , NamedArgument , TypeHintedExpression , ErrorFrom , MatchCase , Slice
3734)
3835from rewrite .rpc .receive_queue import RpcReceiveQueue
36+ from rewrite .utils import replace_if_changed
3937
4038T = TypeVar ('T' )
39+ E = TypeVar ('E' , bound = Enum )
40+
41+
42+ def _to_enum (enum_class : Type [E ]) -> Callable [[Any ], E ]:
43+ """Create a mapping function that converts string values to enum members.
44+
45+ Similar to Java's toEnum(EnumClass.class) used in RPC deserialization.
46+ """
47+ def mapper (value : Any ) -> E :
48+ return enum_class [value ] if isinstance (value , str ) else value
49+ return mapper
4150
4251
4352class PythonRpcReceiver :
@@ -204,7 +213,7 @@ def _visit_await(self, await_: Await, q: RpcReceiveQueue) -> Await:
204213
205214 def _visit_binary (self , binary : Binary , q : RpcReceiveQueue ) -> Binary :
206215 left = q .receive (binary .left )
207- operator = q .receive (binary .padding .operator )
216+ operator = q .receive (binary .padding .operator , lambda lp : self . _receive_left_padded ( lp , q , _to_enum ( Binary . Type )) )
208217 negation = q .receive (binary .negation )
209218 right = q .receive (binary .right )
210219 type_ = q .receive (binary .type )
@@ -262,7 +271,7 @@ def _visit_dict_literal(self, dl: DictLiteral, q: RpcReceiveQueue) -> DictLitera
262271 return replace_if_changed (dl , elements = elements , type = type_ )
263272
264273 def _visit_collection_literal (self , cl : CollectionLiteral , q : RpcReceiveQueue ) -> CollectionLiteral :
265- kind = q .receive (cl .kind )
274+ kind = _to_enum ( CollectionLiteral . Kind )( q .receive (cl .kind ) )
266275 elements = q .receive (cl .padding .elements )
267276 type_ = q .receive (cl .type )
268277 return replace_if_changed (cl , kind = kind , elements = elements , type = type_ )
@@ -276,7 +285,7 @@ def _visit_formatted_string(self, fs: FormattedString, q: RpcReceiveQueue) -> Fo
276285 def _visit_formatted_string_value (self , v : FormattedString .Value , q : RpcReceiveQueue ) -> FormattedString .Value :
277286 expression = q .receive (v .padding .expression )
278287 debug = q .receive (v .padding .debug )
279- conversion = q .receive (v .conversion )
288+ conversion = _to_enum ( FormattedString . Value . Conversion )( q .receive (v .conversion ) )
280289 format_ = q .receive (v .format )
281290 return replace_if_changed (v , expression = expression , debug = debug , conversion = conversion , format = format_ )
282291
@@ -290,7 +299,7 @@ def _visit_trailing_else_wrapper(self, tew: TrailingElseWrapper, q: RpcReceiveQu
290299 return replace_if_changed (tew , statement = statement , else_block = else_block )
291300
292301 def _visit_comprehension_expression (self , ce : ComprehensionExpression , q : RpcReceiveQueue ) -> ComprehensionExpression :
293- kind = q .receive (ce .kind )
302+ kind = _to_enum ( ComprehensionExpression . Kind )( q .receive (ce .kind ) )
294303 result = q .receive (ce .result )
295304 clauses = q .receive_list (ce .clauses )
296305 suffix = q .receive (ce .suffix )
@@ -325,7 +334,7 @@ def _visit_union_type(self, ut: UnionType, q: RpcReceiveQueue) -> UnionType:
325334 return replace_if_changed (ut , types = types , type = type_ )
326335
327336 def _visit_variable_scope (self , vs : VariableScope , q : RpcReceiveQueue ) -> VariableScope :
328- kind = q .receive (vs .kind )
337+ kind = _to_enum ( VariableScope . Kind )( q .receive (vs .kind ) )
329338 names = q .receive_list (vs .padding .names )
330339 return replace_if_changed (vs , kind = kind , names = names )
331340
@@ -334,13 +343,13 @@ def _visit_del(self, del_: Del, q: RpcReceiveQueue) -> Del:
334343 return replace_if_changed (del_ , targets = targets )
335344
336345 def _visit_special_parameter (self , sp : SpecialParameter , q : RpcReceiveQueue ) -> SpecialParameter :
337- kind = q .receive (sp .kind )
346+ kind = _to_enum ( SpecialParameter . Kind )( q .receive (sp .kind ) )
338347 type_hint = q .receive (sp .type_hint )
339348 type_ = q .receive (sp .type )
340349 return replace_if_changed (sp , kind = kind , type_hint = type_hint , type = type_ )
341350
342351 def _visit_star (self , star : Star , q : RpcReceiveQueue ) -> Star :
343- kind = q .receive (star .kind )
352+ kind = _to_enum ( Star . Kind )( q .receive (star .kind ) )
344353 expression = q .receive (star .expression )
345354 type_ = q .receive (star .type )
346355 return replace_if_changed (star , kind = kind , expression = expression , type = type_ )
@@ -370,7 +379,7 @@ def _visit_match_case(self, mc: MatchCase, q: RpcReceiveQueue) -> MatchCase:
370379 return replace_if_changed (mc , pattern = pattern , guard = guard , type = type_ )
371380
372381 def _visit_match_case_pattern (self , p : MatchCase .Pattern , q : RpcReceiveQueue ) -> MatchCase .Pattern :
373- kind = q .receive (p .kind )
382+ kind = _to_enum ( MatchCase . Pattern . Kind )( q .receive (p .kind ) )
374383 children = q .receive (p .padding .children )
375384 type_ = q .receive (p .type )
376385 return replace_if_changed (p , kind = kind , children = children , type = type_ )
@@ -533,14 +542,16 @@ def _visit_block(self, block, q: RpcReceiveQueue):
533542 return replace_if_changed (block , static = static , statements = statements , end = end )
534543
535544 def _visit_j_unary (self , unary , q : RpcReceiveQueue ):
536- operator = q .receive (unary .padding .operator )
545+ from rewrite .java .tree import Unary
546+ operator = q .receive (unary .padding .operator , lambda lp : self ._receive_left_padded (lp , q , _to_enum (Unary .Type )))
537547 expression = q .receive (unary .expression )
538548 type_ = q .receive (unary .type )
539549 return replace_if_changed (unary , operator = operator , expression = expression , type = type_ )
540550
541551 def _visit_j_binary (self , binary , q : RpcReceiveQueue ):
552+ from rewrite .java .tree import Binary as JBinary
542553 left = q .receive (binary .left )
543- operator = q .receive (binary .padding .operator )
554+ operator = q .receive (binary .padding .operator , lambda lp : self . _receive_left_padded ( lp , q , _to_enum ( JBinary . Type )) )
544555 right = q .receive (binary .right )
545556 type_ = q .receive (binary .type )
546557 return replace_if_changed (binary , left = left , operator = operator , right = right , type = type_ )
@@ -552,8 +563,9 @@ def _visit_j_assignment(self, assign, q: RpcReceiveQueue):
552563 return replace_if_changed (assign , variable = variable , assignment = assignment , type = type_ )
553564
554565 def _visit_j_assignment_operation (self , assign , q : RpcReceiveQueue ):
566+ from rewrite .java .tree import AssignmentOperation
555567 variable = q .receive (assign .variable )
556- operator = q .receive (assign .padding .operator )
568+ operator = q .receive (assign .padding .operator , lambda lp : self . _receive_left_padded ( lp , q , _to_enum ( AssignmentOperation . Type )) )
557569 assignment = q .receive (assign .assignment )
558570 type_ = q .receive (assign .type )
559571 return replace_if_changed (assign , variable = variable , operator = operator , assignment = assignment , type = type_ )
@@ -702,9 +714,10 @@ def _visit_j_class_declaration(self, class_decl, q: RpcReceiveQueue):
702714 permits = permits , body = body )
703715
704716 def _visit_j_class_declaration_kind (self , kind , q : RpcReceiveQueue ):
717+ from rewrite .java .tree import ClassDeclaration
705718 # Note: _pre_visit is already called by _visit before this method
706719 annotations = q .receive_list (kind .annotations )
707- type_ = q .receive (kind .type ) # Enum type
720+ type_ = _to_enum ( ClassDeclaration . Kind . Type )( q .receive (kind .type ))
708721 return replace_if_changed (kind , annotations = annotations , type = type_ )
709722
710723 def _visit_j_method_declaration (self , method , q : RpcReceiveQueue ):
@@ -741,7 +754,8 @@ def _visit_j_switch(self, switch, q: RpcReceiveQueue):
741754 return replace_if_changed (switch , selector = selector , cases = cases )
742755
743756 def _visit_j_case (self , case , q : RpcReceiveQueue ):
744- type_ = q .receive (case .type ) # Enum type
757+ from rewrite .java .tree import Case
758+ type_ = _to_enum (Case .Type )(q .receive (case .type ))
745759 case_labels = q .receive (case .padding .case_labels )
746760 statements = q .receive (case .padding .statements )
747761 body = q .receive (case .padding .body , lambda rp : self ._receive_right_padded (rp , q ) if rp else None )
@@ -785,8 +799,9 @@ def _visit_j_control_parentheses(self, parens, q: RpcReceiveQueue):
785799 return replace_if_changed (parens , tree = tree )
786800
787801 def _visit_j_modifier (self , mod , q : RpcReceiveQueue ):
802+ from rewrite .java .tree import Modifier
788803 keyword = q .receive (mod .keyword )
789- type_ = q .receive (mod .type ) # Enum type
804+ type_ = _to_enum ( Modifier . Type )( q .receive (mod .type ))
790805 annotations = q .receive_list (mod .annotations )
791806 return replace_if_changed (mod , keyword = keyword , type = type_ , annotations = annotations )
792807
@@ -869,14 +884,16 @@ def _receive_right_padded(self, rp: JRightPadded, q: RpcReceiveQueue) -> Optiona
869884
870885 return JRightPadded (element , after , markers )
871886
872- def _receive_left_padded (self , lp : JLeftPadded , q : RpcReceiveQueue ) -> Optional [JLeftPadded ]:
887+ def _receive_left_padded (self , lp : JLeftPadded , q : RpcReceiveQueue , element_mapping : Optional [ Callable [[ Any ], Any ]] = None ) -> Optional [JLeftPadded ]:
873888 """Receive a JLeftPadded wrapper."""
874889 if lp is None :
875890 return None
876891
877892 # Codec registry handles type dispatch automatically
878893 before = q .receive_defined (lp .before )
879894 element = q .receive (lp .element )
895+ if element_mapping is not None :
896+ element = element_mapping (element )
880897 markers = q .receive_markers (lp .markers )
881898
882899 if before is lp .before and element is lp .element and markers is lp .markers :
@@ -966,7 +983,7 @@ def _receive_type(self, java_type, q: RpcReceiveQueue):
966983 # Class: flagsBitMap, kind, fullyQualifiedName, typeParameters, supertype,
967984 # owningClass, annotations, interfaces, members, methods
968985 flags = q .receive_defined (getattr (java_type , '_flags_bit_map' , 0 ))
969- kind = q .receive (getattr (java_type , '_kind' , JT .FullyQualified .Kind .Class ))
986+ kind = _to_enum ( JT . FullyQualified . Kind )( q .receive (getattr (java_type , '_kind' , JT .FullyQualified .Kind .Class ) ))
970987 fqn = q .receive_defined (getattr (java_type , '_fully_qualified_name' , '' ))
971988 type_params = q .receive_list (getattr (java_type , '_type_parameters' , None ) or [],
972989 lambda t : self ._receive_type (t , q ))
@@ -1279,7 +1296,7 @@ def _receive_java_type_class(cls, q: RpcReceiveQueue):
12791296 # flagsBitMap, kind, fullyQualifiedName, typeParameters, supertype,
12801297 # owningClass, annotations, interfaces, members, methods
12811298 flags = q .receive_defined (getattr (cls , '_flags_bit_map' , 0 ) if cls else 0 )
1282- kind = q .receive (getattr (cls , '_kind' , JT .FullyQualified .Kind .Class ) if cls else JT .FullyQualified .Kind .Class )
1299+ kind = _to_enum ( JT . FullyQualified . Kind )( q .receive (getattr (cls , '_kind' , JT .FullyQualified .Kind .Class ) if cls else JT .FullyQualified .Kind .Class ) )
12831300 fqn = q .receive_defined (getattr (cls , '_fully_qualified_name' , '' ) if cls else '' )
12841301 type_params = q .receive_list (getattr (cls , '_type_parameters' , None ) if cls else None )
12851302 supertype = q .receive (getattr (cls , '_supertype' , None ) if cls else None )
@@ -1308,8 +1325,7 @@ def _receive_java_type_class(cls, q: RpcReceiveQueue):
13081325def _register_java_type_codecs ():
13091326 """Register codecs for JavaType classes."""
13101327 from rewrite .java .support_types import JavaType as JT
1311- from rewrite .rpc .receive_queue import register_codec_with_both_names , register_receive_codec , register_send_codec
1312- from rewrite .rpc .send_queue import RpcSendQueue
1328+ from rewrite .rpc .receive_queue import register_codec_with_both_names
13131329
13141330 # JavaType.Primitive - special handling to consume keyword
13151331 register_codec_with_both_names (
@@ -1572,7 +1588,7 @@ def _send_keyword_only_arguments(marker: KeywordOnlyArguments, q: RpcSendQueue)
15721588 # Quoted - has id and style (enum)
15731589 def _receive_quoted (marker : Quoted , q : RpcReceiveQueue ) -> Quoted :
15741590 new_id = q .receive_defined (marker .id )
1575- new_style = q .receive_defined (marker .style )
1591+ new_style = _to_enum ( Quoted . Style )( q .receive_defined (marker .style ) )
15761592 if new_id is marker .id and new_style is marker .style :
15771593 return marker
15781594 result = marker
0 commit comments