Skip to content

Commit f35c255

Browse files
committed
[red-knot] Type inference for str.startswith
1 parent 1aad180 commit f35c255

4 files changed

Lines changed: 72 additions & 6 deletions

File tree

crates/red_knot_python_semantic/resources/mdtest/sys_platform.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ reveal_type(sys.platform) # revealed: Literal["linux"]
3131

3232
## Testing for a specific platform
3333

34-
### Exact comparison
35-
3634
```toml
3735
[environment]
3836
python-platform = "freebsd8"
3937
```
4038

39+
### Exact comparison
40+
4141
```py
4242
import sys
4343

@@ -48,11 +48,11 @@ reveal_type(sys.platform == "linux") # revealed: Literal[False]
4848
### Substring comparison
4949

5050
It is [recommended](https://docs.python.org/3/library/sys.html#sys.platform) to use
51-
`sys.platform.startswith(...)` for platform checks. This is not yet supported in type inference:
51+
`sys.platform.startswith(...)` for platform checks:
5252

5353
```py
5454
import sys
5555

56-
reveal_type(sys.platform.startswith("freebsd")) # revealed: bool
57-
reveal_type(sys.platform.startswith("linux")) # revealed: bool
56+
reveal_type(sys.platform.startswith("freebsd")) # revealed: Literal[True]
57+
reveal_type(sys.platform.startswith("linux")) # revealed: Literal[False]
5858
```

crates/red_knot_python_semantic/src/types.rs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,12 @@ impl<'db> Type<'db> {
20002000
))
20012001
.into(),
20022002
),
2003+
(Some(KnownClass::Str), "startswith") => Some(
2004+
Symbol::bound(Type::WrapperDescriptor(
2005+
WrapperDescriptorKind::StrStartswith,
2006+
))
2007+
.into(),
2008+
),
20032009
// TODO:
20042010
// We currently hard-code the knowledge that the following known classes are not
20052011
// descriptors, i.e. that they have no `__get__` method. This is not wrong and
@@ -2508,6 +2514,10 @@ impl<'db> Type<'db> {
25082514
Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(property)),
25092515
)
25102516
.into(),
2517+
Type::StringLiteral(literal) if name == "startswith" => Symbol::bound(
2518+
Type::MethodWrapper(MethodWrapperKind::StrStartswith(literal)),
2519+
)
2520+
.into(),
25112521

25122522
Type::ClassLiteral(class)
25132523
if name == "__get__" && class.is_known(db, KnownClass::FunctionType) =>
@@ -2533,6 +2543,14 @@ impl<'db> Type<'db> {
25332543
))
25342544
.into()
25352545
}
2546+
Type::ClassLiteral(class)
2547+
if name == "startswith" && class.is_known(db, KnownClass::Str) =>
2548+
{
2549+
Symbol::bound(Type::WrapperDescriptor(
2550+
WrapperDescriptorKind::StrStartswith,
2551+
))
2552+
.into()
2553+
}
25362554
Type::BoundMethod(bound_method) => match name_str {
25372555
"__self__" => Symbol::bound(bound_method.self_instance(db)).into(),
25382556
"__func__" => {
@@ -3043,7 +3061,7 @@ impl<'db> Type<'db> {
30433061
WrapperDescriptorKind::PropertyDunderGet => {
30443062
KnownClass::Property.to_instance(db)
30453063
}
3046-
WrapperDescriptorKind::PropertyDunderSet => {
3064+
_ => {
30473065
unreachable!("Not part of outer match pattern")
30483066
}
30493067
};
@@ -3112,6 +3130,34 @@ impl<'db> Type<'db> {
31123130
))
31133131
}
31143132

3133+
Type::MethodWrapper(MethodWrapperKind::StrStartswith(_)) => {
3134+
Signatures::single(CallableSignature::single(
3135+
self,
3136+
Signature::new(
3137+
Parameters::new([
3138+
Parameter::positional_only(Some(Name::new_static("prefix")))
3139+
.with_annotated_type(UnionType::from_elements(
3140+
db,
3141+
[
3142+
KnownClass::Str.to_instance(db),
3143+
// TODO: tuple[str, ...]
3144+
KnownClass::Tuple.to_instance(db),
3145+
],
3146+
)),
3147+
Parameter::positional_only(Some(Name::new_static("start")))
3148+
// TODO: SupportsIndex | None
3149+
.with_annotated_type(Type::object(db))
3150+
.with_default_type(Type::none(db)),
3151+
Parameter::positional_only(Some(Name::new_static("end")))
3152+
// TODO: SupportsIndex | None
3153+
.with_annotated_type(Type::object(db))
3154+
.with_default_type(Type::none(db)),
3155+
]),
3156+
Some(KnownClass::Bool.to_instance(db)),
3157+
),
3158+
))
3159+
}
3160+
31153161
Type::FunctionLiteral(function_type) => match function_type.known(db) {
31163162
Some(
31173163
KnownFunction::IsEquivalentTo
@@ -4238,6 +4284,7 @@ impl<'db> Type<'db> {
42384284
| Type::AlwaysTruthy
42394285
| Type::AlwaysFalsy
42404286
| Type::WrapperDescriptor(_)
4287+
| Type::MethodWrapper(MethodWrapperKind::StrStartswith(_))
42414288
| Type::ModuleLiteral(_)
42424289
// A non-generic class never needs to be specialized. A generic class is specialized
42434290
// explicitly (via a subscript expression) or implicitly (via a call), and not because
@@ -6155,6 +6202,8 @@ pub enum MethodWrapperKind<'db> {
61556202
PropertyDunderGet(PropertyInstanceType<'db>),
61566203
/// Method wrapper for `some_property.__set__`
61576204
PropertyDunderSet(PropertyInstanceType<'db>),
6205+
/// Method wrapper for `str.startswith`
6206+
StrStartswith(StringLiteralType<'db>),
61586207
}
61596208

61606209
/// Represents a specific instance of `types.WrapperDescriptorType`
@@ -6166,6 +6215,8 @@ pub enum WrapperDescriptorKind {
61666215
PropertyDunderGet,
61676216
/// `property.__set__`
61686217
PropertyDunderSet,
6218+
/// `str.startswith`
6219+
StrStartswith,
61696220
}
61706221

61716222
#[salsa::interned(debug)]

crates/red_knot_python_semantic/src/types/call/bind.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,16 @@ impl<'db> Bindings<'db> {
395395
}
396396
}
397397

398+
Type::MethodWrapper(MethodWrapperKind::StrStartswith(literal)) => {
399+
if let [Some(Type::StringLiteral(prefix)), None, None] =
400+
overload.parameter_types()
401+
{
402+
overload.set_return_type(Type::BooleanLiteral(
403+
literal.value(db).starts_with(&**prefix.value(db)),
404+
));
405+
}
406+
}
407+
398408
Type::BoundMethod(bound_method)
399409
if bound_method.self_instance(db).is_property_instance() =>
400410
{
@@ -1088,6 +1098,7 @@ impl<'db> CallableDescription<'db> {
10881098
WrapperDescriptorKind::FunctionTypeDunderGet => "FunctionType.__get__",
10891099
WrapperDescriptorKind::PropertyDunderGet => "property.__get__",
10901100
WrapperDescriptorKind::PropertyDunderSet => "property.__set__",
1101+
WrapperDescriptorKind::StrStartswith => "str.startswith",
10911102
},
10921103
}),
10931104
_ => None,

crates/red_knot_python_semantic/src/types/display.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,15 @@ impl Display for DisplayRepresentation<'_> {
145145
Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(_)) => {
146146
write!(f, "<method-wrapper `__set__` of `property` object>",)
147147
}
148+
Type::MethodWrapper(MethodWrapperKind::StrStartswith(_)) => {
149+
write!(f, "<method-wrapper `startswith` of `str` object>",)
150+
}
148151
Type::WrapperDescriptor(kind) => {
149152
let (method, object) = match kind {
150153
WrapperDescriptorKind::FunctionTypeDunderGet => ("__get__", "function"),
151154
WrapperDescriptorKind::PropertyDunderGet => ("__get__", "property"),
152155
WrapperDescriptorKind::PropertyDunderSet => ("__set__", "property"),
156+
WrapperDescriptorKind::StrStartswith => ("startswith", "str"),
153157
};
154158
write!(f, "<wrapper-descriptor `{method}` of `{object}` objects>")
155159
}

0 commit comments

Comments
 (0)