Skip to content

Commit 77d6ec5

Browse files
Python: Add parameter field_type attribution and call-site type arguments (#6795)
* Add type attribution tests and fix method declaration/type hint types - Add type attribution tests to 13 existing parser test files covering method invocations, binary ops, type hints, collection literals, imports, field access, class instances, method declarations, async defs, for loops, unary ops, ternaries, and lambdas - Add `method_declaration_type()` to type_mapping.py to build JavaType.Method for function declarations using ty-types descriptor data with annotation fallback - Add type attribution to ParameterizedType nodes in type hint expressions - Fix pre-existing assign_test.py bug (simple_name access, FQN startswith) - Add typing.Text test to test_type_attribution.py - Bump ty-types dependency to >=0.0.19.dev20260223093555 * Add parameter field_type attribution and call-site type arguments - Add param_type_info() to PythonTypeMapping for function parameter Identifier nodes to get JavaType.Variable field_type - Update __convert_name() and map_arg() to flow field_type through to J.Identifier and NamedVariable - Use call signature returnTypeId for call-site-specific return types (e.g. int for identity(42) instead of generic T) - Populate _declared_formal_type_names on method invocation types from function descriptor type parameters - Bump ty-types to 0.0.19.dev20260223102528 for callSignature typeArguments/returnTypeId support
1 parent ea79af4 commit 77d6ec5

6 files changed

Lines changed: 252 additions & 11 deletions

File tree

rewrite-python/rewrite/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ requires-python = ">=3.10"
2424
dependencies = [
2525
"cbor2>=5.6.5",
2626
"more_itertools>=10.0.0",
27-
"ty-types>=0.0.19.dev20260223093555", # Type inference CLI for Python type attribution
27+
"ty-types>=0.0.19.dev20260223102528", # Type inference CLI for Python type attribution
2828
"parso>=0.7.1,<0.8", # Python 2/3 parser with CST support (0.8+ dropped Python 2.7 grammar)
2929
]
3030

rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,8 @@ def visit_arguments(self, node, with_close_paren: bool = True) -> List[JRightPad
338338
def map_arg(self, node, default=None, vararg=False, kwarg=False):
339339
prefix = self.__source_before('**') if kwarg else self.__whitespace()
340340
vararg_prefix = self.__source_before('*') if vararg else None
341-
name = self.__convert_name(node.arg, self._type_mapping.type(node))
341+
expr_type, field_type = self._type_mapping.param_type_info(node)
342+
name = self.__convert_name(node.arg, expr_type, field_type)
342343
after_name = self.__source_before(':') if node.annotation else Space.EMPTY
343344
type_expression = self.__convert_type(node.annotation) if node.annotation else None
344345
initializer = self.__pad_left(self.__source_before('='), self.__convert(default)) if default else None
@@ -359,7 +360,7 @@ def map_arg(self, node, default=None, vararg=False, kwarg=False):
359360
cast(j.Identifier, name),
360361
[],
361362
initializer,
362-
self.__as_variable_type(self._type_mapping.type(node))
363+
field_type
363364
), after_name)],
364365
)
365366

@@ -3039,12 +3040,13 @@ def __as_method_type(t: Optional[JavaType]) -> Optional[JavaType]:
30393040
return t
30403041
return None
30413042

3042-
def __convert_name(self, name: str, name_type: Optional[JavaType] = None) -> NameTree:
3043+
def __convert_name(self, name: str, name_type: Optional[JavaType] = None,
3044+
field_type: Optional[JavaType.Variable] = None) -> NameTree:
30433045
def ident_or_field(parts: List[str]) -> NameTree:
30443046
if len(parts) == 1:
30453047
space, actual_name = self.__consume_identifier(parts[-1])
30463048
return j.Identifier(random_id(), space, Markers.EMPTY, [], actual_name,
3047-
name_type, None)
3049+
name_type, field_type)
30483050
else:
30493051
return j.FieldAccess(
30503052
random_id(),
@@ -3055,7 +3057,7 @@ def ident_or_field(parts: List[str]) -> NameTree:
30553057
self.__source_before('.'),
30563058
(lambda s, n: j.Identifier(random_id(), s, Markers.EMPTY, [], n,
30573059
name_type,
3058-
None))(*self.__consume_identifier(parts[-1])),
3060+
field_type))(*self.__consume_identifier(parts[-1])),
30593061
),
30603062
name_type
30613063
)

rewrite-python/rewrite/src/rewrite/python/type_mapping.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,17 @@ def _descriptor_to_java_type(self, descriptor: Dict[str, Any]) -> Optional[JavaT
477477
if interfaces:
478478
class_type._interfaces = interfaces
479479

480+
# Populate type parameters from typeVar descriptors
481+
type_params = descriptor.get('typeParameters', [])
482+
if type_params and getattr(class_type, '_type_parameters', None) is None:
483+
resolved_type_params = []
484+
for tp_id in type_params:
485+
tp_type = self._resolve_type(tp_id)
486+
if tp_type is not None:
487+
resolved_type_params.append(tp_type)
488+
if resolved_type_params:
489+
class_type._type_parameters = resolved_type_params
490+
480491
# Populate methods from function/boundMethod members
481492
members = descriptor.get('members', [])
482493
if members and getattr(class_type, '_methods', None) is None:
@@ -533,6 +544,10 @@ def _descriptor_to_java_type(self, descriptor: Dict[str, Any]) -> Optional[JavaT
533544
elif kind == 'property':
534545
return _UNKNOWN
535546

547+
elif kind == 'typeVar':
548+
name = descriptor.get('name', '')
549+
return self._create_class_type(name) if name else _UNKNOWN
550+
536551
else:
537552
return _UNKNOWN
538553

@@ -601,6 +616,16 @@ def name_type_info(self, node: ast.Name) -> Tuple[Optional[JavaType], Optional[J
601616
return expr_type, JavaType.Variable(_name=node.id, _type=expr_type)
602617
return expr_type, None
603618

619+
def param_type_info(self, node: ast.arg) -> Tuple[Optional[JavaType], Optional[JavaType.Variable]]:
620+
"""Get expression type and variable type for a function parameter.
621+
622+
Returns (expression_type, variable_field_type).
623+
"""
624+
expr_type = self.type(node)
625+
if expr_type is None:
626+
return None, None
627+
return expr_type, JavaType.Variable(_name=node.arg, _type=expr_type)
628+
604629
def attribute_type_info(self, node: ast.Attribute,
605630
receiver_type: Optional[JavaType] = None
606631
) -> Tuple[Optional[JavaType], Optional[JavaType.Variable]]:
@@ -660,7 +685,13 @@ def method_declaration_type(self, node: ast.FunctionDef) -> Optional[JavaType.Me
660685
if node.returns is not None:
661686
return_type = self.type(node.returns)
662687

663-
if not param_names and return_type is None:
688+
# Extract type parameter names from Python 3.12+ type_params
689+
type_param_names: List[str] = []
690+
for tp in getattr(node, 'type_params', []) or []:
691+
if hasattr(tp, 'name'):
692+
type_param_names.append(tp.name)
693+
694+
if not param_names and return_type is None and not type_param_names:
664695
return None
665696

666697
return JavaType.Method(
@@ -670,6 +701,7 @@ def method_declaration_type(self, node: ast.FunctionDef) -> Optional[JavaType.Me
670701
_return_type=return_type,
671702
_parameter_names=param_names if param_names else None,
672703
_parameter_types=param_types if param_types else None,
704+
_declared_formal_type_names=type_param_names if type_param_names else None,
673705
)
674706

675707
def _method_from_function_descriptor(
@@ -690,13 +722,16 @@ def _method_from_function_descriptor(
690722
if ret_id is not None:
691723
return_type = self._resolve_type(ret_id)
692724

725+
type_param_names = self._extract_type_param_names(descriptor)
726+
693727
return JavaType.Method(
694728
_flags_bit_map=0,
695729
_declaring_type=None,
696730
_name=name,
697731
_return_type=return_type,
698732
_parameter_names=param_names if param_names else None,
699733
_parameter_types=param_types if param_types else None,
734+
_declared_formal_type_names=type_param_names if type_param_names else None,
700735
)
701736

702737
def method_invocation_type(self, node: ast.Call) -> Optional[JavaType.Method]:
@@ -729,13 +764,22 @@ def method_invocation_type(self, node: ast.Call) -> Optional[JavaType.Method]:
729764
# Get return type
730765
return_type = self._get_return_type(node)
731766

767+
# Extract type parameter names from function descriptor
768+
type_param_names: List[str] = []
769+
func_type_id = self._lookup_func_type_id(node)
770+
if func_type_id is not None:
771+
func_desc = self._type_registry.get(func_type_id)
772+
if func_desc:
773+
type_param_names = self._extract_type_param_names(func_desc)
774+
732775
return JavaType.Method(
733776
_flags_bit_map=0,
734777
_declaring_type=declaring_type,
735778
_name=method_name,
736779
_return_type=return_type,
737780
_parameter_names=param_names if param_names else None,
738781
_parameter_types=param_types if param_types else None,
782+
_declared_formal_type_names=type_param_names if type_param_names else None,
739783
)
740784

741785
def _extract_method_name(self, node: ast.Call) -> Optional[str]:
@@ -1037,9 +1081,20 @@ def _get_parameter_types(self, node: ast.Call) -> Optional[List[JavaType]]:
10371081
def _get_return_type(self, node: ast.Call) -> Optional[JavaType]:
10381082
"""Get the return type of a method call.
10391083
1040-
First tries the ExprCall node type (which IS the return type),
1041-
then falls back to the function descriptor's returnType field.
1084+
Prefers the call-site-specific returnTypeId from the call signature
1085+
(which gives resolved types like int instead of generic T),
1086+
then tries the ExprCall node type, then falls back to the function
1087+
descriptor's returnType field.
10421088
"""
1089+
# Prefer call-site-specific return type from call signature
1090+
sig = self._lookup_call_signature(node)
1091+
if sig:
1092+
ret_id = sig.get('returnTypeId')
1093+
if ret_id is not None:
1094+
result = self._resolve_type(ret_id)
1095+
if result is not None:
1096+
return result
1097+
10431098
# The type of an ExprCall node in ty-types IS the return type
10441099
type_id = self._lookup_type_id(node)
10451100
if type_id is not None:
@@ -1083,15 +1138,29 @@ def _create_method_from_descriptor(self, descriptor: Dict[str, Any],
10831138
else:
10841139
param_types.append(_UNKNOWN)
10851140

1141+
type_param_names = self._extract_type_param_names(descriptor)
1142+
10861143
return JavaType.Method(
10871144
_flags_bit_map=0,
10881145
_declaring_type=declaring_type,
10891146
_name=name,
10901147
_return_type=return_type,
10911148
_parameter_names=param_names if param_names else None,
10921149
_parameter_types=param_types if param_types else None,
1150+
_declared_formal_type_names=type_param_names if type_param_names else None,
10931151
)
10941152

1153+
def _extract_type_param_names(self, descriptor: Dict[str, Any]) -> List[str]:
1154+
"""Extract type parameter names from a descriptor's typeParameters list."""
1155+
names: List[str] = []
1156+
for tp_id in descriptor.get('typeParameters', []):
1157+
tp_desc = self._type_registry.get(tp_id)
1158+
if tp_desc and tp_desc.get('kind') == 'typeVar':
1159+
name = tp_desc.get('name', '')
1160+
if name:
1161+
names.append(name)
1162+
return names
1163+
10951164
def _create_class_type(self, fqn: str) -> JavaType.Class:
10961165
"""Create a JavaType.Class from a fully qualified name."""
10971166
if fqn in self._type_cache:

rewrite-python/rewrite/tests/python/all/tree/class_test.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from rewrite.java.support_types import JavaType
6-
from rewrite.java.tree import Assignment
6+
from rewrite.java.tree import Assignment, Identifier
77
from rewrite.python.tree import CompilationUnit
88
from rewrite.python.visitor import PythonVisitor
99
from rewrite.test import RecipeSpec, python
@@ -125,6 +125,55 @@ class C3(Generic[T], metaclass=type, *[str]):
125125
))
126126

127127

128+
@requires_ty_cli
129+
def test_generic_class_type_params():
130+
"""Verify type parameters on a generic class like class Box[T]."""
131+
errors = []
132+
133+
def check_types(source_file):
134+
assert isinstance(source_file, CompilationUnit)
135+
136+
class TypeChecker(PythonVisitor):
137+
def visit_assignment(self, assignment, p):
138+
if not isinstance(assignment, Assignment):
139+
return assignment
140+
# Only check the `x = Box(42)` assignment
141+
if not isinstance(assignment.variable, Identifier) or assignment.variable.simple_name != 'x':
142+
return assignment
143+
if assignment.type is None:
144+
errors.append("Assignment.type is None for x = Box(42)")
145+
elif isinstance(assignment.type, JavaType.Class):
146+
if assignment.type._fully_qualified_name != 'Box':
147+
errors.append(f"Assignment.type fqn is '{assignment.type._fully_qualified_name}', expected 'Box'")
148+
type_params = getattr(assignment.type, '_type_parameters', None)
149+
if type_params is None:
150+
errors.append("Box class type has no _type_parameters")
151+
elif len(type_params) != 1:
152+
errors.append(f"Box class type has {len(type_params)} type params, expected 1")
153+
else:
154+
tp = type_params[0]
155+
if isinstance(tp, JavaType.Class):
156+
if tp._fully_qualified_name != 'T':
157+
errors.append(f"type_parameter fqn is '{tp._fully_qualified_name}', expected 'T'")
158+
else:
159+
errors.append(f"type_parameter is {type(tp).__name__}, expected Class")
160+
return assignment
161+
162+
TypeChecker().visit(source_file, None)
163+
164+
# language=python
165+
RecipeSpec(type_attribution=True).rewrite_run(python(
166+
"""\
167+
class Box[T]:
168+
def __init__(self, value: T) -> None:
169+
self.value = value
170+
x = Box(42)
171+
""",
172+
after_recipe=check_types,
173+
))
174+
assert not errors, "Type attribution errors:\n" + "\n".join(f" - {e}" for e in errors)
175+
176+
128177
@requires_ty_cli
129178
def test_class_instance_type_attribution():
130179
"""Verify that x = Foo() assigns a type with fqn 'Foo'."""

rewrite-python/rewrite/tests/python/all/tree/method_declaration_test.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from rewrite.java.support_types import JavaType
6-
from rewrite.java.tree import MethodDeclaration
6+
from rewrite.java.tree import MethodDeclaration, VariableDeclarations
77
from rewrite.python.tree import CompilationUnit
88
from rewrite.python.visitor import PythonVisitor
99
from rewrite.test import RecipeSpec, python
@@ -167,6 +167,45 @@ def f(
167167
))
168168

169169

170+
@requires_ty_cli
171+
def test_generic_function_type_params():
172+
"""Verify method_type.declared_formal_type_names for def identity[T](x: T) -> T."""
173+
errors = []
174+
175+
def check_types(source_file):
176+
assert isinstance(source_file, CompilationUnit)
177+
178+
class TypeChecker(PythonVisitor):
179+
def visit_method_declaration(self, method, p):
180+
if not isinstance(method, MethodDeclaration):
181+
return method
182+
if method.name.simple_name != 'identity':
183+
return method
184+
if method.method_type is None:
185+
errors.append("MethodDeclaration.method_type is None")
186+
else:
187+
mt = method.method_type
188+
if mt._declared_formal_type_names is None:
189+
errors.append("method_type._declared_formal_type_names is None")
190+
elif mt._declared_formal_type_names != ['T']:
191+
errors.append(f"method_type._declared_formal_type_names is {mt._declared_formal_type_names}, expected ['T']")
192+
if mt._parameter_names is not None and 'x' not in mt._parameter_names:
193+
errors.append(f"parameter_names {mt._parameter_names} does not contain 'x'")
194+
return method
195+
196+
TypeChecker().visit(source_file, None)
197+
198+
# language=python
199+
RecipeSpec(type_attribution=True).rewrite_run(python(
200+
"""\
201+
def identity[T](x: T) -> T:
202+
return x
203+
""",
204+
after_recipe=check_types,
205+
))
206+
assert not errors, "Type attribution errors:\n" + "\n".join(f" - {e}" for e in errors)
207+
208+
170209
@requires_ty_cli
171210
def test_method_declaration_type_attribution():
172211
"""Verify method_type on a function with typed parameters and return type."""
@@ -216,3 +255,43 @@ def foo(a: int, b: str) -> bool:
216255
after_recipe=check_types,
217256
))
218257
assert not errors, "Type attribution errors:\n" + "\n".join(f" - {e}" for e in errors)
258+
259+
260+
@requires_ty_cli
261+
def test_param_identifier_field_type():
262+
"""Verify J.Identifier.field_type is JavaType.Variable for typed parameters."""
263+
errors = []
264+
265+
def check_types(source_file):
266+
assert isinstance(source_file, CompilationUnit)
267+
268+
class TypeChecker(PythonVisitor):
269+
def visit_variable_declarations(self, vd, p):
270+
if not isinstance(vd, VariableDeclarations):
271+
return vd
272+
for named_var in vd.variables:
273+
ident = named_var.name
274+
if ident.simple_name not in ('x', 'y'):
275+
continue
276+
if ident.field_type is None:
277+
errors.append(f"Identifier '{ident.simple_name}' has field_type=None")
278+
elif not isinstance(ident.field_type, JavaType.Variable):
279+
errors.append(f"Identifier '{ident.simple_name}' field_type is {type(ident.field_type)}, expected JavaType.Variable")
280+
else:
281+
if ident.field_type._name != ident.simple_name:
282+
errors.append(f"field_type._name is '{ident.field_type._name}', expected '{ident.simple_name}'")
283+
if ident.field_type._type is None:
284+
errors.append(f"field_type._type is None for '{ident.simple_name}'")
285+
return vd
286+
287+
TypeChecker().visit(source_file, None)
288+
289+
# language=python
290+
RecipeSpec(type_attribution=True).rewrite_run(python(
291+
"""\
292+
def greet(x: int, y: str) -> bool:
293+
return True
294+
""",
295+
after_recipe=check_types,
296+
))
297+
assert not errors, "Type attribution errors:\n" + "\n".join(f" - {e}" for e in errors)

0 commit comments

Comments
 (0)