Skip to content

Commit 3cc298d

Browse files
committed
[ty] tuple syntax and fix flag/intenum/etc routing
astral-sh/ty#876
1 parent d41b7b9 commit 3cc298d

4 files changed

Lines changed: 133 additions & 10 deletions

File tree

crates/ty_python_semantic/resources/mdtest/enums.md

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,13 +1288,13 @@ from ty_extensions import enum_members
12881288

12891289
Color = Enum("Color", [("RED", 1), ("GREEN", 2), ("BLUE", 3)])
12901290

1291-
# TODO: this should be supported
1292-
reveal_type(enum_members(Color)) # revealed: Unknown
1291+
# revealed: tuple[Literal["RED"], Literal["GREEN"], Literal["BLUE"]]
1292+
reveal_type(enum_members(Color))
12931293

12941294
Color = Enum("Color", (("RED", 1), ("GREEN", 2), ("BLUE", 3)))
12951295

1296-
# TODO: this should be supported
1297-
reveal_type(enum_members(Color)) # revealed: Unknown
1296+
# revealed: tuple[Literal["RED"], Literal["GREEN"], Literal["BLUE"]]
1297+
reveal_type(enum_members(Color))
12981298
```
12991299

13001300
### List of strings
@@ -1387,6 +1387,42 @@ reveal_type(Http.OK.value) # revealed: Literal[1]
13871387
reveal_type(Http.NOT_FOUND.value) # revealed: Literal[2]
13881388
```
13891389

1390+
### IntEnum function syntax
1391+
1392+
```py
1393+
from enum import IntEnum
1394+
from ty_extensions import enum_members
1395+
1396+
Color = IntEnum("Color", "RED GREEN BLUE")
1397+
1398+
# revealed: tuple[Literal["RED"], Literal["GREEN"], Literal["BLUE"]]
1399+
reveal_type(enum_members(Color))
1400+
```
1401+
1402+
### Flag function syntax
1403+
1404+
```py
1405+
from enum import Flag
1406+
from ty_extensions import enum_members
1407+
1408+
Perm = Flag("Perm", "READ WRITE EXECUTE")
1409+
1410+
# revealed: tuple[Literal["READ"], Literal["WRITE"], Literal["EXECUTE"]]
1411+
reveal_type(enum_members(Perm))
1412+
```
1413+
1414+
### IntFlag function syntax
1415+
1416+
```py
1417+
from enum import IntFlag
1418+
from ty_extensions import enum_members
1419+
1420+
Perm = IntFlag("Perm", "READ WRITE EXECUTE")
1421+
1422+
# revealed: tuple[Literal["READ"], Literal["WRITE"], Literal["EXECUTE"]]
1423+
reveal_type(enum_members(Perm))
1424+
```
1425+
13901426
## Exhaustiveness checking
13911427

13921428
## `if` statements

crates/ty_python_semantic/src/types/class/known.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ pub enum KnownClass {
6868
Member,
6969
Nonmember,
7070
StrEnum,
71+
IntEnum,
72+
Flag,
73+
IntFlag,
7174
// abc
7275
ABCMeta,
7376
// Types
@@ -225,6 +228,9 @@ impl KnownClass {
225228
| Self::Member
226229
| Self::Nonmember
227230
| Self::StrEnum
231+
| Self::IntEnum
232+
| Self::Flag
233+
| Self::IntFlag
228234
| Self::ABCMeta
229235
| Self::Iterable
230236
| Self::Iterator
@@ -294,6 +300,9 @@ impl KnownClass {
294300
| KnownClass::Member
295301
| KnownClass::Nonmember
296302
| KnownClass::StrEnum
303+
| KnownClass::IntEnum
304+
| KnownClass::Flag
305+
| KnownClass::IntFlag
297306
| KnownClass::ABCMeta
298307
| KnownClass::GenericAlias
299308
| KnownClass::ModuleType
@@ -385,6 +394,9 @@ impl KnownClass {
385394
| KnownClass::Member
386395
| KnownClass::Nonmember
387396
| KnownClass::StrEnum
397+
| KnownClass::IntEnum
398+
| KnownClass::Flag
399+
| KnownClass::IntFlag
388400
| KnownClass::ABCMeta
389401
| KnownClass::GenericAlias
390402
| KnownClass::ModuleType
@@ -476,6 +488,9 @@ impl KnownClass {
476488
| KnownClass::Member
477489
| KnownClass::Nonmember
478490
| KnownClass::StrEnum
491+
| KnownClass::IntEnum
492+
| KnownClass::Flag
493+
| KnownClass::IntFlag
479494
| KnownClass::ABCMeta
480495
| KnownClass::GenericAlias
481496
| KnownClass::ModuleType
@@ -608,6 +623,9 @@ impl KnownClass {
608623
| Self::Member
609624
| Self::Nonmember
610625
| Self::StrEnum
626+
| Self::IntEnum
627+
| Self::Flag
628+
| Self::IntFlag
611629
| Self::ABCMeta
612630
| Self::Super
613631
| Self::StdlibAlias
@@ -669,6 +687,9 @@ impl KnownClass {
669687
| KnownClass::Member
670688
| KnownClass::Nonmember
671689
| KnownClass::StrEnum
690+
| KnownClass::IntEnum
691+
| KnownClass::Flag
692+
| KnownClass::IntFlag
672693
| KnownClass::ABCMeta
673694
| KnownClass::GenericAlias
674695
| KnownClass::ModuleType
@@ -796,6 +817,9 @@ impl KnownClass {
796817
Self::Member => "member",
797818
Self::Nonmember => "nonmember",
798819
Self::StrEnum => "StrEnum",
820+
Self::IntEnum => "IntEnum",
821+
Self::Flag => "Flag",
822+
Self::IntFlag => "IntFlag",
799823
Self::ABCMeta => "ABCMeta",
800824
Self::Super => "super",
801825
Self::Iterable => "Iterable",
@@ -1126,7 +1150,10 @@ impl KnownClass {
11261150
| Self::Auto
11271151
| Self::Member
11281152
| Self::Nonmember
1129-
| Self::StrEnum => KnownModule::Enum,
1153+
| Self::StrEnum
1154+
| Self::IntEnum
1155+
| Self::Flag
1156+
| Self::IntFlag => KnownModule::Enum,
11301157
Self::GenericAlias
11311158
| Self::ModuleType
11321159
| Self::FunctionType
@@ -1282,6 +1309,9 @@ impl KnownClass {
12821309
| Self::Member
12831310
| Self::Nonmember
12841311
| Self::StrEnum
1312+
| Self::IntEnum
1313+
| Self::Flag
1314+
| Self::IntFlag
12851315
| Self::ABCMeta
12861316
| Self::Super
12871317
| Self::NewType
@@ -1377,6 +1407,9 @@ impl KnownClass {
13771407
| Self::Member
13781408
| Self::Nonmember
13791409
| Self::StrEnum
1410+
| Self::IntEnum
1411+
| Self::Flag
1412+
| Self::IntFlag
13801413
| Self::ABCMeta
13811414
| Self::Super
13821415
| Self::UnionType
@@ -1478,6 +1511,9 @@ impl KnownClass {
14781511
"StrEnum" if Program::get(db).python_version(db) >= PythonVersion::PY311 => {
14791512
&[Self::StrEnum]
14801513
}
1514+
"IntEnum" => &[Self::IntEnum],
1515+
"Flag" => &[Self::Flag],
1516+
"IntFlag" => &[Self::IntFlag],
14811517
"auto" => &[Self::Auto],
14821518
"member" => &[Self::Member],
14831519
"nonmember" => &[Self::Nonmember],
@@ -1564,6 +1600,9 @@ impl KnownClass {
15641600
| Self::Member
15651601
| Self::Nonmember
15661602
| Self::StrEnum
1603+
| Self::IntEnum
1604+
| Self::Flag
1605+
| Self::IntFlag
15671606
| Self::ABCMeta
15681607
| Self::Super
15691608
| Self::NotImplementedType

crates/ty_python_semantic/src/types/enums.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,9 @@ pub(crate) fn enum_metadata<'db>(
424424
})
425425
}
426426

427-
/// Iterates over parent enum classes in the MRO, skipping known classes
428-
/// (like `Enum`, `StrEnum`, etc.) that we handle specially.
427+
/// Iterates over parent enum classes in the MRO, skipping the base `Enum`
428+
/// class (which is handled specially) but including known subclasses like
429+
/// `IntEnum`, `StrEnum`, `Flag`, and `IntFlag` that may declare `_value_`.
429430
fn iter_parent_enum_classes<'db>(
430431
db: &'db dyn Db,
431432
class: StaticClassLiteral<'db>,
@@ -436,7 +437,13 @@ fn iter_parent_enum_classes<'db>(
436437
.filter_map(ClassBase::into_class)
437438
.filter_map(move |class_type| {
438439
let base = class_type.class_literal(db).as_static()?;
439-
(base.known(db).is_none() && is_enum_class_by_inheritance(db, base)).then_some(base)
440+
let dominated_by_enum_infra = base.known(db).is_some_and(|k| {
441+
!matches!(
442+
k,
443+
KnownClass::IntEnum | KnownClass::Flag | KnownClass::IntFlag
444+
)
445+
});
446+
(!dominated_by_enum_infra && is_enum_class_by_inheritance(db, base)).then_some(base)
440447
})
441448
}
442449

crates/ty_python_semantic/src/types/infer/builder/enum_call.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@ pub(crate) fn enum_functional_call_base<'db>(db: &'db dyn Db, ty: Type<'db>) ->
1717
let ClassLiteral::Static(cls) = ty.as_class_literal()? else {
1818
return None;
1919
};
20-
cls.known(db)
21-
.filter(|k| matches!(k, KnownClass::Enum | KnownClass::StrEnum))
20+
cls.known(db).filter(|k| {
21+
matches!(
22+
k,
23+
KnownClass::Enum
24+
| KnownClass::StrEnum
25+
| KnownClass::IntEnum
26+
| KnownClass::Flag
27+
| KnownClass::IntFlag
28+
)
29+
})
2230
}
2331

2432
impl<'db> TypeInferenceBuilder<'db, '_> {
@@ -32,6 +40,19 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
3240
let args = &call_expr.arguments.args;
3341
let keywords = &call_expr.arguments.keywords;
3442

43+
// bail out on unknown keywords so normal overload resolution can diagnose them
44+
let has_unknown_keyword = keywords.iter().any(|kw| {
45+
kw.arg.as_ref().is_some_and(|name| {
46+
!matches!(
47+
name.as_str(),
48+
"value" | "names" | "start" | "type" | "module" | "qualname" | "boundary"
49+
)
50+
})
51+
});
52+
if has_unknown_keyword {
53+
return None;
54+
}
55+
3556
let value_kw = call_expr.arguments.find_keyword("value");
3657
let names_kw = call_expr.arguments.find_keyword("names");
3758
let start_kw = call_expr.arguments.find_keyword("start");
@@ -172,13 +193,33 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
172193
for (i, elt) in elts.iter().enumerate() {
173194
let ty = self.expression_type(elt);
174195
if let Some(string_lit) = ty.as_string_literal() {
196+
// string element: auto-assign incrementing int value
175197
let member_name = Name::new(string_lit.value(db));
176198
let v = Type::int_literal(start + i64::try_from(i).unwrap_or(0));
177199
members.push((member_name, v));
200+
} else if let Some((name, value)) = self.extract_enum_tuple_entry(elt) {
201+
members.push((name, value));
178202
} else {
179203
return (vec![], false);
180204
}
181205
}
182206
(members, true)
183207
}
208+
209+
/// Extract a `(name, value)` pair from a tuple element like `("RED", 1)`.
210+
fn extract_enum_tuple_entry(&mut self, elt: &ast::Expr) -> Option<(Name, Type<'db>)> {
211+
let pair = match elt {
212+
ast::Expr::Tuple(tup) => &tup.elts,
213+
ast::Expr::List(list) => &list.elts,
214+
_ => return None,
215+
};
216+
let [name_expr, value_expr] = &**pair else {
217+
return None;
218+
};
219+
let db = self.db();
220+
let name_ty = self.expression_type(name_expr);
221+
let name = Name::new(name_ty.as_string_literal()?.value(db));
222+
let value = self.expression_type(value_expr);
223+
Some((name, value))
224+
}
184225
}

0 commit comments

Comments
 (0)