Skip to content

Commit 19ed6c3

Browse files
committed
use argument type for target overload when filtering overloads with gradual types
1 parent b734146 commit 19ed6c3

3 files changed

Lines changed: 83 additions & 70 deletions

File tree

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,17 @@ def list_or_set2[T, U](x: T, y: U) -> list[T] | set[U]:
319319
# error: [no-matching-overload]
320320
x8 = f6(reveal_type(list_or_set2(1, 1))) # revealed: list[int] | set[int]
321321
reveal_type(x8) # revealed: Unknown
322+
323+
@overload
324+
def f7(y: list[int | str]) -> list[int | str]: ...
325+
@overload
326+
def f7[T](y: list[T]) -> list[T]: ...
327+
def f7(y: object) -> object:
328+
raise NotImplementedError
329+
330+
# TODO: We should reveal `list[int | str]` here.
331+
x9 = f7(reveal_type(["Sheet1"])) # revealed: list[str]
332+
reveal_type(x9) # revealed: list[int | str]
322333
```
323334

324335
## Class constructor parameters

crates/ty_python_semantic/src/types/call/arguments.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ pub(crate) struct CallArguments<'a, 'db> {
5353
/// with type context across multiple bindings.
5454
#[derive(Clone, Debug, Default)]
5555
pub(crate) struct CallArgumentTypes<'db> {
56-
fallback_ty: Option<Type<'db>>,
56+
fallback_type: Option<Type<'db>>,
5757
types: FxHashMap<Type<'db>, Type<'db>>,
5858
}
5959

6060
impl<'db> CallArgumentTypes<'db> {
6161
pub(crate) fn new(fallback_ty: Option<Type<'db>>) -> Self {
6262
Self {
63-
fallback_ty,
63+
fallback_type: fallback_ty,
6464
types: FxHashMap::default(),
6565
}
6666
}
@@ -78,7 +78,7 @@ impl<'db> CallArgumentTypes<'db> {
7878
return Some(*exact_ty);
7979
}
8080

81-
self.fallback_ty
81+
self.fallback_type
8282
}
8383

8484
/// Returns the type of this argument when inferred against the provided declared type.
@@ -93,7 +93,7 @@ impl<'db> CallArgumentTypes<'db> {
9393
/// Insert the type of this argument when inferred with the provided type context.
9494
pub(crate) fn insert(&mut self, tcx: impl Into<TypeContext<'db>>, ty: Type<'db>) {
9595
match tcx.into().annotation {
96-
None => self.fallback_ty = Some(ty),
96+
None => self.fallback_type = Some(ty),
9797
Some(tcx) => {
9898
self.types.insert(tcx, ty);
9999
}
@@ -104,7 +104,7 @@ impl<'db> CallArgumentTypes<'db> {
104104
self.types
105105
.iter()
106106
.map(|(tcx, ty)| (TypeContext::new(Some(*tcx)), *ty))
107-
.chain(self.fallback_ty.map(|ty| (TypeContext::default(), ty)))
107+
.chain(self.fallback_type.map(|ty| (TypeContext::default(), ty)))
108108
}
109109
}
110110

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 67 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2590,71 +2590,6 @@ impl<'db> CallableBinding<'db> {
25902590
}
25912591
}
25922592

2593-
let mut union_argument_type_builders = std::iter::repeat_with(|| UnionBuilder::new(db))
2594-
.take(max_parameter_count)
2595-
.collect::<Vec<_>>();
2596-
2597-
// The following loop is trying to construct a tuple of argument types that correspond to
2598-
// the participating parameter indexes. Considering the following example:
2599-
//
2600-
// ```python
2601-
// @overload
2602-
// def f(x: Literal[1], y: Literal[2]) -> tuple[int, int]: ...
2603-
// @overload
2604-
// def f(*args: Any) -> tuple[Any, ...]: ...
2605-
//
2606-
// f(1, 2)
2607-
// ```
2608-
//
2609-
// Here, only the first parameter participates in the filtering process because only one
2610-
// overload has the second parameter. So, while going through the argument types, the
2611-
// second argument needs to be skipped but for the second overload both arguments map to
2612-
// the first parameter and that parameter is considered for the filtering process. This
2613-
// flag is to handle that special case of many-to-one mapping from arguments to parameters.
2614-
let mut variadic_parameter_handled = false;
2615-
2616-
for (argument_index, argument_types) in arguments.types().iter().enumerate() {
2617-
if variadic_parameter_handled {
2618-
continue;
2619-
}
2620-
for overload_index in matching_overload_indexes {
2621-
let overload = &self.overloads[*overload_index];
2622-
for (parameter_index, variadic_argument_type) in
2623-
overload.argument_matches[argument_index].iter()
2624-
{
2625-
let parameter = &overload.signature.parameters()[parameter_index];
2626-
if parameter.is_variadic() {
2627-
variadic_parameter_handled = true;
2628-
}
2629-
if !participating_parameter_indexes.contains(&parameter_index) {
2630-
continue;
2631-
}
2632-
let argument_type =
2633-
argument_types.get_for_declared_type(parameter.annotated_type());
2634-
union_argument_type_builders[parameter_index].add_in_place(
2635-
variadic_argument_type
2636-
.unwrap_or(argument_type)
2637-
.top_materialization(db),
2638-
);
2639-
}
2640-
}
2641-
}
2642-
2643-
// These only contain the top materialized argument types for the corresponding
2644-
// participating parameter indexes.
2645-
let top_materialized_argument_type = Type::heterogeneous_tuple(
2646-
db,
2647-
union_argument_type_builders
2648-
.into_iter()
2649-
.filter_map(|builder| {
2650-
if builder.is_empty() {
2651-
None
2652-
} else {
2653-
Some(builder.build())
2654-
}
2655-
}),
2656-
);
2657-
26582593
// A flag to indicate whether we've found the overload that makes the remaining overloads
26592594
// unmatched for the given argument types.
26602595
let mut filter_remaining_overloads = false;
@@ -2665,6 +2600,73 @@ impl<'db> CallableBinding<'db> {
26652600
continue;
26662601
}
26672602

2603+
let mut union_argument_type_builders = std::iter::repeat_with(|| UnionBuilder::new(db))
2604+
.take(max_parameter_count)
2605+
.collect::<Vec<_>>();
2606+
2607+
// The following loop is trying to construct a tuple of argument types that correspond to
2608+
// the participating parameter indexes. Considering the following example:
2609+
//
2610+
// ```python
2611+
// @overload
2612+
// def f(x: Literal[1], y: Literal[2]) -> tuple[int, int]: ...
2613+
// @overload
2614+
// def f(*args: Any) -> tuple[Any, ...]: ...
2615+
//
2616+
// f(1, 2)
2617+
// ```
2618+
//
2619+
// Here, only the first parameter participates in the filtering process because only one
2620+
// overload has the second parameter. So, while going through the argument types, the
2621+
// second argument needs to be skipped but for the second overload both arguments map to
2622+
// the first parameter and that parameter is considered for the filtering process. This
2623+
// flag is to handle that special case of many-to-one mapping from arguments to parameters.
2624+
let mut variadic_parameter_handled = false;
2625+
2626+
for (argument_index, argument_types) in arguments.types().iter().enumerate() {
2627+
if variadic_parameter_handled {
2628+
continue;
2629+
}
2630+
for overload_index in matching_overload_indexes {
2631+
let overload = &self.overloads[*overload_index];
2632+
for (parameter_index, variadic_argument_type) in
2633+
overload.argument_matches[argument_index].iter()
2634+
{
2635+
let parameter = &overload.signature.parameters()[parameter_index];
2636+
if parameter.is_variadic() {
2637+
variadic_parameter_handled = true;
2638+
}
2639+
if !participating_parameter_indexes.contains(&parameter_index) {
2640+
continue;
2641+
}
2642+
let current_parameter =
2643+
&self.overloads[*current_index].signature.parameters()[parameter_index];
2644+
let argument_type = argument_types
2645+
.get_for_declared_type(current_parameter.annotated_type());
2646+
union_argument_type_builders[parameter_index].add_in_place(
2647+
variadic_argument_type
2648+
.unwrap_or(argument_type)
2649+
.top_materialization(db),
2650+
);
2651+
}
2652+
}
2653+
}
2654+
2655+
// These only contain the top materialized argument types for the corresponding
2656+
// participating parameter indexes.
2657+
let top_materialized_argument_type = Type::heterogeneous_tuple(
2658+
db,
2659+
union_argument_type_builders
2660+
.into_iter()
2661+
.filter_map(|builder| {
2662+
if builder.is_empty() {
2663+
None
2664+
} else {
2665+
Some(builder.build())
2666+
}
2667+
}),
2668+
);
2669+
26682670
let mut union_parameter_types = std::iter::repeat_with(|| UnionBuilder::new(db))
26692671
.take(max_parameter_count)
26702672
.collect::<Vec<_>>();

0 commit comments

Comments
 (0)