Skip to content

Commit 11b6cfd

Browse files
[ty] Fix wrong inlay hints for overloaded function arguments (#23179)
## Summary Inlay hints for function call arguments were showing parameter names from the wrong overload variant because they only performed arity-based matching. Closes astral-sh/ty#1985.
1 parent 6fa0873 commit 11b6cfd

3 files changed

Lines changed: 238 additions & 53 deletions

File tree

crates/ty_ide/src/inlay_hints.rs

Lines changed: 144 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6191,13 +6191,14 @@ mod tests {
61916191
|
61926192
61936193
info[inlay-hint-location]: Inlay Hint Target
6194-
--> main.py:5:9
6194+
--> main.py:7:9
61956195
|
6196-
4 | @overload
61976196
5 | def foo(x: int) -> str: ...
6198-
| ^
61996197
6 | @overload
62006198
7 | def foo(x: str) -> int: ...
6199+
| ^
6200+
8 | def foo(x):
6201+
9 | return x
62016202
|
62026203
info: Source
62036204
--> main2.py:12:6
@@ -6209,6 +6210,146 @@ mod tests {
62096210
");
62106211
}
62116212

6213+
#[test]
6214+
fn test_overloaded_function_calls_different_params() {
6215+
let mut test = inlay_hint_test(
6216+
"
6217+
from typing import overload, Optional, Sequence
6218+
6219+
@overload
6220+
def S(name: str, is_symmetric: Optional[bool] = None) -> str: ...
6221+
@overload
6222+
def S(*names: str, is_symmetric: Optional[bool] = None) -> Sequence[str]: ...
6223+
def S():
6224+
pass
6225+
6226+
b = S('x', 'y')",
6227+
);
6228+
6229+
// The call S('x', 'y') should match the second overload (*names: str),
6230+
// and since *names is variadic, no parameter name hints should be shown.
6231+
// Before the fix, this incorrectly showed `name=` and `is_symmetric=` hints
6232+
// from the first overload.
6233+
assert_snapshot!(test.inlay_hints(), @"
6234+
6235+
from typing import overload, Optional, Sequence
6236+
6237+
@overload
6238+
def S(name: str, is_symmetric: Optional[bool] = None) -> str: ...
6239+
@overload
6240+
def S(*names: str, is_symmetric: Optional[bool] = None) -> Sequence[str]: ...
6241+
def S():
6242+
pass
6243+
6244+
b[: Sequence[str]] = S('x', 'y')
6245+
---------------------------------------------
6246+
info[inlay-hint-location]: Inlay Hint Target
6247+
--> stdlib/typing.pyi:1565:7
6248+
|
6249+
1563 | def __len__(self) -> int: ...
6250+
1564 |
6251+
1565 | class Sequence(Reversible[_T_co], Collection[_T_co]):
6252+
| ^^^^^^^^
6253+
1566 | \"\"\"All the operations on a read-only sequence.
6254+
|
6255+
info: Source
6256+
--> main2.py:11:5
6257+
|
6258+
9 | pass
6259+
10 |
6260+
11 | b[: Sequence[str]] = S('x', 'y')
6261+
| ^^^^^^^^
6262+
|
6263+
6264+
info[inlay-hint-location]: Inlay Hint Target
6265+
--> stdlib/builtins.pyi:915:7
6266+
|
6267+
914 | @disjoint_base
6268+
915 | class str(Sequence[str]):
6269+
| ^^^
6270+
916 | \"\"\"str(object='') -> str
6271+
917 | str(bytes_or_buffer[, encoding[, errors]]) -> str
6272+
|
6273+
info: Source
6274+
--> main2.py:11:14
6275+
|
6276+
9 | pass
6277+
10 |
6278+
11 | b[: Sequence[str]] = S('x', 'y')
6279+
| ^^^
6280+
|
6281+
6282+
---------------------------------------------
6283+
info[inlay-hint-edit]: File after edits
6284+
info: Source
6285+
6286+
from typing import overload, Optional, Sequence
6287+
6288+
@overload
6289+
def S(name: str, is_symmetric: Optional[bool] = None) -> str: ...
6290+
@overload
6291+
def S(*names: str, is_symmetric: Optional[bool] = None) -> Sequence[str]: ...
6292+
def S():
6293+
pass
6294+
6295+
b: Sequence[str] = S('x', 'y')
6296+
");
6297+
}
6298+
6299+
#[test]
6300+
fn test_overloaded_function_calls_no_matching_overload() {
6301+
let mut test = inlay_hint_test(
6302+
"
6303+
from typing import overload
6304+
6305+
@overload
6306+
def f(x: int) -> str: ...
6307+
@overload
6308+
def f(x: str, y: str) -> int: ...
6309+
def f(x):
6310+
return x
6311+
6312+
f([])
6313+
",
6314+
);
6315+
6316+
// Neither overload matches via type checking (list[Unknown] is neither int nor str),
6317+
// so `matching_overloads()` returns empty. The arity-based fallback picks the first
6318+
// overload (1 matched arg out of 1 required), and we should see the `x=` hint.
6319+
assert_snapshot!(test.inlay_hints(), @r"
6320+
6321+
from typing import overload
6322+
6323+
@overload
6324+
def f(x: int) -> str: ...
6325+
@overload
6326+
def f(x: str, y: str) -> int: ...
6327+
def f(x):
6328+
return x
6329+
6330+
f([x=][])
6331+
6332+
---------------------------------------------
6333+
info[inlay-hint-location]: Inlay Hint Target
6334+
--> main.py:5:7
6335+
|
6336+
4 | @overload
6337+
5 | def f(x: int) -> str: ...
6338+
| ^
6339+
6 | @overload
6340+
7 | def f(x: str, y: str) -> int: ...
6341+
|
6342+
info: Source
6343+
--> main2.py:11:4
6344+
|
6345+
9 | return x
6346+
10 |
6347+
11 | f([x=][])
6348+
| ^
6349+
|
6350+
");
6351+
}
6352+
62126353
#[test]
62136354
fn test_disabled_function_argument_names() {
62146355
let mut test = inlay_hint_test(

crates/ty_python_semantic/src/types/call/arguments.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,23 @@ impl<'a, 'db> CallArguments<'a, 'db> {
7373
/// typechecking.
7474
pub(crate) fn from_arguments_typed(
7575
arguments: &'a ast::Arguments,
76-
mut infer_argument_type: impl FnMut(Option<&ast::Expr>, &ast::Expr) -> Type<'db>,
76+
mut infer_argument_type: impl FnMut(&ast::Expr) -> Type<'db>,
7777
) -> Self {
7878
arguments
7979
.arguments_source_order()
8080
.map(|arg_or_keyword| match arg_or_keyword {
8181
ast::ArgOrKeyword::Arg(arg) => match arg {
8282
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
83-
let ty = infer_argument_type(Some(arg), value);
83+
let ty = infer_argument_type(value);
8484
(Argument::Variadic, Some(ty))
8585
}
8686
_ => {
87-
let ty = infer_argument_type(None, arg);
87+
let ty = infer_argument_type(arg);
8888
(Argument::Positional, Some(ty))
8989
}
9090
},
9191
ast::ArgOrKeyword::Keyword(ast::Keyword { arg, value, .. }) => {
92-
let ty = infer_argument_type(None, value);
92+
let ty = infer_argument_type(value);
9393
if let Some(arg) = arg {
9494
(Argument::Keyword(&arg.id), Some(ty))
9595
} else {

crates/ty_python_semantic/src/types/ide_support.rs

Lines changed: 90 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::place::builtins_module_scope;
55
use crate::semantic_index::definition::Definition;
66
use crate::semantic_index::definition::DefinitionKind;
77
use crate::semantic_index::{attribute_scopes, global_scope, semantic_index, use_def_map};
8-
use crate::types::call::{CallArguments, MatchedArgument};
8+
use crate::types::call::{CallArguments, CallError, MatchedArgument};
99
use crate::types::signatures::{ParameterKind, Signature};
1010
use crate::types::{
1111
CallDunderError, CallableTypes, ClassBase, ClassLiteral, ClassType, KnownUnion, Type,
@@ -514,7 +514,29 @@ pub struct CallSignatureDetails<'db> {
514514
pub argument_to_parameter_mapping: Vec<MatchedArgument<'db>>,
515515
}
516516

517-
impl CallSignatureDetails<'_> {
517+
impl<'db> CallSignatureDetails<'db> {
518+
fn from_binding(db: &'db dyn Db, binding: &crate::types::call::Binding<'db>) -> Self {
519+
let argument_to_parameter_mapping = binding.argument_matches().to_vec();
520+
let signature = binding.signature.clone();
521+
let display_details = signature.display(db).to_string_parts();
522+
let (parameter_kinds, parameter_types): (Vec<ParameterKind>, Vec<Type>) = signature
523+
.parameters()
524+
.iter()
525+
.map(|param| (param.kind().clone(), param.annotated_type()))
526+
.unzip();
527+
528+
CallSignatureDetails {
529+
definition: signature.definition(),
530+
signature,
531+
label: display_details.label,
532+
parameter_label_offsets: display_details.parameter_ranges,
533+
parameter_names: display_details.parameter_names,
534+
parameter_kinds,
535+
parameter_types,
536+
argument_to_parameter_mapping,
537+
}
538+
}
539+
518540
fn get_definition_parameter_range(&self, db: &dyn Db, name: &str) -> Option<FileRange> {
519541
let definition = self.signature.definition()?;
520542
let file = definition.file(db);
@@ -542,10 +564,12 @@ pub fn call_signature_details<'db>(
542564
return Vec::new();
543565
};
544566

567+
let db = model.db();
568+
545569
// Use into_callable to handle all the complex type conversions
546570
if let Some(callable_type) = func_type
547-
.try_upcast_to_callable(model.db())
548-
.map(|callables| callables.into_type(model.db()))
571+
.try_upcast_to_callable(db)
572+
.map(|callables| callables.into_type(db))
549573
{
550574
let call_arguments =
551575
CallArguments::from_arguments(&call_expr.arguments, |_, splatted_value| {
@@ -554,36 +578,14 @@ pub fn call_signature_details<'db>(
554578
.unwrap_or(Type::unknown())
555579
});
556580
let bindings = callable_type
557-
.bindings(model.db())
558-
.match_parameters(model.db(), &call_arguments);
581+
.bindings(db)
582+
.match_parameters(db, &call_arguments);
559583

560584
// Extract signature details from all callable bindings
561585
bindings
562586
.into_iter()
563587
.flatten()
564-
.map(|binding| {
565-
let argument_to_parameter_mapping = binding.argument_matches().to_vec();
566-
let signature = binding.signature;
567-
let display_details = signature.display(model.db()).to_string_parts();
568-
let parameter_label_offsets = display_details.parameter_ranges;
569-
let parameter_names = display_details.parameter_names;
570-
let (parameter_kinds, parameter_types): (Vec<ParameterKind>, Vec<Type>) = signature
571-
.parameters()
572-
.iter()
573-
.map(|param| (param.kind().clone(), param.annotated_type()))
574-
.unzip();
575-
576-
CallSignatureDetails {
577-
definition: signature.definition(),
578-
signature,
579-
label: display_details.label,
580-
parameter_label_offsets,
581-
parameter_names,
582-
parameter_kinds,
583-
parameter_types,
584-
argument_to_parameter_mapping,
585-
}
586-
})
588+
.map(|binding| CallSignatureDetails::from_binding(db, &binding))
587589
.collect()
588590
} else {
589591
// Type is not callable, return empty signatures
@@ -621,7 +623,7 @@ pub fn call_type_simplified_by_overloads(
621623
}
622624

623625
// Hand the overload resolution system as much type info as we have
624-
let args = CallArguments::from_arguments_typed(&call_expr.arguments, |_, splatted_value| {
626+
let args = CallArguments::from_arguments_typed(&call_expr.arguments, |splatted_value| {
625627
splatted_value
626628
.inferred_type(model)
627629
.unwrap_or(Type::unknown())
@@ -796,6 +798,60 @@ pub fn find_active_signature_from_details(
796798
Some(best_index)
797799
}
798800

801+
/// Resolve a call expression to its matching overload's signature details,
802+
/// using full type checking (not just arity matching) for overload resolution.
803+
///
804+
/// Falls back to arity-based matching if type-based resolution fails.
805+
fn resolve_call_signature<'db>(
806+
model: &SemanticModel<'db>,
807+
call_expr: &ast::ExprCall,
808+
) -> Option<CallSignatureDetails<'db>> {
809+
let db = model.db();
810+
let func_type = call_expr.func.inferred_type(model)?;
811+
let callable_type = func_type.try_upcast_to_callable(db)?.into_type(db);
812+
813+
let args = CallArguments::from_arguments_typed(&call_expr.arguments, |splatted_value| {
814+
splatted_value
815+
.inferred_type(model)
816+
.unwrap_or(Type::unknown())
817+
});
818+
819+
// Extract the `Bindings` regardless of whether type checking succeeded or failed.
820+
let bindings = callable_type
821+
.bindings(db)
822+
.match_parameters(db, &args)
823+
.check_types(db, &args, TypeContext::default(), &[])
824+
.unwrap_or_else(|CallError(_, bindings)| *bindings);
825+
826+
// First, try to find the matching overload after full type checking.
827+
let type_checked_details: Vec<_> = bindings
828+
.iter()
829+
.flat_map(|binding| binding.matching_overloads().map(|(_, overload)| overload))
830+
.map(|binding| CallSignatureDetails::from_binding(db, binding))
831+
.collect();
832+
833+
if !type_checked_details.is_empty() {
834+
let active = find_active_signature_from_details(&type_checked_details)?;
835+
return type_checked_details.into_iter().nth(active);
836+
}
837+
838+
// If all overloads have type-checking errors (e.g., `InvalidArgumentType`),
839+
// `matching_overloads()` returns empty. Fall back to arity-based matching
840+
// across all overloads to pick the best candidate for showing hints.
841+
let all_details: Vec<_> = bindings
842+
.iter()
843+
.flatten()
844+
.map(|binding| CallSignatureDetails::from_binding(db, binding))
845+
.collect();
846+
847+
if all_details.is_empty() {
848+
return None;
849+
}
850+
851+
let active = find_active_signature_from_details(&all_details)?;
852+
all_details.into_iter().nth(active)
853+
}
854+
799855
#[derive(Default)]
800856
pub struct InlayHintCallArgumentDetails {
801857
/// The position of the arguments mapped to their name and the range of the argument definition in the signature.
@@ -807,25 +863,14 @@ pub fn inlay_hint_call_argument_details<'db>(
807863
model: &SemanticModel<'db>,
808864
call_expr: &ast::ExprCall,
809865
) -> Option<InlayHintCallArgumentDetails> {
810-
let signature_details = call_signature_details(model, call_expr);
866+
let resolved = resolve_call_signature(model, call_expr)?;
811867

812-
if signature_details.is_empty() {
813-
return None;
814-
}
815-
816-
let active_signature_index = find_active_signature_from_details(&signature_details)?;
817-
818-
let call_signature_details = signature_details.get(active_signature_index)?;
819-
820-
let parameters = call_signature_details.signature.parameters();
868+
let parameters = resolved.signature.parameters();
821869

822870
let mut argument_names = HashMap::new();
823871

824872
for arg_index in 0..call_expr.arguments.args.len() {
825-
let Some(arg_mapping) = call_signature_details
826-
.argument_to_parameter_mapping
827-
.get(arg_index)
828-
else {
873+
let Some(arg_mapping) = resolved.argument_to_parameter_mapping.get(arg_index) else {
829874
continue;
830875
};
831876

@@ -847,8 +892,7 @@ pub fn inlay_hint_call_argument_details<'db>(
847892
continue;
848893
};
849894

850-
let parameter_label_offset =
851-
call_signature_details.get_definition_parameter_range(db, param.name()?);
895+
let parameter_label_offset = resolved.get_definition_parameter_range(db, param.name()?);
852896

853897
// Only add hints for parameters that can be specified by name
854898
if !param.is_positional_only() && !param.is_variadic() && !param.is_keyword_variadic() {

0 commit comments

Comments
 (0)