Skip to content

Commit 0880a2a

Browse files
committed
Python: Map typeVar descriptors to GenericTypeVariable with variance and upper bounds
1 parent 2cf6fd5 commit 0880a2a

8 files changed

Lines changed: 234 additions & 2 deletions

File tree

rewrite-python/rewrite/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ requires-python = ">=3.10"
2424
dependencies = [
2525
"cbor2>=5.6.5",
2626
"more_itertools>=10.0.0",
27-
"ty-types>=0.0.19.dev20260223122104", # Type inference CLI for Python type attribution
27+
"ty-types>=0.0.19.dev20260224073439", # Type inference CLI for Python type attribution
2828
"parso>=0.7.1,<0.8", # Python 2/3 parser with CST support (0.8+ dropped Python 2.7 grammar)
2929
]
3030

rewrite-python/rewrite/src/rewrite/java/support_types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,29 @@ def fully_qualified_name(self) -> str:
238238
return t.fully_qualified_name
239239
return ''
240240

241+
@dataclass
241242
class GenericTypeVariable:
243+
_name: str = field(default="")
244+
_variance: GenericTypeVariable.Variance = field(default=None)
245+
_bounds: Optional[List[JavaType]] = field(default=None)
246+
242247
class Variance(Enum):
243248
Invariant = 0
244249
Covariant = 1
245250
Contravariant = 2
246251

252+
@property
253+
def name(self) -> str:
254+
return self._name
255+
256+
@property
257+
def variance(self) -> GenericTypeVariable.Variance:
258+
return self._variance
259+
260+
@property
261+
def bounds(self) -> List[JavaType]:
262+
return self._bounds if self._bounds is not None else []
263+
247264
@dataclass
248265
class Union:
249266
"""Union type (e.g. str | int). Maps to JavaType$MultiCatch over RPC."""

rewrite-python/rewrite/src/rewrite/java/support_types.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,24 @@ class JavaType(ABC):
9898
def fully_qualified_name(self) -> str: ...
9999

100100

101+
@dataclass
101102
class GenericTypeVariable:
103+
_name: str
104+
_variance: GenericTypeVariable.Variance
105+
_bounds: Optional[List[JavaType]]
106+
102107
class Variance(Enum):
103108
Invariant: Variance
104109
Covariant: Variance
105110
Contravariant: Variance
106111

112+
@property
113+
def name(self) -> str: ...
114+
@property
115+
def variance(self) -> GenericTypeVariable.Variance: ...
116+
@property
117+
def bounds(self) -> List[JavaType]: ...
118+
107119
class Union:
108120
_bounds: Optional[List[JavaType]]
109121

rewrite-python/rewrite/src/rewrite/python/type_mapping.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,21 @@ def _descriptor_to_java_type(self, descriptor: Dict[str, Any]) -> Optional[JavaT
551551

552552
elif kind == 'typeVar':
553553
name = descriptor.get('name', '')
554-
return self._create_class_type(name) if name else _UNKNOWN
554+
if not name:
555+
return _UNKNOWN
556+
variance_str = descriptor.get('variance', 'invariant')
557+
variance_map = {
558+
'covariant': JavaType.GenericTypeVariable.Variance.Covariant,
559+
'contravariant': JavaType.GenericTypeVariable.Variance.Contravariant,
560+
}
561+
variance = variance_map.get(variance_str, JavaType.GenericTypeVariable.Variance.Invariant)
562+
bounds = None
563+
upper_bound_id = descriptor.get('upperBound')
564+
if upper_bound_id is not None:
565+
bound_type = self._resolve_type(upper_bound_id)
566+
if bound_type is not None:
567+
bounds = [bound_type]
568+
return JavaType.GenericTypeVariable(_name=name, _variance=variance, _bounds=bounds)
555569

556570
else:
557571
return _UNKNOWN

rewrite-python/rewrite/src/rewrite/rpc/python_receiver.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,25 @@ def mapper(value: Any) -> E:
4949
return mapper
5050

5151

52+
_JAVA_VARIANCE_MAP = {
53+
'COVARIANT': 'Covariant',
54+
'CONTRAVARIANT': 'Contravariant',
55+
'INVARIANT': 'Invariant',
56+
}
57+
58+
59+
def _to_variance(value: Any, JT) -> Any:
60+
"""Convert a Java wire-format variance string to a Python Variance enum.
61+
62+
Handles both the ADD/CHANGE path (receives a Java UPPER_CASE string) and the
63+
NO_CHANGE path (receives the existing Python Variance enum unchanged).
64+
"""
65+
if isinstance(value, JT.GenericTypeVariable.Variance):
66+
return value # NO_CHANGE path — already a Variance enum
67+
mapped = _JAVA_VARIANCE_MAP.get(str(value), 'Invariant')
68+
return JT.GenericTypeVariable.Variance[mapped]
69+
70+
5271
class PythonRpcReceiver:
5372
"""Receiver that mirrors Java's PythonReceiver for RPC deserialization."""
5473

@@ -1052,6 +1071,15 @@ def _receive_type(self, java_type, q: RpcReceiveQueue):
10521071
var._annotations = annotations
10531072
return var
10541073

1074+
elif isinstance(java_type, JT.GenericTypeVariable):
1075+
# GenericTypeVariable: name, variance, bounds
1076+
name = q.receive(getattr(java_type, '_name', ''))
1077+
variance_raw = q.receive(getattr(java_type, '_variance', JT.GenericTypeVariable.Variance.Invariant))
1078+
variance = _to_variance(variance_raw, JT)
1079+
bounds = q.receive_list(getattr(java_type, '_bounds', None) or [],
1080+
lambda t: self._receive_type(t, q))
1081+
return JT.GenericTypeVariable(_name=name, _variance=variance, _bounds=bounds)
1082+
10551083
elif isinstance(java_type, JT.Union):
10561084
# Union (MultiCatch in Java): bounds list
10571085
bounds = q.receive_list(getattr(java_type, '_bounds', None) or [],
@@ -1447,6 +1475,17 @@ def _receive_java_type_variable(variable, q: RpcReceiveQueue):
14471475
return var
14481476

14491477

1478+
def _receive_java_type_generic_type_variable(gtv, q: RpcReceiveQueue):
1479+
"""Codec for receiving JavaType.GenericTypeVariable - consumes name, variance, bounds."""
1480+
from rewrite.java.support_types import JavaType as JT
1481+
1482+
name = q.receive(gtv._name)
1483+
variance_raw = q.receive(gtv._variance)
1484+
variance = _to_variance(variance_raw, JT)
1485+
bounds = q.receive_list(gtv._bounds or [])
1486+
return JT.GenericTypeVariable(_name=name, _variance=variance, _bounds=bounds)
1487+
1488+
14501489
def _receive_java_type_union(union, q: RpcReceiveQueue):
14511490
"""Codec for receiving JavaType.Union (MultiCatch in Java) - consumes bounds list."""
14521491
from rewrite.java.support_types import JavaType as JT
@@ -1525,6 +1564,14 @@ def _register_java_type_codecs():
15251564
lambda: JT.Array() # Factory creates empty Array
15261565
)
15271566

1567+
# JavaType.GenericTypeVariable - name, variance, bounds
1568+
register_codec_with_both_names(
1569+
'org.openrewrite.java.tree.JavaType$GenericTypeVariable',
1570+
JT.GenericTypeVariable,
1571+
_receive_java_type_generic_type_variable,
1572+
lambda: JT.GenericTypeVariable() # Factory creates empty GenericTypeVariable
1573+
)
1574+
15281575
# JavaType.Union (MultiCatch in Java) - bounds list
15291576
register_codec_with_both_names(
15301577
'org.openrewrite.java.tree.JavaType$MultiCatch',

rewrite-python/rewrite/src/rewrite/rpc/python_sender.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,17 @@ def _visit_type(self, java_type, q: 'RpcSendQueue') -> None:
896896
q.get_and_send_as_ref(java_type, lambda x: x._type, lambda t: self._visit_type(t, q))
897897
q.get_and_send_list_as_ref(java_type, lambda x: x._annotations or [], self._type_signature, lambda t: self._visit_type(t, q))
898898

899+
elif isinstance(java_type, JT.GenericTypeVariable):
900+
# GenericTypeVariable: name, variance, bounds
901+
q.get_and_send(java_type, lambda x: x._name)
902+
# Java expects UPPER_SNAKE_CASE variance enum names
903+
_variance_to_java = {
904+
JT.GenericTypeVariable.Variance.Covariant: 'COVARIANT',
905+
JT.GenericTypeVariable.Variance.Contravariant: 'CONTRAVARIANT',
906+
}
907+
q.get_and_send(java_type, lambda x, m=_variance_to_java: m.get(x._variance, 'INVARIANT'))
908+
q.get_and_send_list_as_ref(java_type, lambda x: x.bounds, self._type_signature, lambda t: self._visit_type(t, q))
909+
899910
elif isinstance(java_type, JT.Union):
900911
# Union (MultiCatch in Java): bounds list
901912
q.get_and_send_list_as_ref(java_type, lambda x: x.bounds, self._type_signature, lambda t: self._visit_type(t, q))
@@ -932,6 +943,9 @@ def _type_signature(self, java_type) -> str:
932943
if isinstance(java_type, JT.Array):
933944
elem_sig = self._type_signature(java_type._elem_type) if java_type._elem_type else ''
934945
return f"{elem_sig}[]"
946+
if isinstance(java_type, JT.GenericTypeVariable):
947+
bounds_sig = ' & '.join(self._type_signature(b) for b in java_type.bounds) if java_type.bounds else ''
948+
return f"Generic{{{java_type._name}{' extends ' + bounds_sig if bounds_sig else ''}}}"
935949
if isinstance(java_type, JT.Union):
936950
return '|'.join(self._type_signature(b) for b in java_type.bounds)
937951
if isinstance(java_type, JT.Intersection):

rewrite-python/rewrite/src/rewrite/rpc/send_queue.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def _get_value_type(self, obj: Any) -> Optional[str]:
295295
return 'org.openrewrite.java.tree.JavaType$Parameterized'
296296
if isinstance(obj, JavaType.Class):
297297
return 'org.openrewrite.java.tree.JavaType$Class'
298+
if isinstance(obj, JavaType.GenericTypeVariable):
299+
return 'org.openrewrite.java.tree.JavaType$GenericTypeVariable'
298300
if isinstance(obj, JavaType.Union):
299301
return 'org.openrewrite.java.tree.JavaType$MultiCatch'
300302
if isinstance(obj, JavaType.Intersection):

rewrite-python/rewrite/tests/python/test_type_attribution.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,3 +1339,129 @@ class Multi(Base, Mixin):
13391339
assert any(n.endswith('Mixin') for n in iface_names)
13401340
finally:
13411341
_cleanup_mapping(mapping, tmpdir, client)
1342+
1343+
1344+
class TestGenericTypeVariable:
1345+
"""Tests for typeVar → GenericTypeVariable conversion."""
1346+
1347+
def test_plain_typevar_creates_generic_type_variable(self):
1348+
"""A typeVar descriptor with just a name should create a GenericTypeVariable."""
1349+
mapping = PythonTypeMapping("", file_path=None)
1350+
mapping._type_registry[100] = {
1351+
'kind': 'typeVar',
1352+
'name': 'T',
1353+
'variance': 'invariant',
1354+
}
1355+
1356+
result = mapping._resolve_type(100)
1357+
assert isinstance(result, JavaType.GenericTypeVariable)
1358+
assert result.name == 'T'
1359+
assert result.variance == JavaType.GenericTypeVariable.Variance.Invariant
1360+
assert result.bounds == []
1361+
1362+
def test_covariant_typevar(self):
1363+
"""A typeVar with covariant variance should map correctly."""
1364+
mapping = PythonTypeMapping("", file_path=None)
1365+
mapping._type_registry[100] = {
1366+
'kind': 'typeVar',
1367+
'name': 'T_co',
1368+
'variance': 'covariant',
1369+
}
1370+
1371+
result = mapping._resolve_type(100)
1372+
assert isinstance(result, JavaType.GenericTypeVariable)
1373+
assert result.name == 'T_co'
1374+
assert result.variance == JavaType.GenericTypeVariable.Variance.Covariant
1375+
1376+
def test_contravariant_typevar(self):
1377+
"""A typeVar with contravariant variance should map correctly."""
1378+
mapping = PythonTypeMapping("", file_path=None)
1379+
mapping._type_registry[100] = {
1380+
'kind': 'typeVar',
1381+
'name': 'T_contra',
1382+
'variance': 'contravariant',
1383+
}
1384+
1385+
result = mapping._resolve_type(100)
1386+
assert isinstance(result, JavaType.GenericTypeVariable)
1387+
assert result.name == 'T_contra'
1388+
assert result.variance == JavaType.GenericTypeVariable.Variance.Contravariant
1389+
1390+
def test_typevar_with_upper_bound(self):
1391+
"""A typeVar with an upperBound should have a bounds list."""
1392+
mapping = PythonTypeMapping("", file_path=None)
1393+
mapping._type_registry[100] = {
1394+
'kind': 'typeVar',
1395+
'name': 'T',
1396+
'variance': 'invariant',
1397+
'upperBound': 101,
1398+
}
1399+
mapping._type_registry[101] = {
1400+
'kind': 'instance',
1401+
'className': 'int',
1402+
}
1403+
1404+
result = mapping._resolve_type(100)
1405+
assert isinstance(result, JavaType.GenericTypeVariable)
1406+
assert result.name == 'T'
1407+
assert len(result.bounds) == 1
1408+
assert result.bounds[0] is JavaType.Primitive.Int
1409+
1410+
def test_typevar_with_class_upper_bound(self):
1411+
"""A typeVar bounded by a class should resolve the bound."""
1412+
mapping = PythonTypeMapping("", file_path=None)
1413+
mapping._type_registry[100] = {
1414+
'kind': 'typeVar',
1415+
'name': 'T',
1416+
'variance': 'covariant',
1417+
'upperBound': 101,
1418+
}
1419+
mapping._type_registry[101] = {
1420+
'kind': 'instance',
1421+
'className': 'Comparable',
1422+
'moduleName': 'builtins',
1423+
}
1424+
1425+
result = mapping._resolve_type(100)
1426+
assert isinstance(result, JavaType.GenericTypeVariable)
1427+
assert result.variance == JavaType.GenericTypeVariable.Variance.Covariant
1428+
assert len(result.bounds) == 1
1429+
assert isinstance(result.bounds[0], JavaType.Class)
1430+
assert result.bounds[0].fully_qualified_name == 'Comparable'
1431+
1432+
def test_typevar_without_variance_defaults_to_invariant(self):
1433+
"""A typeVar without explicit variance should default to Invariant."""
1434+
mapping = PythonTypeMapping("", file_path=None)
1435+
mapping._type_registry[100] = {
1436+
'kind': 'typeVar',
1437+
'name': 'T',
1438+
}
1439+
1440+
result = mapping._resolve_type(100)
1441+
assert isinstance(result, JavaType.GenericTypeVariable)
1442+
assert result.name == 'T'
1443+
assert result.variance == JavaType.GenericTypeVariable.Variance.Invariant
1444+
1445+
def test_typevar_without_name_returns_unknown(self):
1446+
"""A typeVar without a name should return Unknown."""
1447+
mapping = PythonTypeMapping("", file_path=None)
1448+
mapping._type_registry[100] = {
1449+
'kind': 'typeVar',
1450+
'name': '',
1451+
}
1452+
1453+
result = mapping._resolve_type(100)
1454+
assert isinstance(result, JavaType.Unknown)
1455+
1456+
def test_typevar_cached_by_type_id(self):
1457+
"""Resolved GenericTypeVariable should be cached by type_id."""
1458+
mapping = PythonTypeMapping("", file_path=None)
1459+
mapping._type_registry[100] = {
1460+
'kind': 'typeVar',
1461+
'name': 'T',
1462+
'variance': 'invariant',
1463+
}
1464+
1465+
result1 = mapping._resolve_type(100)
1466+
result2 = mapping._resolve_type(100)
1467+
assert result1 is result2

0 commit comments

Comments
 (0)