Skip to content

Commit 23ceeab

Browse files
committed
Python: Add Union/Intersection types, fix ref tracking, and address review items
- Add JavaType.Union and JavaType.Intersection with full sender/receiver/codec support (maps to JavaType$MultiCatch and JavaType$Intersection over RPC) - Fix _send_as_ref to transfer ref tracking from before to after on CHANGE path - Make _type_cache an instance variable to prevent cross-batch contamination - Remove mkdir(parents=True) from _ensure_file_on_disk for safer path handling - Reuse shutdown() in TyTypesClient.initialize for proper process lifecycle - Add fully_qualified_name property to JavaType.Class and Parameterized - Remove uv.lock from Git tracking
1 parent d2f5514 commit 23ceeab

10 files changed

Lines changed: 147 additions & 1043 deletions

File tree

rewrite-python/rewrite/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.idea/
22
.tox/
3+
uv.lock

rewrite-python/rewrite/src/rewrite/java/support_types.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class Unknown(FullyQualified):
203203
class Class(FullyQualified):
204204
_flags_bit_map: int
205205
_fully_qualified_name: str
206-
_kind: FullyQualified.Kind
206+
_kind: JavaType.FullyQualified.Kind
207207
_type_parameters: Optional[List[JavaType]]
208208
_supertype: Optional[JavaType.FullyQualified]
209209
_owning_class: Optional[JavaType.FullyQualified]
@@ -212,6 +212,10 @@ class Class(FullyQualified):
212212
_members: Optional[List[JavaType.Variable]]
213213
_methods: Optional[List[JavaType.Method]]
214214

215+
@property
216+
def fully_qualified_name(self) -> str:
217+
return self._fully_qualified_name
218+
215219
class ShallowClass(Class):
216220
pass
217221

@@ -228,10 +232,10 @@ def type_parameters(self) -> Optional[List[JavaType]]:
228232
return self._type_parameters
229233

230234
@property
231-
def _fully_qualified_name(self) -> str:
235+
def fully_qualified_name(self) -> str:
232236
t = getattr(self, '_type', None)
233-
if t is not None and hasattr(t, '_fully_qualified_name'):
234-
return t._fully_qualified_name
237+
if t is not None and hasattr(t, 'fully_qualified_name'):
238+
return t.fully_qualified_name
235239
return ''
236240

237241
class GenericTypeVariable:
@@ -240,6 +244,24 @@ class Variance(Enum):
240244
Covariant = 1
241245
Contravariant = 2
242246

247+
@dataclass
248+
class Union:
249+
"""Union type (e.g. str | int). Maps to JavaType$MultiCatch over RPC."""
250+
_bounds: Optional[List[JavaType]] = field(default=None)
251+
252+
@property
253+
def bounds(self) -> List[JavaType]:
254+
return self._bounds if self._bounds is not None else []
255+
256+
@dataclass
257+
class Intersection:
258+
"""Intersection type (e.g. A & B). Maps to JavaType$Intersection over RPC."""
259+
_bounds: Optional[List[JavaType]] = field(default=None)
260+
261+
@property
262+
def bounds(self) -> List[JavaType]:
263+
return self._bounds if self._bounds is not None else []
264+
243265
class Primitive(Enum):
244266
Boolean = 0
245267
Byte = 1

rewrite-python/rewrite/src/rewrite/java/support_types.pyi

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class JavaType(ABC):
7777
_members: Optional[List[JavaType.Variable]]
7878
_methods: Optional[List[JavaType.Method]]
7979

80+
@property
81+
def fully_qualified_name(self) -> str: ...
8082

8183
class ShallowClass(Class):
8284
pass
@@ -85,12 +87,15 @@ class JavaType(ABC):
8587
_type: JavaType.FullyQualified
8688
_type_parameters: Optional[List[JavaType]]
8789

90+
@property
91+
def fully_qualified_name(self) -> str: ...
92+
8893
@property
8994
def type(self) -> JavaType.FullyQualified: ...
9095
@property
9196
def type_parameters(self) -> Optional[List[JavaType]]: ...
9297
@property
93-
def _fully_qualified_name(self) -> str: ...
98+
def fully_qualified_name(self) -> str: ...
9499

95100

96101
class GenericTypeVariable:
@@ -99,7 +104,17 @@ class JavaType(ABC):
99104
Covariant: Variance
100105
Contravariant: Variance
101106

107+
class Union:
108+
_bounds: Optional[List[JavaType]]
102109

110+
@property
111+
def bounds(self) -> List[JavaType]: ...
112+
113+
class Intersection:
114+
_bounds: Optional[List[JavaType]]
115+
116+
@property
117+
def bounds(self) -> List[JavaType]: ...
103118

104119
class Primitive(Enum):
105120
Boolean: Primitive

rewrite-python/rewrite/src/rewrite/python/method_matcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,6 @@ def _get_fqn(type_obj) -> Optional[str]:
419419
if type_obj is None:
420420
return None
421421

422-
if hasattr(type_obj, "_fully_qualified_name"):
423-
return type_obj._fully_qualified_name
424422
if hasattr(type_obj, "fully_qualified_name"):
425423
return type_obj.fully_qualified_name
426424

rewrite-python/rewrite/src/rewrite/python/ty_client.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,7 @@ def initialize(self, project_root: str) -> bool:
156156
return True
157157

158158
if self._initialized:
159-
self._send_request("shutdown")
160-
self._initialized = False
161-
self._project_root = None
162-
if self._process is not None:
163-
try:
164-
self._process.wait(timeout=5)
165-
except subprocess.TimeoutExpired:
166-
self._process.kill()
159+
self.shutdown()
167160
self._start_process()
168161

169162
result = self._send_request("initialize", {"projectRoot": project_root})

rewrite-python/rewrite/src/rewrite/python/type_mapping.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ class PythonTypeMapping:
7474
method_type = mapping.method_invocation_type(call_node)
7575
"""
7676

77-
# Cache for type mappings to avoid repeated class type creation
78-
_type_cache: Dict[str, JavaType] = {}
79-
8077
def __init__(self, source: str, file_path: Optional[str] = None, ty_client=None):
8178
"""Initialize type mapping for a source file.
8279
@@ -91,6 +88,7 @@ def __init__(self, source: str, file_path: Optional[str] = None, ty_client=None)
9188
self._file_path = file_path
9289
self._source_lines = source.splitlines()
9390
self._temp_file: Optional[Path] = None
91+
self._type_cache: Dict[str, JavaType] = {} # FQN -> JavaType (per-instance)
9492

9593
# Compute line byte offsets for position conversion
9694
self._line_byte_offsets = self._compute_line_byte_offsets(source)
@@ -137,16 +135,18 @@ def _ensure_file_on_disk(self, source: str, file_path: Optional[str]) -> Optiona
137135
"""Ensure the source is available as a file on disk for ty-types.
138136
139137
Returns the absolute file path, or None if unavailable.
138+
When file_path is given but doesn't exist, writes source there.
139+
Callers are responsible for providing safe paths (e.g. within a temp directory).
140140
"""
141141
if file_path:
142142
path = Path(file_path)
143143
if not path.is_absolute():
144144
path = path.resolve()
145145
if path.exists():
146146
return str(path)
147-
# File path given but doesn't exist — write source there
147+
# File path given but doesn't exist — write source there.
148+
# The parent directory must already exist (caller should ensure this).
148149
try:
149-
path.parent.mkdir(parents=True, exist_ok=True)
150150
path.write_text(source, encoding='utf-8')
151151
self._temp_file = path
152152
return str(path)
@@ -177,7 +177,7 @@ def _compute_line_byte_offsets(source: str) -> List[int]:
177177
def _pos_to_byte_offset(self, lineno: int, col_offset: int) -> int:
178178
"""Convert AST (lineno, col_offset) to an absolute byte offset.
179179
180-
Python's ast uses 1-based lineno and character-based col_offset.
180+
Python's ast uses 1-based lineno and 0-based character col_offset (Python 3.8+).
181181
ty-types uses absolute byte offsets (ruff convention).
182182
"""
183183
line_start = self._line_byte_offsets[lineno - 1]
@@ -276,7 +276,7 @@ def _resolve_type(self, type_id: int) -> Optional[JavaType]:
276276
if type_id in self._cycle_placeholders:
277277
placeholder = self._cycle_placeholders.pop(type_id)
278278
if isinstance(result, JavaType.Class):
279-
placeholder._fully_qualified_name = result._fully_qualified_name
279+
placeholder._fully_qualified_name = result.fully_qualified_name
280280
placeholder._kind = result._kind
281281
# Copy enriched fields so cycle placeholders retain supertypes/methods
282282
for attr in ('_supertype', '_methods', '_type_parameters', '_interfaces',
@@ -285,8 +285,8 @@ def _resolve_type(self, type_id: int) -> Optional[JavaType]:
285285
if val is not None:
286286
setattr(placeholder, attr, val)
287287
elif isinstance(result, JavaType.Parameterized):
288-
if hasattr(result._type, '_fully_qualified_name'):
289-
placeholder._fully_qualified_name = result._type._fully_qualified_name
288+
if hasattr(result._type, 'fully_qualified_name'):
289+
placeholder._fully_qualified_name = result._type.fully_qualified_name
290290
self._type_id_cache[type_id] = placeholder
291291
return placeholder
292292

@@ -351,16 +351,24 @@ def _descriptor_to_java_type(self, descriptor: Dict[str, Any]) -> Optional[JavaT
351351
return JavaType.Primitive.String
352352

353353
elif kind == 'union':
354-
# Unwrap union: take first non-None type
354+
# Resolve all non-None members into a Union type.
355+
# For Optional[X] (= X | None) with a single real member, unwrap to just X.
356+
resolved_bounds = []
355357
for member_id in descriptor.get('members', []):
356358
member = self._type_registry.get(member_id)
357359
if member:
358360
member_kind = member.get('kind')
359361
# Skip None/NoneType members
360362
if member_kind == 'instance' and member.get('className') in ('None', 'NoneType'):
361363
continue
362-
return self._resolve_type(member_id)
363-
return _UNKNOWN
364+
resolved = self._resolve_type(member_id)
365+
if resolved is not None:
366+
resolved_bounds.append(resolved)
367+
if not resolved_bounds:
368+
return _UNKNOWN
369+
if len(resolved_bounds) == 1:
370+
return resolved_bounds[0]
371+
return JavaType.Union(_bounds=resolved_bounds)
364372

365373
elif kind == 'module':
366374
module_name = descriptor.get('moduleName', '')
@@ -691,7 +699,12 @@ def _get_declaring_type(self, node: ast.Call) -> Optional[JavaType.FullyQualifie
691699
return self._infer_declaring_type_from_ast(node)
692700

693701
def _resolve_declaring_type(self, type_id: int) -> Optional[JavaType.FullyQualified]:
694-
"""Resolve a type ID to a declaring type, maximizing object reuse."""
702+
"""Resolve a type ID to a declaring type, maximizing object reuse.
703+
704+
NOTE: The cycle-detection pattern here mirrors _resolve_type intentionally.
705+
They use separate caches and placeholder dicts because declaring types are
706+
resolved independently (often to a simpler Class without methods/members).
707+
"""
695708
if type_id in self._declaring_type_id_cache:
696709
return self._declaring_type_id_cache[type_id]
697710

@@ -720,7 +733,7 @@ def _resolve_declaring_type(self, type_id: int) -> Optional[JavaType.FullyQualif
720733
if type_id in self._declaring_cycle_placeholders:
721734
placeholder = self._declaring_cycle_placeholders.pop(type_id)
722735
if isinstance(result, JavaType.Class):
723-
placeholder._fully_qualified_name = result._fully_qualified_name
736+
placeholder._fully_qualified_name = result.fully_qualified_name
724737
placeholder._kind = result._kind
725738
self._declaring_type_id_cache[type_id] = placeholder
726739
return placeholder

rewrite-python/rewrite/src/rewrite/rpc/python_receiver.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,10 @@ def _receive_type(self, java_type, q: RpcReceiveQueue):
921921
922922
This matches the sender's _visit_type which sends expanded type fields.
923923
The callback pattern ensures message counts match between sender and receiver.
924+
925+
IMPORTANT: isinstance ordering matters here. Parameterized must be checked
926+
before Class because both are separate types. If inheritance changes, review
927+
this ordering. Mirrors the sender's _visit_type chain.
924928
"""
925929
from rewrite.java.support_types import JavaType as JT
926930

@@ -994,7 +998,7 @@ def _receive_type(self, java_type, q: RpcReceiveQueue):
994998
# owningClass, annotations, interfaces, members, methods
995999
flags = q.receive_defined(getattr(java_type, '_flags_bit_map', 0))
9961000
kind = _to_enum(JT.FullyQualified.Kind)(q.receive(getattr(java_type, '_kind', JT.FullyQualified.Kind.Class)))
997-
fqn = q.receive_defined(getattr(java_type, '_fully_qualified_name', ''))
1001+
fqn = q.receive_defined(getattr(java_type, 'fully_qualified_name', ''))
9981002
type_params = q.receive_list(getattr(java_type, '_type_parameters', None) or [],
9991003
lambda t: self._receive_type(t, q))
10001004
supertype = q.receive(getattr(java_type, '_supertype', None),
@@ -1048,6 +1052,18 @@ def _receive_type(self, java_type, q: RpcReceiveQueue):
10481052
var._annotations = annotations
10491053
return var
10501054

1055+
elif isinstance(java_type, JT.Union):
1056+
# Union (MultiCatch in Java): bounds list
1057+
bounds = q.receive_list(getattr(java_type, '_bounds', None) or [],
1058+
lambda t: self._receive_type(t, q))
1059+
return JT.Union(_bounds=bounds)
1060+
1061+
elif isinstance(java_type, JT.Intersection):
1062+
# Intersection: bounds list
1063+
bounds = q.receive_list(getattr(java_type, '_bounds', None) or [],
1064+
lambda t: self._receive_type(t, q))
1065+
return JT.Intersection(_bounds=bounds)
1066+
10511067
elif isinstance(java_type, JT.Unknown):
10521068
# Unknown has no additional fields
10531069
return java_type
@@ -1361,7 +1377,7 @@ def _receive_java_type_class(cls, q: RpcReceiveQueue):
13611377
# owningClass, annotations, interfaces, members, methods
13621378
flags = q.receive_defined(getattr(cls, '_flags_bit_map', 0) if cls else 0)
13631379
kind = _to_enum(JT.FullyQualified.Kind)(q.receive(getattr(cls, '_kind', JT.FullyQualified.Kind.Class) if cls else JT.FullyQualified.Kind.Class))
1364-
fqn = q.receive_defined(getattr(cls, '_fully_qualified_name', '') if cls else '')
1380+
fqn = q.receive_defined(getattr(cls, 'fully_qualified_name', '') if cls else '')
13651381
type_params = q.receive_list(getattr(cls, '_type_parameters', None) if cls else None)
13661382
supertype = q.receive(getattr(cls, '_supertype', None) if cls else None)
13671383
owning_class = q.receive(getattr(cls, '_owning_class', None) if cls else None)
@@ -1431,6 +1447,22 @@ def _receive_java_type_variable(variable, q: RpcReceiveQueue):
14311447
return var
14321448

14331449

1450+
def _receive_java_type_union(union, q: RpcReceiveQueue):
1451+
"""Codec for receiving JavaType.Union (MultiCatch in Java) - consumes bounds list."""
1452+
from rewrite.java.support_types import JavaType as JT
1453+
1454+
bounds = q.receive_list(union._bounds)
1455+
return JT.Union(_bounds=bounds)
1456+
1457+
1458+
def _receive_java_type_intersection(intersection, q: RpcReceiveQueue):
1459+
"""Codec for receiving JavaType.Intersection - consumes bounds list."""
1460+
from rewrite.java.support_types import JavaType as JT
1461+
1462+
bounds = q.receive_list(intersection._bounds)
1463+
return JT.Intersection(_bounds=bounds)
1464+
1465+
14341466
def _register_java_type_codecs():
14351467
"""Register codecs for JavaType classes."""
14361468
from rewrite.java.support_types import JavaType as JT
@@ -1493,6 +1525,22 @@ def _register_java_type_codecs():
14931525
lambda: JT.Array() # Factory creates empty Array
14941526
)
14951527

1528+
# JavaType.Union (MultiCatch in Java) - bounds list
1529+
register_codec_with_both_names(
1530+
'org.openrewrite.java.tree.JavaType$MultiCatch',
1531+
JT.Union,
1532+
_receive_java_type_union,
1533+
lambda: JT.Union() # Factory creates empty Union
1534+
)
1535+
1536+
# JavaType.Intersection - bounds list
1537+
register_codec_with_both_names(
1538+
'org.openrewrite.java.tree.JavaType$Intersection',
1539+
JT.Intersection,
1540+
_receive_java_type_intersection,
1541+
lambda: JT.Intersection() # Factory creates empty Intersection
1542+
)
1543+
14961544

14971545
def _register_tree_codecs():
14981546
"""Register codecs for all AST types using reflection."""

rewrite-python/rewrite/src/rewrite/rpc/python_sender.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def _visit_type(self, java_type, q: 'RpcSendQueue') -> None:
875875
# owningClass, annotations, interfaces, members, methods
876876
q.get_and_send(java_type, lambda x: getattr(x, '_flags_bit_map', 0))
877877
q.get_and_send(java_type, lambda x: getattr(x, '_kind', JT.FullyQualified.Kind.Class))
878-
q.get_and_send(java_type, lambda x: getattr(x, '_fully_qualified_name', ''))
878+
q.get_and_send(java_type, lambda x: getattr(x, 'fully_qualified_name', ''))
879879
q.get_and_send_list_as_ref(java_type, lambda x: getattr(x, '_type_parameters', None) or [], self._type_signature, lambda t: self._visit_type(t, q))
880880
q.get_and_send_as_ref(java_type, lambda x: getattr(x, '_supertype', None), lambda t: self._visit_type(t, q))
881881
q.get_and_send_as_ref(java_type, lambda x: getattr(x, '_owning_class', None), lambda t: self._visit_type(t, q))
@@ -896,6 +896,14 @@ def _visit_type(self, java_type, q: 'RpcSendQueue') -> None:
896896
q.get_and_send_as_ref(java_type, lambda x: x._type, lambda t: self._visit_type(t, q))
897897
q.get_and_send_list_as_ref(java_type, lambda x: x._annotations or [], self._type_signature, lambda t: self._visit_type(t, q))
898898

899+
elif isinstance(java_type, JT.Union):
900+
# Union (MultiCatch in Java): bounds list
901+
q.get_and_send_list_as_ref(java_type, lambda x: x.bounds, self._type_signature, lambda t: self._visit_type(t, q))
902+
903+
elif isinstance(java_type, JT.Intersection):
904+
# Intersection: bounds list
905+
q.get_and_send_list_as_ref(java_type, lambda x: x.bounds, self._type_signature, lambda t: self._visit_type(t, q))
906+
899907
elif isinstance(java_type, JT.Unknown):
900908
# Unknown has no additional fields
901909
pass
@@ -909,7 +917,7 @@ def _type_signature(self, java_type) -> str:
909917
if isinstance(java_type, JT.Primitive):
910918
return java_type.name
911919
if isinstance(java_type, JT.Class):
912-
return getattr(java_type, '_fully_qualified_name', str(id(java_type)))
920+
return getattr(java_type, 'fully_qualified_name', str(id(java_type)))
913921
if isinstance(java_type, JT.Method):
914922
declaring = getattr(java_type, '_declaring_type', None)
915923
declaring_name = self._type_signature(declaring) if declaring else ''
@@ -924,6 +932,10 @@ def _type_signature(self, java_type) -> str:
924932
if isinstance(java_type, JT.Array):
925933
elem_sig = self._type_signature(java_type._elem_type) if java_type._elem_type else ''
926934
return f"{elem_sig}[]"
935+
if isinstance(java_type, JT.Union):
936+
return '|'.join(self._type_signature(b) for b in java_type.bounds)
937+
if isinstance(java_type, JT.Intersection):
938+
return '&'.join(self._type_signature(b) for b in java_type.bounds)
927939
return str(id(java_type))
928940

929941
def _visit_space(self, space: Space, q: 'RpcSendQueue') -> None:

0 commit comments

Comments
 (0)