Skip to content

Commit 0042529

Browse files
committed
[red-knot] Type inference for str.startswith
1 parent 1aad180 commit 0042529

5 files changed

Lines changed: 105 additions & 6 deletions

File tree

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# `str.startswith`
2+
3+
We special-case `str.startswith` to allow inference of precise Boolean literal types, because those
4+
are used in [`sys.platform` checks].
5+
6+
```py
7+
reveal_type("abc".startswith("")) # revealed: Literal[True]
8+
reveal_type("abc".startswith("a")) # revealed: Literal[True]
9+
reveal_type("abc".startswith("ab")) # revealed: Literal[True]
10+
reveal_type("abc".startswith("abc")) # revealed: Literal[True]
11+
12+
reveal_type("abc".startswith("abcd")) # revealed: Literal[False]
13+
reveal_type("abc".startswith("bc")) # revealed: Literal[False]
14+
15+
reveal_type("AbC".startswith("")) # revealed: Literal[True]
16+
reveal_type("AbC".startswith("A")) # revealed: Literal[True]
17+
reveal_type("AbC".startswith("Ab")) # revealed: Literal[True]
18+
reveal_type("AbC".startswith("AbC")) # revealed: Literal[True]
19+
20+
reveal_type("AbC".startswith("a")) # revealed: Literal[False]
21+
reveal_type("AbC".startswith("aB")) # revealed: Literal[False]
22+
23+
reveal_type("".startswith("")) # revealed: Literal[True]
24+
25+
reveal_type("".startswith(" ")) # revealed: Literal[False]
26+
```
27+
28+
Make sure that we fall back to `bool` for more complex cases:
29+
30+
```py
31+
reveal_type("abc".startswith("b", 1)) # revealed: bool
32+
reveal_type("abc".startswith("bc", 1, 3)) # revealed: bool
33+
34+
reveal_type("abc".startswith(("a", "x"))) # revealed: bool
35+
```
36+
37+
And similiarly, we should still infer `bool` if the instance or the prefix are not string literals:
38+
39+
```py
40+
from typing_extensions import LiteralString
41+
42+
def _(string_instance: str, literalstring: LiteralString):
43+
reveal_type(string_instance.startswith("a")) # revealed: bool
44+
reveal_type(literalstring.startswith("a")) # revealed: bool
45+
46+
reveal_type("a".startswith(string_instance)) # revealed: bool
47+
reveal_type("a".startswith(literalstring)) # revealed: bool
48+
```
49+
50+
[`sys.platform` checks]: https://docs.python.org/3/library/sys.html#sys.platform

crates/red_knot_python_semantic/resources/mdtest/sys_platform.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ reveal_type(sys.platform) # revealed: Literal["linux"]
3131

3232
## Testing for a specific platform
3333

34-
### Exact comparison
35-
3634
```toml
3735
[environment]
3836
python-platform = "freebsd8"
3937
```
4038

39+
### Exact comparison
40+
4141
```py
4242
import sys
4343

@@ -48,11 +48,11 @@ reveal_type(sys.platform == "linux") # revealed: Literal[False]
4848
### Substring comparison
4949

5050
It is [recommended](https://docs.python.org/3/library/sys.html#sys.platform) to use
51-
`sys.platform.startswith(...)` for platform checks. This is not yet supported in type inference:
51+
`sys.platform.startswith(...)` for platform checks:
5252

5353
```py
5454
import sys
5555

56-
reveal_type(sys.platform.startswith("freebsd")) # revealed: bool
57-
reveal_type(sys.platform.startswith("linux")) # revealed: bool
56+
reveal_type(sys.platform.startswith("freebsd")) # revealed: Literal[True]
57+
reveal_type(sys.platform.startswith("linux")) # revealed: Literal[False]
5858
```

crates/red_knot_python_semantic/src/types.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,7 @@ impl<'db> Type<'db> {
20002000
))
20012001
.into(),
20022002
),
2003+
20032004
// TODO:
20042005
// We currently hard-code the knowledge that the following known classes are not
20052006
// descriptors, i.e. that they have no `__get__` method. This is not wrong and
@@ -2508,6 +2509,10 @@ impl<'db> Type<'db> {
25082509
Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(property)),
25092510
)
25102511
.into(),
2512+
Type::StringLiteral(literal) if name == "startswith" => Symbol::bound(
2513+
Type::MethodWrapper(MethodWrapperKind::StrStartswith(literal)),
2514+
)
2515+
.into(),
25112516

25122517
Type::ClassLiteral(class)
25132518
if name == "__get__" && class.is_known(db, KnownClass::FunctionType) =>
@@ -3043,7 +3048,7 @@ impl<'db> Type<'db> {
30433048
WrapperDescriptorKind::PropertyDunderGet => {
30443049
KnownClass::Property.to_instance(db)
30453050
}
3046-
WrapperDescriptorKind::PropertyDunderSet => {
3051+
_ => {
30473052
unreachable!("Not part of outer match pattern")
30483053
}
30493054
};
@@ -3112,6 +3117,34 @@ impl<'db> Type<'db> {
31123117
))
31133118
}
31143119

3120+
Type::MethodWrapper(MethodWrapperKind::StrStartswith(_)) => {
3121+
Signatures::single(CallableSignature::single(
3122+
self,
3123+
Signature::new(
3124+
Parameters::new([
3125+
Parameter::positional_only(Some(Name::new_static("prefix")))
3126+
.with_annotated_type(UnionType::from_elements(
3127+
db,
3128+
[
3129+
KnownClass::Str.to_instance(db),
3130+
// TODO: tuple[str, ...]
3131+
KnownClass::Tuple.to_instance(db),
3132+
],
3133+
)),
3134+
Parameter::positional_only(Some(Name::new_static("start")))
3135+
// TODO: SupportsIndex | None
3136+
.with_annotated_type(Type::object(db))
3137+
.with_default_type(Type::none(db)),
3138+
Parameter::positional_only(Some(Name::new_static("end")))
3139+
// TODO: SupportsIndex | None
3140+
.with_annotated_type(Type::object(db))
3141+
.with_default_type(Type::none(db)),
3142+
]),
3143+
Some(KnownClass::Bool.to_instance(db)),
3144+
),
3145+
))
3146+
}
3147+
31153148
Type::FunctionLiteral(function_type) => match function_type.known(db) {
31163149
Some(
31173150
KnownFunction::IsEquivalentTo
@@ -4238,6 +4271,7 @@ impl<'db> Type<'db> {
42384271
| Type::AlwaysTruthy
42394272
| Type::AlwaysFalsy
42404273
| Type::WrapperDescriptor(_)
4274+
| Type::MethodWrapper(MethodWrapperKind::StrStartswith(_))
42414275
| Type::ModuleLiteral(_)
42424276
// A non-generic class never needs to be specialized. A generic class is specialized
42434277
// explicitly (via a subscript expression) or implicitly (via a call), and not because
@@ -6155,6 +6189,8 @@ pub enum MethodWrapperKind<'db> {
61556189
PropertyDunderGet(PropertyInstanceType<'db>),
61566190
/// Method wrapper for `some_property.__set__`
61576191
PropertyDunderSet(PropertyInstanceType<'db>),
6192+
/// Method wrapper for `str.startswith`
6193+
StrStartswith(StringLiteralType<'db>),
61586194
}
61596195

61606196
/// Represents a specific instance of `types.WrapperDescriptorType`

crates/red_knot_python_semantic/src/types/call/bind.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,16 @@ impl<'db> Bindings<'db> {
395395
}
396396
}
397397

398+
Type::MethodWrapper(MethodWrapperKind::StrStartswith(literal)) => {
399+
if let [Some(Type::StringLiteral(prefix)), None, None] =
400+
overload.parameter_types()
401+
{
402+
overload.set_return_type(Type::BooleanLiteral(
403+
literal.value(db).starts_with(&**prefix.value(db)),
404+
));
405+
}
406+
}
407+
398408
Type::BoundMethod(bound_method)
399409
if bound_method.self_instance(db).is_property_instance() =>
400410
{

crates/red_knot_python_semantic/src/types/display.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ impl Display for DisplayRepresentation<'_> {
145145
Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(_)) => {
146146
write!(f, "<method-wrapper `__set__` of `property` object>",)
147147
}
148+
Type::MethodWrapper(MethodWrapperKind::StrStartswith(_)) => {
149+
write!(f, "<method-wrapper `startswith` of `str` object>",)
150+
}
148151
Type::WrapperDescriptor(kind) => {
149152
let (method, object) = match kind {
150153
WrapperDescriptorKind::FunctionTypeDunderGet => ("__get__", "function"),

0 commit comments

Comments
 (0)