Skip to content

Commit 88d1eec

Browse files
authored
[ty] Ensure a type[] type T is always considered assignable to a union that holds T and a protocol P (#23901)
1 parent 37cdd61 commit 88d1eec

5 files changed

Lines changed: 120 additions & 58 deletions

File tree

crates/ty_python_semantic/resources/mdtest/type_of/basic.md

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,24 +395,29 @@ def _[T]():
395395
static_assert(not is_disjoint_from(type[InvSub[T]], type[Inv[Any]]))
396396
```
397397

398-
## `type[]` types in unions with `Callable` types
398+
## `type[]` types in unions with `Callable` types and callback protocols
399399

400400
`type[Foo]` is assignable to `Callable[[], Foo]` here:
401401

402402
```py
403-
from typing import Callable
403+
from typing import Callable, Protocol
404404
from ty_extensions import is_assignable_to, is_subtype_of, static_assert, TypeOf, Top
405405

406406
class Foo:
407407
def __init__(self): ...
408408

409+
class CustomCallback(Protocol):
410+
def __call__(self, /) -> Foo: ...
411+
409412
static_assert(is_assignable_to(type[Foo], Callable[[], Foo]))
413+
static_assert(is_assignable_to(type[Foo], CustomCallback))
410414
```
411415

412416
but it is not a subtype of `Callable[[], Foo]` or redundant with `Callable[[], Foo]`:
413417

414418
```py
415419
static_assert(not is_subtype_of(type[Foo], Callable[[], Foo]))
420+
static_assert(not is_subtype_of(type[Foo], CustomCallback))
416421
```
417422

418423
and the reason for that is that constructor signatures are not checked for Liskov violations,
@@ -428,15 +433,17 @@ so if `type[Foo]` were considered a subtype of `Callable[[], Foo]`, then this un
428433
incorrectly simplified to `Callable[[], Foo]`:
429434

430435
```py
431-
def test(x: type[Foo] | Callable[[], Foo]):
432-
# this remains unsimplified!
436+
def test(x: type[Foo] | Callable[[], Foo], y: type[Foo] | CustomCallback):
437+
# these remain unsimplified!
433438
reveal_type(x) # revealed: type[Foo] | (() -> Foo)
439+
reveal_type(y) # revealed: type[Foo] | CustomCallback
434440
```
435441

436-
which means that this assertion would fail:
442+
which means that these assertions would fail:
437443

438444
```py
439445
static_assert(is_subtype_of(type[Bar], type[Foo] | Callable[[], Foo]))
446+
static_assert(is_subtype_of(type[Bar], type[Foo] | CustomCallback))
440447
```
441448

442449
despite the fact that this would still pass!
@@ -466,17 +473,26 @@ constructor signatures:
466473

467474
```py
468475
static_assert(is_subtype_of(TypeOf[Foo], Callable[[], Foo]))
476+
static_assert(is_subtype_of(TypeOf[Foo], CustomCallback))
469477
static_assert(is_subtype_of(TypeOf[Bar], Callable[[int], Bar]))
470478
static_assert(not is_subtype_of(TypeOf[Bar], TypeOf[Foo]))
471479
static_assert(is_subtype_of(TypeOf[Foo], TypeOf[Foo] | Callable[[], Foo]))
480+
static_assert(is_subtype_of(TypeOf[Foo], TypeOf[Foo] | CustomCallback))
472481
static_assert(is_subtype_of(TypeOf[Bar], type[Foo] | Callable[[], Foo]))
482+
static_assert(is_subtype_of(TypeOf[Bar], type[Foo] | CustomCallback))
473483
static_assert(is_subtype_of(TypeOf[Bar], TypeOf[Bar] | Callable[[], Bar]))
474484
static_assert(is_subtype_of(TypeOf[Bar], type[Bar] | Callable[[], Bar]))
475485
static_assert(is_subtype_of(TypeOf[Bar], TypeOf[Bar] | Callable[[int], Bar]))
476486
static_assert(is_subtype_of(TypeOf[Bar], type[Bar] | Callable[[int], Bar]))
477487

478-
def f(x: TypeOf[Foo] | Callable[[], Foo], y: TypeOf[Bar] | Callable[[int], Bar], z: TypeOf[Bar] | Callable[[], Bar]):
479-
reveal_type(x) # revealed: () -> Foo
480-
reveal_type(y) # revealed: (int, /) -> Bar
481-
reveal_type(z) # revealed: <class 'Bar'> | (() -> Bar)
488+
def f(
489+
a: TypeOf[Foo] | Callable[[], Foo],
490+
b: TypeOf[Bar] | Callable[[int], Bar],
491+
c: TypeOf[Bar] | Callable[[], Bar],
492+
d: TypeOf[Foo] | CustomCallback,
493+
):
494+
reveal_type(a) # revealed: () -> Foo
495+
reveal_type(b) # revealed: (int, /) -> Bar
496+
reveal_type(c) # revealed: <class 'Bar'> | (() -> Bar)
497+
reveal_type(d) # revealed: CustomCallback
482498
```

crates/ty_python_semantic/src/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use ruff_text_size::Ranged;
1919
use smallvec::smallvec_inline;
2020
use ty_module_resolver::{KnownModule, Module, ModuleName, resolve_module};
2121

22+
pub(crate) use self::callable::UpcastPolicy;
2223
pub use self::cyclic::CycleDetector;
2324
pub(crate) use self::cyclic::TypeTransformer;
2425
pub(crate) use self::diagnostic::register_lints;

crates/ty_python_semantic/src/types/callable.rs

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ impl<'db> Type<'db> {
4141
}
4242

4343
pub(crate) fn try_upcast_to_callable(self, db: &'db dyn Db) -> Option<CallableTypes<'db>> {
44+
self.try_upcast_to_callable_with_policy(db, UpcastPolicy::default())
45+
}
46+
47+
pub(crate) fn try_upcast_to_callable_with_policy(
48+
self,
49+
db: &'db dyn Db,
50+
policy: UpcastPolicy,
51+
) -> Option<CallableTypes<'db>> {
4452
match self {
4553
Type::Callable(callable) => Some(CallableTypes::one(callable)),
4654

@@ -68,7 +76,7 @@ impl<'db> Type<'db> {
6876
if let Place::Defined(place) = call_symbol
6977
&& place.is_definitely_defined()
7078
{
71-
place.ty.try_upcast_to_callable(db)
79+
place.ty.try_upcast_to_callable_with_policy(db, policy)
7280
} else {
7381
None
7482
}
@@ -79,16 +87,25 @@ impl<'db> Type<'db> {
7987

8088
Type::GenericAlias(alias) => Some(ClassType::Generic(alias).into_callable(db)),
8189

82-
Type::NewTypeInstance(newtype) => {
83-
newtype.concrete_base_type(db).try_upcast_to_callable(db)
90+
Type::NewTypeInstance(newtype) => newtype
91+
.concrete_base_type(db)
92+
.try_upcast_to_callable_with_policy(db, policy),
93+
94+
Type::SubclassOf(subclass_of_ty) if policy == UpcastPolicy::Sound => {
95+
Some(CallableTypes::one(CallableType::function_like(
96+
db,
97+
Signature::new(Parameters::top(), subclass_of_ty.to_instance(db)),
98+
)))
8499
}
85100

86101
// TODO: This is unsound so in future we can consider an opt-in option to disable it.
87102
Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() {
88103
SubclassOfInner::Class(class) => Some(class.into_callable(db)),
89104
SubclassOfInner::TypeVar(tvar) => match tvar.typevar(db).bound_or_constraints(db) {
90105
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
91-
let upcast_callables = bound.to_meta_type(db).try_upcast_to_callable(db)?;
106+
let upcast_callables = bound
107+
.to_meta_type(db)
108+
.try_upcast_to_callable_with_policy(db, policy)?;
92109
Some(upcast_callables.map(|callable| {
93110
let signatures = callable
94111
.signatures(db)
@@ -104,8 +121,9 @@ impl<'db> Type<'db> {
104121
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
105122
let mut callables = SmallVec::new();
106123
for constraint in constraints.elements(db) {
107-
let element_upcast =
108-
constraint.to_meta_type(db).try_upcast_to_callable(db)?;
124+
let element_upcast = constraint
125+
.to_meta_type(db)
126+
.try_upcast_to_callable_with_policy(db, policy)?;
109127
for callable in element_upcast.into_inner() {
110128
let signatures = callable
111129
.signatures(db)
@@ -134,7 +152,8 @@ impl<'db> Type<'db> {
134152
Type::Union(union) => {
135153
let mut callables = SmallVec::new();
136154
for element in union.elements(db) {
137-
let element_callable = element.try_upcast_to_callable(db)?;
155+
let element_callable =
156+
element.try_upcast_to_callable_with_policy(db, policy)?;
138157
callables.extend(element_callable.into_inner());
139158
}
140159
Some(CallableTypes::new(callables))
@@ -143,11 +162,13 @@ impl<'db> Type<'db> {
143162
Type::LiteralValue(literal) => match literal.kind() {
144163
LiteralValueTypeKind::Enum(enum_literal) => enum_literal
145164
.enum_class_instance(db)
146-
.try_upcast_to_callable(db),
165+
.try_upcast_to_callable_with_policy(db, policy),
147166
_ => None,
148167
},
149168

150-
Type::TypeAlias(alias) => alias.value_type(db).try_upcast_to_callable(db),
169+
Type::TypeAlias(alias) => alias
170+
.value_type(db)
171+
.try_upcast_to_callable_with_policy(db, policy),
151172

152173
Type::KnownBoundMethod(method) => Some(CallableTypes::one(CallableType::new(
153174
db,
@@ -220,6 +241,50 @@ pub enum CallableTypeKind {
220241
ParamSpecValue,
221242
}
222243

244+
/// A "policy" enum that describes how `type[]` types should be upcast
245+
/// to `Callable` types.
246+
///
247+
/// `type[T]` is generally considered assignable to
248+
/// `Callable[<constructor signature of T>, T]` in Python, and most
249+
/// type-checking in Python uses assignability rather than subtyping
250+
/// when determining whether to emit errors on code, so -- despite its
251+
/// scary name -- [`UpcastPolicy::Unsound`] is actually the policy that
252+
/// you probably want in most situations. We *have* to use
253+
/// [`UpcastPolicy::Sound`], however, when doing subtyping or redundancy
254+
/// checks, because constructor signatures in subclasses are not checked
255+
/// for Liskov substitutability: `type[S]` may not be a subtype of
256+
/// `Callable[<constructor signature of T>, T]` even if `S` is a subtype
257+
/// of `T`. If this unsoundness leaked into our union simplification or
258+
/// subtyping checks, it would ead to nontransitivity of subtyping,
259+
/// breaking fundamental assumptions in our model.
260+
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
261+
pub(crate) enum UpcastPolicy {
262+
/// Only upcast types to callables in a sound fashion.
263+
///
264+
/// This means that `type[T]` is upcast to `Top[Callable[..., T]]`
265+
/// rather than `Callable[<constructor signature of T>, T]`,
266+
/// since the former is sound while the latter is not.
267+
Sound,
268+
269+
/// Allow unsound upcasts to callables, such as treating `type[T]` as
270+
/// `Callable[<constructor signature of T>, T`.
271+
#[default]
272+
Unsound,
273+
}
274+
275+
impl From<TypeRelation> for UpcastPolicy {
276+
fn from(relation: TypeRelation) -> Self {
277+
match relation {
278+
TypeRelation::Subtyping
279+
| TypeRelation::Redundancy { .. }
280+
| TypeRelation::SubtypingAssuming => UpcastPolicy::Sound,
281+
TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability => {
282+
UpcastPolicy::Unsound
283+
}
284+
}
285+
}
286+
}
287+
223288
/// This type represents the set of all callable objects with a certain, possibly overloaded,
224289
/// signature.
225290
///

crates/ty_python_semantic/src/types/protocol_class.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ use itertools::Itertools;
66
use ruff_python_ast::name::Name;
77
use rustc_hash::FxHashMap;
88

9-
use crate::types::TypeContext;
109
use crate::types::callable::CallableTypeKind;
1110
use crate::types::relation::{HasRelationToVisitor, IsDisjointVisitor, TypeRelation};
11+
use crate::types::{TypeContext, UpcastPolicy};
1212
use crate::{
1313
Db, FxOrderSet,
1414
place::{
@@ -780,10 +780,9 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
780780
// With the new solver, we should be to replace all of this with an additional
781781
// constraint that enforces what `Self` can specialize to.
782782
let fallback_other = other.literal_fallback_instance(db).unwrap_or(other);
783-
attribute_type.try_upcast_to_callable(db).when_some_and(
784-
db,
785-
constraints,
786-
|callables| {
783+
attribute_type
784+
.try_upcast_to_callable_with_policy(db, UpcastPolicy::from(relation))
785+
.when_some_and(db, constraints, |callables| {
787786
callables
788787
.map(|callable| callable.apply_self(db, fallback_other))
789788
.has_relation_to_impl(
@@ -795,8 +794,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
795794
relation_visitor,
796795
disjointness_visitor,
797796
)
798-
},
799-
)
797+
})
800798
}
801799
// TODO: consider the types of the attribute on `other` for property members
802800
ProtocolMemberKind::Property(_) => ConstraintSet::from_bool(

crates/ty_python_semantic/src/types/relation.rs

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use crate::types::cyclic::PairVisitor;
1010
use crate::types::enums::is_single_member_enum;
1111
use crate::types::set_theoretic::RecursivelyDefined;
1212
use crate::types::{
13-
CallableType, CallableTypes, ClassBase, ClassType, CycleDetector, DynamicType, KnownClass,
14-
KnownInstanceType, LiteralValueTypeKind, MemberLookupPolicy, Parameters, ProtocolInstanceType,
15-
Signature, SubclassOfInner, TypeVarBoundOrConstraints, UnionType,
13+
CallableType, ClassBase, ClassType, CycleDetector, DynamicType, KnownClass, KnownInstanceType,
14+
LiteralValueTypeKind, MemberLookupPolicy, ProtocolInstanceType, SubclassOfInner,
15+
TypeVarBoundOrConstraints, UnionType, UpcastPolicy,
1616
};
1717
use crate::{
1818
Db,
@@ -1185,37 +1185,19 @@ impl<'db> Type<'db> {
11851185
}),
11861186

11871187
(_, Type::Callable(other_callable)) => {
1188-
// Special-case: upcasting a subclass-of to its `Callable` "supertype" is unsound,
1189-
// because we don't do Liskov checks for constructor signatures.
1190-
let upcasted = if let Type::SubclassOf(inner) = self {
1191-
match relation {
1192-
TypeRelation::Subtyping
1193-
| TypeRelation::SubtypingAssuming
1194-
| TypeRelation::Redundancy { .. } => {
1195-
Some(CallableTypes::one(CallableType::function_like(
1196-
db,
1197-
Signature::new(Parameters::top(), inner.to_instance(db)),
1198-
)))
1199-
}
1200-
TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability => {
1201-
self.try_upcast_to_callable(db)
1202-
}
1203-
}
1204-
} else {
1205-
self.try_upcast_to_callable(db)
1206-
};
12071188
relation_visitor.visit((self, target, relation), || {
1208-
upcasted.when_some_and(db, constraints, |callables| {
1209-
callables.has_relation_to_impl(
1210-
db,
1211-
other_callable,
1212-
constraints,
1213-
inferable,
1214-
relation,
1215-
relation_visitor,
1216-
disjointness_visitor,
1217-
)
1218-
})
1189+
self.try_upcast_to_callable_with_policy(db, UpcastPolicy::from(relation))
1190+
.when_some_and(db, constraints, |callables| {
1191+
callables.has_relation_to_impl(
1192+
db,
1193+
other_callable,
1194+
constraints,
1195+
inferable,
1196+
relation,
1197+
relation_visitor,
1198+
disjointness_visitor,
1199+
)
1200+
})
12191201
})
12201202
}
12211203

0 commit comments

Comments
 (0)