Skip to content

Commit f17ba12

Browse files
Python: bump ty-types to 0.0.31, consume new ParamSpec wire fields (#7431)
ty-types 0.0.31 adds two optional fields to `ParameterInfo`: `concatenatePrefix` (flags leading positionals of a Concatenate prefix) and `paramSpecName` (set on the synthetic *args/**kwargs pair that stands in for a `ParamSpec` tail). When a pair carries `paramSpecName`, collapse it into a single synthetic `JavaType.Method` parameter named after the ParamSpec with type `JavaType.Unknown`, rather than exposing `P.args`/`P.kwargs` as two distinct variadic params. `concatenatePrefix` is accepted but treated as a regular positional today — the wire flag exists so consumers can reconstruct Concatenate forms later if they choose. Adds unit tests for the collapse/pass-through logic on mock descriptors plus CLI-backed integration tests covering `Callable[P, R]`, `Callable[Concatenate[int, P], R]`, and a plain `def add(a, b)` regression.
1 parent fd8e2bf commit f17ba12

3 files changed

Lines changed: 179 additions & 33 deletions

File tree

rewrite-python/rewrite/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ requires-python = ">=3.10"
2424
dependencies = [
2525
"cbor2>=5.6.5",
2626
"more_itertools>=10.0.0",
27-
"ty-types>=0.0.21", # Type inference CLI for Python type attribution
27+
"ty-types>=0.0.31", # Type inference CLI for Python type attribution
2828
"parso>=0.7.1,<0.8", # Python 2/3 parser with CST support (0.8+ dropped Python 2.7 grammar)
2929
]
3030

rewrite-python/rewrite/src/rewrite/python/type_mapping.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -771,14 +771,8 @@ def _method_from_function_descriptor(
771771
self, descriptor: Dict[str, Any], name: str
772772
) -> JavaType.Method:
773773
"""Build a JavaType.Method from a function descriptor with parameters/returnType."""
774-
param_names: List[str] = []
775-
param_types: List[JavaType] = []
776-
for param in descriptor.get('parameters', []):
777-
p_name = param.get('name', '')
778-
if p_name in ('self', 'cls'):
779-
continue
780-
param_names.append(p_name)
781-
param_types.append(self._resolve_param_type(param))
774+
param_names, param_types = self._process_method_params(
775+
descriptor.get('parameters', []))
782776

783777
return_type = None
784778
ret_id = descriptor.get('returnType')
@@ -873,6 +867,40 @@ def _resolve_param_type(self, param: Dict[str, Any]) -> JavaType:
873867
return result
874868
return _UNKNOWN
875869

870+
def _process_method_params(
871+
self, params: List[Dict[str, Any]]
872+
) -> Tuple[List[str], List[JavaType]]:
873+
"""Normalize a ParameterInfo list into (names, types) for JavaType.Method.
874+
875+
Applies:
876+
- Skip `self` / `cls`.
877+
- Collapse the synthetic `*args` / `**kwargs` pair emitted for a
878+
`ParamSpec` tail (both carry the same ``paramSpecName``) into a
879+
single entry whose name is the ParamSpec's name and whose type
880+
is `_UNKNOWN`. This avoids exposing `P.args` / `P.kwargs` as two
881+
distinct variadic parameters on the produced method.
882+
- Treat `concatenatePrefix` params as ordinary positional params.
883+
"""
884+
names: List[str] = []
885+
types: List[JavaType] = []
886+
last_spec_emitted: Optional[str] = None
887+
for p in params:
888+
p_name = p.get('name', '')
889+
if p_name in ('self', 'cls'):
890+
continue
891+
spec_name = p.get('paramSpecName')
892+
if spec_name is not None:
893+
if spec_name == last_spec_emitted:
894+
continue
895+
names.append(spec_name)
896+
types.append(_UNKNOWN)
897+
last_spec_emitted = spec_name
898+
continue
899+
last_spec_emitted = None
900+
names.append(p_name)
901+
types.append(self._resolve_param_type(p))
902+
return names, types
903+
876904
def _get_method_signature(self, node: ast.Call) -> Tuple[List[str], List[JavaType]]:
877905
"""Get parameter names and types from the method signature.
878906
@@ -885,11 +913,7 @@ def _get_method_signature(self, node: ast.Call) -> Tuple[List[str], List[JavaTyp
885913
if sig:
886914
params = sig.get('parameters', [])
887915
if params:
888-
names = [p['name'] for p in params
889-
if p['name'] not in ('self', 'cls')]
890-
types = [self._resolve_param_type(p) for p in params
891-
if p['name'] not in ('self', 'cls')]
892-
return names, types
916+
return self._process_method_params(params)
893917

894918
# Try function/method descriptor parameters
895919
func_type_id = self._lookup_func_type_id(node)
@@ -898,11 +922,7 @@ def _get_method_signature(self, node: ast.Call) -> Tuple[List[str], List[JavaTyp
898922
if descriptor:
899923
params = descriptor.get('parameters', [])
900924
if params:
901-
names = [p['name'] for p in params
902-
if p['name'] not in ('self', 'cls')]
903-
types = [self._resolve_param_type(p) for p in params
904-
if p['name'] not in ('self', 'cls')]
905-
return names, types
925+
return self._process_method_params(params)
906926

907927
# Fall back to placeholder names
908928
return self._generate_placeholder_names(node)
@@ -1205,20 +1225,9 @@ def _create_method_from_descriptor(self, descriptor: Dict[str, Any],
12051225
if return_type_id is not None:
12061226
return_type = self._resolve_type(return_type_id)
12071227

1208-
# Resolve parameters (skip self/cls)
1209-
param_names = []
1210-
param_types = []
1211-
for param in descriptor.get('parameters', []):
1212-
p_name = param.get('name', '')
1213-
if p_name in ('self', 'cls'):
1214-
continue
1215-
param_names.append(p_name)
1216-
p_type_id = param.get('typeId')
1217-
if p_type_id is not None:
1218-
p_type = self._resolve_type(p_type_id)
1219-
param_types.append(p_type if p_type else _UNKNOWN)
1220-
else:
1221-
param_types.append(_UNKNOWN)
1228+
# Resolve parameters (skip self/cls, collapse ParamSpec *args/**kwargs pairs)
1229+
param_names, param_types = self._process_method_params(
1230+
descriptor.get('parameters', []))
12221231

12231232
type_param_names = self._extract_type_param_names(descriptor)
12241233

rewrite-python/rewrite/tests/python/test_type_attribution.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,143 @@ def test_primitive_not_wrapped_in_class(self):
12401240
assert result is JavaType.Primitive.Int
12411241

12421242

1243+
class TestParamSpecAndConcatenate:
1244+
"""Unit tests for the ty-types 0.0.31 `paramSpecName` / `concatenatePrefix` fields.
1245+
1246+
Use mock descriptors so the logic is exercised without the ty-types CLI.
1247+
End-to-end tests against real source live in
1248+
TestParamSpecAndConcatenateIntegration below.
1249+
"""
1250+
1251+
def test_paramspec_pair_collapses_to_single_entry(self):
1252+
"""The synthetic `*args` + `**kwargs` pair is folded into one entry
1253+
whose name is the ParamSpec's name and whose type is Unknown."""
1254+
mapping = PythonTypeMapping("", file_path=None)
1255+
params = [
1256+
{'name': 'args', 'kind': 'variadic', 'paramSpecName': 'P'},
1257+
{'name': 'kwargs', 'kind': 'keywordVariadic', 'paramSpecName': 'P'},
1258+
]
1259+
names, types = mapping._process_method_params(params)
1260+
assert names == ['P']
1261+
assert len(types) == 1
1262+
assert isinstance(types[0], JavaType.Unknown)
1263+
1264+
def test_concatenate_prefix_is_treated_as_positional(self):
1265+
"""A param flagged `concatenatePrefix: True` is emitted as-is;
1266+
the trailing ParamSpec pair still collapses behind it."""
1267+
mapping = PythonTypeMapping("", file_path=None)
1268+
mapping._type_registry[7] = {'kind': 'instance', 'className': 'int'}
1269+
params = [
1270+
{'name': '', 'kind': 'positionalOnly', 'typeId': 7,
1271+
'concatenatePrefix': True},
1272+
{'name': 'args', 'kind': 'variadic', 'paramSpecName': 'P'},
1273+
{'name': 'kwargs', 'kind': 'keywordVariadic', 'paramSpecName': 'P'},
1274+
]
1275+
names, types = mapping._process_method_params(params)
1276+
assert names == ['', 'P']
1277+
assert types[0] is JavaType.Primitive.Int
1278+
assert isinstance(types[1], JavaType.Unknown)
1279+
1280+
def test_plain_params_unchanged(self):
1281+
"""Regression: descriptors without the new fields still produce
1282+
one entry per parameter with the declared type."""
1283+
mapping = PythonTypeMapping("", file_path=None)
1284+
mapping._type_registry[1] = {'kind': 'instance', 'className': 'int'}
1285+
mapping._type_registry[2] = {'kind': 'instance', 'className': 'int'}
1286+
params = [
1287+
{'name': 'a', 'typeId': 1},
1288+
{'name': 'b', 'typeId': 2},
1289+
]
1290+
names, types = mapping._process_method_params(params)
1291+
assert names == ['a', 'b']
1292+
assert types == [JavaType.Primitive.Int, JavaType.Primitive.Int]
1293+
1294+
def test_self_and_cls_still_filtered(self):
1295+
"""Filtering of self/cls still works when new fields are present."""
1296+
mapping = PythonTypeMapping("", file_path=None)
1297+
params = [
1298+
{'name': 'self'},
1299+
{'name': 'args', 'kind': 'variadic', 'paramSpecName': 'P'},
1300+
{'name': 'kwargs', 'kind': 'keywordVariadic', 'paramSpecName': 'P'},
1301+
]
1302+
names, _ = mapping._process_method_params(params)
1303+
assert names == ['P']
1304+
1305+
1306+
@requires_ty_types_cli
1307+
class TestParamSpecAndConcatenateIntegration:
1308+
"""End-to-end tests exercising ty-types 0.0.31 ParamSpec/Concatenate output."""
1309+
1310+
def test_callable_paramspec_collapses_in_invocation(self):
1311+
"""`cb()` where `cb: Callable[P, R]` yields a method type with a
1312+
single collapsed `P` parameter rather than two variadic entries."""
1313+
source = '''from typing import Callable, ParamSpec, TypeVar
1314+
P = ParamSpec('P')
1315+
R = TypeVar('R')
1316+
1317+
def run(cb: Callable[P, R]) -> R:
1318+
return cb()
1319+
1320+
run(lambda: 42)
1321+
'''
1322+
mapping, tree, tmpdir, client = _make_mapping(source)
1323+
try:
1324+
cb_call = tree.body[3].body[0].value # cb() inside run
1325+
result = mapping.method_invocation_type(cb_call)
1326+
assert result is not None
1327+
assert result._parameter_names == ['P']
1328+
assert result._parameter_types is not None
1329+
assert len(result._parameter_types) == 1
1330+
assert isinstance(result._parameter_types[0], JavaType.Unknown)
1331+
finally:
1332+
_cleanup_mapping(mapping, tmpdir, client)
1333+
1334+
def test_concatenate_keeps_prefix_and_collapses_tail(self):
1335+
"""`cb(1)` where `cb: Callable[Concatenate[int, P], R]` yields a
1336+
method type with the leading `int` plus a single collapsed `P`."""
1337+
source = '''from typing import Callable, Concatenate, ParamSpec, TypeVar
1338+
P = ParamSpec('P')
1339+
R = TypeVar('R')
1340+
1341+
def run(cb: Callable[Concatenate[int, P], R]) -> R:
1342+
return cb(1)
1343+
1344+
run(lambda x: x)
1345+
'''
1346+
mapping, tree, tmpdir, client = _make_mapping(source)
1347+
try:
1348+
cb_call = tree.body[3].body[0].value # cb(1) inside run
1349+
result = mapping.method_invocation_type(cb_call)
1350+
assert result is not None
1351+
assert result._parameter_names is not None
1352+
assert len(result._parameter_names) == 2
1353+
assert result._parameter_names[-1] == 'P'
1354+
assert result._parameter_types is not None
1355+
assert result._parameter_types[0] == JavaType.Primitive.Int
1356+
assert isinstance(result._parameter_types[1], JavaType.Unknown)
1357+
finally:
1358+
_cleanup_mapping(mapping, tmpdir, client)
1359+
1360+
def test_plain_function_method_type_unchanged(self):
1361+
"""Regression: a function with no ParamSpec/Concatenate produces the
1362+
same (name, type) pairs it did before the 0.0.31 field additions."""
1363+
source = '''def add(a: int, b: int) -> int:
1364+
return a + b
1365+
1366+
add(1, 2)
1367+
'''
1368+
mapping, tree, tmpdir, client = _make_mapping(source)
1369+
try:
1370+
call = tree.body[1].value
1371+
result = mapping.method_invocation_type(call)
1372+
assert result is not None
1373+
assert result._parameter_names == ['a', 'b']
1374+
assert result._parameter_types == [
1375+
JavaType.Primitive.Int, JavaType.Primitive.Int]
1376+
finally:
1377+
_cleanup_mapping(mapping, tmpdir, client)
1378+
1379+
12431380
@requires_ty_types_cli
12441381
class TestClassKind:
12451382
"""Tests for JavaType.Class.Kind inference from ty-types data."""

0 commit comments

Comments
 (0)