Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class Sub(Base[Sub]): ...
reveal_type(Sub) # revealed: Literal[Sub]
```

A similar case can work in a non-stub file, if forward references are stringified:

`string_annotation.py`:

```py
Expand All @@ -174,6 +176,8 @@ class Sub(Base["Sub"]): ...
reveal_type(Sub) # revealed: Literal[Sub]
```

In a non-stub file, without stringified forward references, this raises a `NameError`:

`bare_annotation.py`:

```py
Expand All @@ -184,5 +188,13 @@ class Base[T]: ...
class Sub(Base[Sub]): ...
```

## Another cyclic case

```pyi
# TODO no error (generics)
# error: [invalid-base]
class Derived[T](list[Derived[T]]): ...
```

[crtp]: https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
[f-bound]: https://en.wikipedia.org/wiki/Bounded_quantification#F-bounded_quantification
8 changes: 8 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,7 @@ impl<'db> Type<'db> {
/// of union and intersection types.
#[salsa::tracked]
fn class_member(self, db: &'db dyn Db, name: Name) -> SymbolAndQualifiers<'db> {
tracing::trace!("class_member: {}.{}", self.display(db), name);
match self {
Type::Union(union) => union
.map_with_boundness_and_qualifiers(db, |elem| elem.class_member(db, name.clone())),
Expand Down Expand Up @@ -1673,6 +1674,12 @@ impl<'db> Type<'db> {
instance: Type<'db>,
owner: Type<'db>,
) -> Option<(Type<'db>, AttributeKind)> {
tracing::trace!(
"try_call_dunder_get: {}, {}, {}",
self.display(db),
instance.display(db),
owner.display(db)
);
let descr_get = self.class_member(db, "__get__".into()).symbol;

if let Symbol::Type(descr_get, descr_get_boundness) = descr_get {
Expand Down Expand Up @@ -1905,6 +1912,7 @@ impl<'db> Type<'db> {
name: Name,
policy: MemberLookupPolicy,
) -> SymbolAndQualifiers<'db> {
tracing::trace!("member_lookup_with_policy: {}.{}", self.display(db), name);
if name == "__class__" {
return Symbol::bound(self.to_meta_type(db)).into();
}
Expand Down
57 changes: 54 additions & 3 deletions crates/red_knot_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,50 @@ pub struct Class<'db> {
pub(crate) known: Option<KnownClass>,
}

fn explicit_bases_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &[Type<'db>],
_count: u32,
_self: Class<'db>,
) -> salsa::CycleRecoveryAction<Box<[Type<'db>]>> {
salsa::CycleRecoveryAction::Iterate
}

fn explicit_bases_cycle_initial<'db>(_db: &'db dyn Db, _self: Class<'db>) -> Box<[Type<'db>]> {
Box::default()
}

fn try_mro_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Result<Mro<'db>, MroError<'db>>,
_count: u32,
_self: Class<'db>,
) -> salsa::CycleRecoveryAction<Result<Mro<'db>, MroError<'db>>> {
salsa::CycleRecoveryAction::Iterate
}

#[allow(clippy::unnecessary_wraps)]
fn try_mro_cycle_initial<'db>(
db: &'db dyn Db,
self_: Class<'db>,
) -> Result<Mro<'db>, MroError<'db>> {
Ok(Mro::from_error(db, self_))
}

#[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)]
fn inheritance_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Option<InheritanceCycle>,
_count: u32,
_self: Class<'db>,
) -> salsa::CycleRecoveryAction<Option<InheritanceCycle>> {
salsa::CycleRecoveryAction::Iterate
}

fn inheritance_cycle_initial<'db>(_db: &'db dyn Db, _self: Class<'db>) -> Option<InheritanceCycle> {
None
}

#[salsa::tracked]
impl<'db> Class<'db> {
/// Return `true` if this class represents `known_class`
Expand Down Expand Up @@ -81,8 +125,9 @@ impl<'db> Class<'db> {
.map(|ClassLiteralType { class }| class)
}

#[salsa::tracked(return_ref)]
#[salsa::tracked(return_ref, cycle_fn=explicit_bases_cycle_recover, cycle_initial=explicit_bases_cycle_initial)]
fn explicit_bases_query(self, db: &'db dyn Db) -> Box<[Type<'db>]> {
tracing::trace!("Class::explicit_bases_query: {}", self.name(db));
let class_stmt = self.node(db);

let class_definition = semantic_index(db, self.file(db)).definition(class_stmt);
Expand Down Expand Up @@ -110,6 +155,7 @@ impl<'db> Class<'db> {
/// Return the types of the decorators on this class
#[salsa::tracked(return_ref)]
fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> {
tracing::trace!("Class::decorators: {}", self.name(db));
let class_stmt = self.node(db);
if class_stmt.decorator_list.is_empty() {
return Box::new([]);
Expand Down Expand Up @@ -141,8 +187,9 @@ impl<'db> Class<'db> {
/// attribute on a class at runtime.
///
/// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order
#[salsa::tracked(return_ref)]
#[salsa::tracked(return_ref, cycle_fn=try_mro_cycle_recover, cycle_initial=try_mro_cycle_initial)]
pub(super) fn try_mro(self, db: &'db dyn Db) -> Result<Mro<'db>, MroError<'db>> {
tracing::trace!("Class::try_mro: {}", self.name(db));
Mro::of_class(db, self)
}

Expand Down Expand Up @@ -199,6 +246,8 @@ impl<'db> Class<'db> {
/// Return the metaclass of this class, or an error if the metaclass cannot be inferred.
#[salsa::tracked]
pub(super) fn try_metaclass(self, db: &'db dyn Db) -> Result<Type<'db>, MetaclassError<'db>> {
tracing::trace!("Class::try_metaclass: {}", self.name(db));

// Identify the class's own metaclass (or take the first base class's metaclass).
let mut base_classes = self.fully_static_explicit_bases(db).peekable();

Expand Down Expand Up @@ -662,7 +711,7 @@ impl<'db> Class<'db> {
///
/// A class definition like this will fail at runtime,
/// but we must be resilient to it or we could panic.
#[salsa::tracked]
#[salsa::tracked(cycle_fn=inheritance_cycle_recover, cycle_initial=inheritance_cycle_initial)]
pub(super) fn inheritance_cycle(self, db: &'db dyn Db) -> Option<InheritanceCycle> {
/// Return `true` if the class is cyclically defined.
///
Expand Down Expand Up @@ -694,6 +743,8 @@ impl<'db> Class<'db> {
result
}

tracing::trace!("Class::inheritance_cycle: {}", self.name(db));

let visited_classes = &mut IndexSet::new();
if !is_cyclically_defined_recursive(db, self, &mut IndexSet::new(), visited_classes) {
None
Expand Down
Loading