Skip to content

Commit 5f11c39

Browse files
authored
[ty] fix inferring type variable from string literal argument (#23326)
## Summary Fixes astral-sh/ty#2821 Allow a string literal argument to match against an iterable parameter in typevar inference. ## Test Plan Added mdtest.
1 parent 987a7dd commit 5f11c39

2 files changed

Lines changed: 41 additions & 0 deletions

File tree

crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,3 +880,31 @@ def f(x: T) -> T:
880880
def g(x: S) -> S:
881881
return f(x) # error: [invalid-argument-type]
882882
```
883+
884+
## Inferring typevars in iterable parameters from literal string and bytes arguments
885+
886+
```py
887+
from typing import Iterable, TypeVar
888+
from typing_extensions import LiteralString
889+
890+
FlatT = TypeVar("FlatT")
891+
892+
def flatten(*iterables: Iterable[FlatT]) -> list[FlatT]:
893+
return [x for iterable in iterables for x in iterable]
894+
895+
def flatten_covariant(*iterables: Iterable[FlatT]) -> tuple[FlatT, ...]:
896+
return tuple(x for iterable in iterables for x in iterable)
897+
898+
reveal_type(flatten("abc", (1, 2, 3))) # revealed: list[str | int]
899+
# TODO: we could have `Literal["a", "b", "c"]` instead of `str` here
900+
reveal_type(flatten_covariant("abc", (1, 2, 3))) # revealed: tuple[str | Literal[1, 2, 3], ...]
901+
902+
def literal_string_case(literal_string: LiteralString):
903+
reveal_type(flatten(literal_string, (1, 2, 3))) # revealed: list[str | int]
904+
905+
reveal_type(flatten(b"abc")) # revealed: list[int]
906+
reveal_type(flatten(b"abc", ("x",))) # revealed: list[int | str]
907+
# TODO: we could have `Literal[97, 98, 99]` instead of `int` in the next two lines
908+
reveal_type(flatten_covariant(b"abc")) # revealed: tuple[int, ...]
909+
reveal_type(flatten_covariant(b"abc", ("x",))) # revealed: tuple[int | Literal["x"], ...]
910+
```

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,6 +2254,19 @@ impl<'db> SpecializationBuilder<'db> {
22542254
}
22552255
}
22562256

2257+
(
2258+
formal @ (Type::NominalInstance(_) | Type::ProtocolInstance(_)),
2259+
actual_literal @ (Type::StringLiteral(_)
2260+
| Type::LiteralString
2261+
| Type::BytesLiteral(_)),
2262+
) => {
2263+
// Retry specialization with the literal's fallback instance (`str` / `bytes`)
2264+
// so literal iterables can contribute to generic inference.
2265+
if let Some(actual_instance) = actual_literal.literal_fallback_instance(self.db) {
2266+
return self.infer_map_impl(formal, actual_instance, polarity, f, seen);
2267+
}
2268+
}
2269+
22572270
(formal, Type::ProtocolInstance(actual_protocol)) => {
22582271
// TODO: This will only handle protocol classes that explicit inherit
22592272
// from other generic protocol classes by listing it as a base class.

0 commit comments

Comments
 (0)