Skip to content

Commit 4019cb1

Browse files
committed
Add parameter field_type attribution and call-site type arguments
- Add param_type_info() to PythonTypeMapping for function parameter Identifier nodes to get JavaType.Variable field_type - Update __convert_name() and map_arg() to flow field_type through to J.Identifier and NamedVariable - Use call signature returnTypeId for call-site-specific return types (e.g. int for identity(42) instead of generic T) - Populate _declared_formal_type_names on method invocation types from function descriptor type parameters - Bump ty-types to 0.0.19.dev20260223102528 for callSignature typeArguments/returnTypeId support
1 parent 63ed316 commit 4019cb1

6 files changed

Lines changed: 252 additions & 11 deletions

File tree

rewrite-python/rewrite/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ requires-python = ">=3.10"
2424
dependencies = [
2525
"cbor2>=5.6.5",
2626
"more_itertools>=10.0.0",
27-
"ty-types>=0.0.19.dev20260223093555", # Type inference CLI for Python type attribution
27+
"ty-types>=0.0.19.dev20260223102528", # Type inference CLI for Python type attribution
2828
"parso>=0.7.1,<0.8", # Python 2/3 parser with CST support (0.8+ dropped Python 2.7 grammar)
2929
]
3030

rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,8 @@ def visit_arguments(self, node, with_close_paren: bool = True) -> List[JRightPad
338338
def map_arg(self, node, default=None, vararg=False, kwarg=False):
339339
prefix = self.__source_before('**') if kwarg else self.__whitespace()
340340
vararg_prefix = self.__source_before('*') if vararg else None
341-
name = self.__convert_name(node.arg, self._type_mapping.type(node))
341+
expr_type, field_type = self._type_mapping.param_type_info(node)
342+
name = self.__convert_name(node.arg, expr_type, field_type)
342343
after_name = self.__source_before(':') if node.annotation else Space.EMPTY
343344
type_expression = self.__convert_type(node.annotation) if node.annotation else None
344345
initializer = self.__pad_left(self.__source_before('='), self.__convert(default)) if default else None
@@ -359,7 +360,7 @@ def map_arg(self, node, default=None, vararg=False, kwarg=False):
359360
cast(j.Identifier, name),
360361
[],
361362
initializer,
362-
self.__as_variable_type(self._type_mapping.type(node))
363+
field_type
363364
), after_name)],
364365
)
365366

@@ -3039,12 +3040,13 @@ def __as_method_type(t: Optional[JavaType]) -> Optional[JavaType]:
30393040
return t
30403041
return None
30413042

3042-
def __convert_name(self, name: str, name_type: Optional[JavaType] = None) -> NameTree:
3043+
def __convert_name(self, name: str, name_type: Optional[JavaType] = None,
3044+
field_type: Optional[JavaType.Variable] = None) -> NameTree:
30433045
def ident_or_field(parts: List[str]) -> NameTree:
30443046
if len(parts) == 1:
30453047
space, actual_name = self.__consume_identifier(parts[-1])
30463048
return j.Identifier(random_id(), space, Markers.EMPTY, [], actual_name,
3047-
name_type, None)
3049+
name_type, field_type)
30483050
else:
30493051
return j.FieldAccess(
30503052
random_id(),
@@ -3055,7 +3057,7 @@ def ident_or_field(parts: List[str]) -> NameTree:
30553057
self.__source_before('.'),
30563058
(lambda s, n: j.Identifier(random_id(), s, Markers.EMPTY, [], n,
30573059
name_type,
3058-
None))(*self.__consume_identifier(parts[-1])),
3060+
field_type))(*self.__consume_identifier(parts[-1])),
30593061
),
30603062
name_type
30613063
)

rewrite-python/rewrite/src/rewrite/python/type_mapping.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,17 @@ def _descriptor_to_java_type(self, descriptor: Dict[str, Any]) -> Optional[JavaT
477477
if interfaces:
478478
class_type._interfaces = interfaces
479479

480+
# Populate type parameters from typeVar descriptors
481+
type_params = descriptor.get('typeParameters', [])
482+
if type_params and getattr(class_type, '_type_parameters', None) is None:
483+
resolved_type_params = []
484+
for tp_id in type_params:
485+
tp_type = self._resolve_type(tp_id)
486+
if tp_type is not None:
487+
resolved_type_params.append(tp_type)
488+
if resolved_type_params:
489+
class_type._type_parameters = resolved_type_params
490+
480491
# Populate methods from function/boundMethod members
481492
members = descriptor.get('members', [])
482493
if members and getattr(class_type, '_methods', None) is None:
@@ -533,6 +544,10 @@ def _descriptor_to_java_type(self, descriptor: Dict[str, Any]) -> Optional[JavaT
533544
elif kind == 'property':
534545
return _UNKNOWN
535546

547+
elif kind == 'typeVar':
548+
name = descriptor.get('name', '')
549+
return self._create_class_type(name) if name else _UNKNOWN
550+
536551
else:
537552
return _UNKNOWN
538553

@@ -601,6 +616,16 @@ def name_type_info(self, node: ast.Name) -> Tuple[Optional[JavaType], Optional[J
601616
return expr_type, JavaType.Variable(_name=node.id, _type=expr_type)
602617
return expr_type, None
603618

619+
def param_type_info(self, node: ast.arg) -> Tuple[Optional[JavaType], Optional[JavaType.Variable]]:
620+
"""Get expression type and variable type for a function parameter.
621+
622+
Returns (expression_type, variable_field_type).
623+
"""
624+
expr_type = self.type(node)
625+
if expr_type is None:
626+
return None, None
627+
return expr_type, JavaType.Variable(_name=node.arg, _type=expr_type)
628+
604629
def attribute_type_info(self, node: ast.Attribute,
605630
receiver_type: Optional[JavaType] = None
606631
) -> Tuple[Optional[JavaType], Optional[JavaType.Variable]]:
@@ -660,7 +685,13 @@ def method_declaration_type(self, node: ast.FunctionDef) -> Optional[JavaType.Me
660685
if node.returns is not None:
661686
return_type = self.type(node.returns)
662687

663-
if not param_names and return_type is None:
688+
# Extract type parameter names from Python 3.12+ type_params
689+
type_param_names: List[str] = []
690+
for tp in getattr(node, 'type_params', []) or []:
691+
if hasattr(tp, 'name'):
692+
type_param_names.append(tp.name)
693+
694+
if not param_names and return_type is None and not type_param_names:
664695
return None
665696

666697
return JavaType.Method(
@@ -670,6 +701,7 @@ def method_declaration_type(self, node: ast.FunctionDef) -> Optional[JavaType.Me
670701
_return_type=return_type,
671702
_parameter_names=param_names if param_names else None,
672703
_parameter_types=param_types if param_types else None,
704+
_declared_formal_type_names=type_param_names if type_param_names else None,
673705
)
674706

675707
def _method_from_function_descriptor(
@@ -690,13 +722,16 @@ def _method_from_function_descriptor(
690722
if ret_id is not None:
691723
return_type = self._resolve_type(ret_id)
692724

725+
type_param_names = self._extract_type_param_names(descriptor)
726+
693727
return JavaType.Method(
694728
_flags_bit_map=0,
695729
_declaring_type=None,
696730
_name=name,
697731
_return_type=return_type,
698732
_parameter_names=param_names if param_names else None,
699733
_parameter_types=param_types if param_types else None,
734+
_declared_formal_type_names=type_param_names if type_param_names else None,
700735
)
701736

702737
def method_invocation_type(self, node: ast.Call) -> Optional[JavaType.Method]:
@@ -729,13 +764,22 @@ def method_invocation_type(self, node: ast.Call) -> Optional[JavaType.Method]:
729764
# Get return type
730765
return_type = self._get_return_type(node)
731766

767+
# Extract type parameter names from function descriptor
768+
type_param_names: List[str] = []
769+
func_type_id = self._lookup_func_type_id(node)
770+
if func_type_id is not None:
771+
func_desc = self._type_registry.get(func_type_id)
772+
if func_desc:
773+
type_param_names = self._extract_type_param_names(func_desc)
774+
732775
return JavaType.Method(
733776
_flags_bit_map=0,
734777
_declaring_type=declaring_type,
735778
_name=method_name,
736779
_return_type=return_type,
737780
_parameter_names=param_names if param_names else None,
738781
_parameter_types=param_types if param_types else None,
782+
_declared_formal_type_names=type_param_names if type_param_names else None,
739783
)
740784

741785
def _extract_method_name(self, node: ast.Call) -> Optional[str]:
@@ -1037,9 +1081,20 @@ def _get_parameter_types(self, node: ast.Call) -> Optional[List[JavaType]]:
10371081
def _get_return_type(self, node: ast.Call) -> Optional[JavaType]:
10381082
"""Get the return type of a method call.
10391083
1040-
First tries the ExprCall node type (which IS the return type),
1041-
then falls back to the function descriptor's returnType field.
1084+
Prefers the call-site-specific returnTypeId from the call signature
1085+
(which gives resolved types like int instead of generic T),
1086+
then tries the ExprCall node type, then falls back to the function
1087+
descriptor's returnType field.
10421088
"""
1089+
# Prefer call-site-specific return type from call signature
1090+
sig = self._lookup_call_signature(node)
1091+
if sig:
1092+
ret_id = sig.get('returnTypeId')
1093+
if ret_id is not None:
1094+
result = self._resolve_type(ret_id)
1095+
if result is not None:
1096+
return result
1097+
10431098
# The type of an ExprCall node in ty-types IS the return type
10441099
type_id = self._lookup_type_id(node)
10451100
if type_id is not None:
@@ -1083,15 +1138,29 @@ def _create_method_from_descriptor(self, descriptor: Dict[str, Any],
10831138
else:
10841139
param_types.append(_UNKNOWN)
10851140

1141+
type_param_names = self._extract_type_param_names(descriptor)
1142+
10861143
return JavaType.Method(
10871144
_flags_bit_map=0,
10881145
_declaring_type=declaring_type,
10891146
_name=name,
10901147
_return_type=return_type,
10911148
_parameter_names=param_names if param_names else None,
10921149
_parameter_types=param_types if param_types else None,
1150+
_declared_formal_type_names=type_param_names if type_param_names else None,
10931151
)
10941152

1153+
def _extract_type_param_names(self, descriptor: Dict[str, Any]) -> List[str]:
1154+
"""Extract type parameter names from a descriptor's typeParameters list."""
1155+
names: List[str] = []
1156+
for tp_id in descriptor.get('typeParameters', []):
1157+
tp_desc = self._type_registry.get(tp_id)
1158+
if tp_desc and tp_desc.get('kind') == 'typeVar':
1159+
name = tp_desc.get('name', '')
1160+
if name:
1161+
names.append(name)
1162+
return names
1163+
10951164
def _create_class_type(self, fqn: str) -> JavaType.Class:
10961165
"""Create a JavaType.Class from a fully qualified name."""
10971166
if fqn in self._type_cache:

rewrite-python/rewrite/tests/python/all/tree/class_test.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from rewrite.java.support_types import JavaType
6-
from rewrite.java.tree import Assignment
6+
from rewrite.java.tree import Assignment, Identifier
77
from rewrite.python.tree import CompilationUnit
88
from rewrite.python.visitor import PythonVisitor
99
from rewrite.test import RecipeSpec, python
@@ -125,6 +125,55 @@ class C3(Generic[T], metaclass=type, *[str]):
125125
))
126126

127127

128+
@requires_ty_cli
129+
def test_generic_class_type_params():
130+
"""Verify type parameters on a generic class like class Box[T]."""
131+
errors = []
132+
133+
def check_types(source_file):
134+
assert isinstance(source_file, CompilationUnit)
135+
136+
class TypeChecker(PythonVisitor):
137+
def visit_assignment(self, assignment, p):
138+
if not isinstance(assignment, Assignment):
139+
return assignment
140+
# Only check the `x = Box(42)` assignment
141+
if not isinstance(assignment.variable, Identifier) or assignment.variable.simple_name != 'x':
142+
return assignment
143+
if assignment.type is None:
144+
errors.append("Assignment.type is None for x = Box(42)")
145+
elif isinstance(assignment.type, JavaType.Class):
146+
if assignment.type._fully_qualified_name != 'Box':
147+
errors.append(f"Assignment.type fqn is '{assignment.type._fully_qualified_name}', expected 'Box'")
148+
type_params = getattr(assignment.type, '_type_parameters', None)
149+
if type_params is None:
150+
errors.append("Box class type has no _type_parameters")
151+
elif len(type_params) != 1:
152+
errors.append(f"Box class type has {len(type_params)} type params, expected 1")
153+
else:
154+
tp = type_params[0]
155+
if isinstance(tp, JavaType.Class):
156+
if tp._fully_qualified_name != 'T':
157+
errors.append(f"type_parameter fqn is '{tp._fully_qualified_name}', expected 'T'")
158+
else:
159+
errors.append(f"type_parameter is {type(tp).__name__}, expected Class")
160+
return assignment
161+
162+
TypeChecker().visit(source_file, None)
163+
164+
# language=python
165+
RecipeSpec(type_attribution=True).rewrite_run(python(
166+
"""\
167+
class Box[T]:
168+
def __init__(self, value: T) -> None:
169+
self.value = value
170+
x = Box(42)
171+
""",
172+
after_recipe=check_types,
173+
))
174+
assert not errors, "Type attribution errors:\n" + "\n".join(f" - {e}" for e in errors)
175+
176+
128177
@requires_ty_cli
129178
def test_class_instance_type_attribution():
130179
"""Verify that x = Foo() assigns a type with fqn 'Foo'."""

rewrite-python/rewrite/tests/python/all/tree/method_declaration_test.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from rewrite.java.support_types import JavaType
6-
from rewrite.java.tree import MethodDeclaration
6+
from rewrite.java.tree import MethodDeclaration, VariableDeclarations
77
from rewrite.python.tree import CompilationUnit
88
from rewrite.python.visitor import PythonVisitor
99
from rewrite.test import RecipeSpec, python
@@ -167,6 +167,45 @@ def f(
167167
))
168168

169169

170+
@requires_ty_cli
171+
def test_generic_function_type_params():
172+
"""Verify method_type.declared_formal_type_names for def identity[T](x: T) -> T."""
173+
errors = []
174+
175+
def check_types(source_file):
176+
assert isinstance(source_file, CompilationUnit)
177+
178+
class TypeChecker(PythonVisitor):
179+
def visit_method_declaration(self, method, p):
180+
if not isinstance(method, MethodDeclaration):
181+
return method
182+
if method.name.simple_name != 'identity':
183+
return method
184+
if method.method_type is None:
185+
errors.append("MethodDeclaration.method_type is None")
186+
else:
187+
mt = method.method_type
188+
if mt._declared_formal_type_names is None:
189+
errors.append("method_type._declared_formal_type_names is None")
190+
elif mt._declared_formal_type_names != ['T']:
191+
errors.append(f"method_type._declared_formal_type_names is {mt._declared_formal_type_names}, expected ['T']")
192+
if mt._parameter_names is not None and 'x' not in mt._parameter_names:
193+
errors.append(f"parameter_names {mt._parameter_names} does not contain 'x'")
194+
return method
195+
196+
TypeChecker().visit(source_file, None)
197+
198+
# language=python
199+
RecipeSpec(type_attribution=True).rewrite_run(python(
200+
"""\
201+
def identity[T](x: T) -> T:
202+
return x
203+
""",
204+
after_recipe=check_types,
205+
))
206+
assert not errors, "Type attribution errors:\n" + "\n".join(f" - {e}" for e in errors)
207+
208+
170209
@requires_ty_cli
171210
def test_method_declaration_type_attribution():
172211
"""Verify method_type on a function with typed parameters and return type."""
@@ -216,3 +255,43 @@ def foo(a: int, b: str) -> bool:
216255
after_recipe=check_types,
217256
))
218257
assert not errors, "Type attribution errors:\n" + "\n".join(f" - {e}" for e in errors)
258+
259+
260+
@requires_ty_cli
261+
def test_param_identifier_field_type():
262+
"""Verify J.Identifier.field_type is JavaType.Variable for typed parameters."""
263+
errors = []
264+
265+
def check_types(source_file):
266+
assert isinstance(source_file, CompilationUnit)
267+
268+
class TypeChecker(PythonVisitor):
269+
def visit_variable_declarations(self, vd, p):
270+
if not isinstance(vd, VariableDeclarations):
271+
return vd
272+
for named_var in vd.variables:
273+
ident = named_var.name
274+
if ident.simple_name not in ('x', 'y'):
275+
continue
276+
if ident.field_type is None:
277+
errors.append(f"Identifier '{ident.simple_name}' has field_type=None")
278+
elif not isinstance(ident.field_type, JavaType.Variable):
279+
errors.append(f"Identifier '{ident.simple_name}' field_type is {type(ident.field_type)}, expected JavaType.Variable")
280+
else:
281+
if ident.field_type._name != ident.simple_name:
282+
errors.append(f"field_type._name is '{ident.field_type._name}', expected '{ident.simple_name}'")
283+
if ident.field_type._type is None:
284+
errors.append(f"field_type._type is None for '{ident.simple_name}'")
285+
return vd
286+
287+
TypeChecker().visit(source_file, None)
288+
289+
# language=python
290+
RecipeSpec(type_attribution=True).rewrite_run(python(
291+
"""\
292+
def greet(x: int, y: str) -> bool:
293+
return True
294+
""",
295+
after_recipe=check_types,
296+
))
297+
assert not errors, "Type attribution errors:\n" + "\n".join(f" - {e}" for e in errors)

0 commit comments

Comments
 (0)