Skip to content

Commit 4ff86c7

Browse files
authored
[ty] Support multi-inference through type aliases (#25245)
Resolves astral-sh/ty#3487.
1 parent 04e5894 commit 4ff86c7

3 files changed

Lines changed: 68 additions & 23 deletions

File tree

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ python-version = "3.12"
435435
```
436436

437437
```py
438-
from typing import Literal, Sequence
438+
from typing import Literal, Mapping, Sequence
439439

440440
def f[T](x: T) -> list[T]:
441441
return [x]
@@ -482,9 +482,20 @@ reveal_type(x12) # revealed: list[str | None]
482482
x13: dict[str, list[int | None]] | dict[str, list[str | None]] = {"a": ["b"]}
483483
reveal_type(x13) # revealed: dict[str, list[str | None]]
484484

485-
x14 = [{"a": [1], "b": 1}, {"a": [1]}]
486-
x14.append(reveal_type({"b": 1})) # revealed: dict[str, list[int] | int]
487-
reveal_type(x14) # revealed: list[dict[str, list[int] | int] | dict[str, list[int]]]
485+
x14: Mapping[str, list[int | None]] | Mapping[str, list[str | None]] = {"a": ["b"]}
486+
reveal_type(x14) # revealed: dict[str, list[str | None]]
487+
488+
x15 = [{"a": [1], "b": 1}, {"a": [1]}]
489+
x15.append(reveal_type({"b": 1})) # revealed: dict[str, list[int] | int]
490+
reveal_type(x15) # revealed: list[dict[str, list[int] | int] | dict[str, list[int]]]
491+
492+
type EitherList = list[int | str] | list[int | None]
493+
494+
x16: EitherList = [None, None]
495+
reveal_type(x16) # revealed: list[int | None]
496+
497+
x17: EitherList = ["1", "2", "3"]
498+
reveal_type(x17) # revealed: list[int | str]
488499
```
489500

490501
## Annotations influence generic call argument inference
@@ -855,7 +866,7 @@ x19: dict[int, dict[str, int]] = defaultdict(dict)
855866
reveal_type(x19) # revealed: defaultdict[int, dict[str, int]]
856867
```
857868

858-
## Narrow generic unions
869+
## Narrow union declared type for generic calls
859870

860871
```toml
861872
[environment]
@@ -907,6 +918,37 @@ def _(target: TargetWithTD):
907918
reveal_type(target) # revealed: (TD, /) -> None
908919
```
909920

921+
```py
922+
from typing import Mapping, Sequence
923+
924+
x1: list[int | str] | list[int | None] = list((1, 2, 3))
925+
reveal_type(x1) # revealed: list[int | str]
926+
927+
x2: Sequence[int | str] | Sequence[int | None] = list((1, 2, 3))
928+
reveal_type(x2) # revealed: list[int]
929+
930+
x3: list[int] | list[int | None] | list[str | None] = list(("1", "2"))
931+
reveal_type(x3) # revealed: list[str | None]
932+
933+
x4: dict[str, list[int | None]] | dict[str, list[str | None]] = dict([("a", ["b"])])
934+
reveal_type(x4) # revealed: dict[str, list[str | None]]
935+
936+
x5: Mapping[str, list[int | None]] | Mapping[str, list[str | None]] = dict([("a", ["b"])])
937+
reveal_type(x5) # revealed: dict[str, list[str | None]]
938+
939+
x6 = [dict([("a", list((1,))), ("b", 1)]), dict([("a", list((1,)))])]
940+
x6.append(reveal_type(dict([("b", 1)]))) # revealed: dict[str, list[int] | int]
941+
reveal_type(x6) # revealed: list[dict[str, list[int] | int] | dict[str, list[int]]]
942+
943+
type EitherList = list[int | str] | list[int | None]
944+
945+
x7: EitherList = list((None, None))
946+
reveal_type(x7) # revealed: list[int | None]
947+
948+
x8: EitherList = list(("1", "2", "3"))
949+
reveal_type(x8) # revealed: list[int | str]
950+
```
951+
910952
## Prefer the inferred type of non-generic classes
911953

912954
```toml

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,8 @@ impl<'db> TypeContext<'db> {
590590
Self { annotation }
591591
}
592592

593-
// If the type annotation is a specialized instance of the given `KnownClass`, returns the
594-
// specialization.
593+
/// If the type annotation is a specialized instance of the given `KnownClass`, returns the
594+
/// specialization.
595595
fn known_specialization(
596596
&self,
597597
db: &'db dyn Db,
@@ -611,6 +611,16 @@ impl<'db> TypeContext<'db> {
611611
self.annotation
612612
.is_some_and(|ty| ty.is_typealias_special_form())
613613
}
614+
615+
/// If the type annotation is a union, returns the target elements that it can be narrowed to.
616+
pub(crate) fn narrow_targets(&self, db: &'db dyn Db) -> Option<&[Type<'db>]> {
617+
self.annotation
618+
.and_then(|ty| ty.as_union_like(db))
619+
// TODO: We could theoretically attempt to narrow to every element of
620+
// the power set of this union. However, this leads to an exponential
621+
// explosion of inference attempts, and is rarely needed in practice.
622+
.map(|union| union.elements(db))
623+
}
614624
}
615625

616626
impl<'db> From<Type<'db>> for TypeContext<'db> {

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5181,16 +5181,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
51815181
.any(|overload| overload.signature.generic_context.is_some());
51825182

51835183
// If the type context is a union, attempt to narrow to a specific element.
5184-
let narrow_targets: &[_] = match call_expression_tcx.annotation {
5185-
// TODO: We could theoretically attempt to narrow to every element of
5186-
// the power set of this union. However, this leads to an exponential
5187-
// explosion of inference attempts, and is rarely needed in practice.
5188-
//
5184+
let narrow_targets = call_expression_tcx
5185+
.narrow_targets(db)
51895186
// We only need to attempt narrowing on generic calls, otherwise the type
51905187
// context has no effect.
5191-
Some(Type::Union(union)) if has_generic_context => union.elements(db),
5192-
_ => &[],
5193-
};
5188+
.filter(|_| has_generic_context)
5189+
.unwrap_or_default();
51945190

51955191
let mut try_narrow = |narrowed_ty: Type<'db>| {
51965192
// Short-circuit if there is no overload with a matching return type.
@@ -6266,12 +6262,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
62666262
) -> Option<Type<'db>> {
62676263
let db = self.db();
62686264

6269-
// If the type context is a union, attempt to narrow to a specific element.
6270-
let narrow_targets: &[_] = match tcx.annotation {
6271-
Some(Type::Union(union)) => union.elements(db),
6272-
_ => &[],
6273-
};
6274-
62756265
let mut try_narrow = |narrowed_ty| {
62766266
let mut speculative_builder = self.speculate();
62776267

@@ -6293,8 +6283,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
62936283
Some(inferred_ty)
62946284
};
62956285

6296-
for narrowed_ty in narrow_targets
6297-
.iter()
6286+
// If the type context is a union, attempt to narrow to a specific element.
6287+
for narrowed_ty in tcx
6288+
.narrow_targets(db)
6289+
.into_iter()
6290+
.flatten()
62986291
.filter(|ty| ty.class_specialization(db).is_some())
62996292
{
63006293
if let Some(result) = try_narrow(*narrowed_ty) {

0 commit comments

Comments
 (0)