Skip to content

Commit 2c68057

Browse files
[ty] Preserve argument signature in @total_ordering (#22496)
## Summary Closes astral-sh/ty#2435.
1 parent 8e29be9 commit 2c68057

2 files changed

Lines changed: 296 additions & 24 deletions

File tree

crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,125 @@ reveal_type(s1 > s2) # revealed: bool
3838
reveal_type(s1 >= s2) # revealed: bool
3939
```
4040

41+
## Signature derived from source ordering method
42+
43+
When the source ordering method accepts a broader type (like `object`) for its `other` parameter,
44+
the synthesized comparison methods should use the same signature. This allows comparisons with types
45+
other than the class itself:
46+
47+
```py
48+
from functools import total_ordering
49+
50+
@total_ordering
51+
class Comparable:
52+
def __init__(self, value: int):
53+
self.value = value
54+
55+
def __eq__(self, other: object) -> bool:
56+
if isinstance(other, Comparable):
57+
return self.value == other.value
58+
if isinstance(other, int):
59+
return self.value == other
60+
return NotImplemented
61+
62+
def __lt__(self, other: object) -> bool:
63+
if isinstance(other, Comparable):
64+
return self.value < other.value
65+
if isinstance(other, int):
66+
return self.value < other
67+
return NotImplemented
68+
69+
a = Comparable(10)
70+
b = Comparable(20)
71+
72+
# Comparisons with the same type work.
73+
reveal_type(a <= b) # revealed: bool
74+
reveal_type(a >= b) # revealed: bool
75+
76+
# Comparisons with `int` also work because `__lt__` accepts `object`.
77+
reveal_type(a <= 15) # revealed: bool
78+
reveal_type(a >= 5) # revealed: bool
79+
```
80+
81+
## Multiple ordering methods with different signatures
82+
83+
When multiple ordering methods are defined with different signatures, the decorator selects a "root"
84+
method using the priority order: `__lt__` > `__le__` > `__gt__` > `__ge__`. Synthesized methods use
85+
the signature from the highest-priority method. Methods that are explicitly defined are not
86+
overridden.
87+
88+
```py
89+
from functools import total_ordering
90+
91+
@total_ordering
92+
class MultiSig:
93+
def __init__(self, value: int):
94+
self.value = value
95+
96+
def __eq__(self, other: object) -> bool:
97+
return True
98+
# __lt__ accepts `object` (highest priority, used as root)
99+
def __lt__(self, other: object) -> bool:
100+
return True
101+
# __gt__ only accepts `MultiSig` (not overridden by decorator)
102+
def __gt__(self, other: "MultiSig") -> bool:
103+
return True
104+
105+
a = MultiSig(10)
106+
b = MultiSig(20)
107+
108+
# __le__ and __ge__ are synthesized with __lt__'s signature (accepts `object`)
109+
reveal_type(a <= b) # revealed: bool
110+
reveal_type(a <= 15) # revealed: bool
111+
reveal_type(a >= b) # revealed: bool
112+
reveal_type(a >= 15) # revealed: bool
113+
114+
# __gt__ keeps its original signature (only accepts MultiSig)
115+
reveal_type(a > b) # revealed: bool
116+
a > 15 # error: [unsupported-operator]
117+
```
118+
119+
## Overloaded ordering method
120+
121+
When the source ordering method is overloaded, the synthesized comparison methods should preserve
122+
all overloads:
123+
124+
```py
125+
from functools import total_ordering
126+
from typing import overload
127+
128+
@total_ordering
129+
class Flexible:
130+
def __init__(self, value: int):
131+
self.value = value
132+
133+
def __eq__(self, other: object) -> bool:
134+
return True
135+
136+
@overload
137+
def __lt__(self, other: "Flexible") -> bool: ...
138+
@overload
139+
def __lt__(self, other: int) -> bool: ...
140+
def __lt__(self, other: "Flexible | int") -> bool:
141+
if isinstance(other, Flexible):
142+
return self.value < other.value
143+
return self.value < other
144+
145+
a = Flexible(10)
146+
b = Flexible(20)
147+
148+
# Synthesized __le__ preserves overloads from __lt__
149+
reveal_type(a <= b) # revealed: bool
150+
reveal_type(a <= 15) # revealed: bool
151+
152+
# Synthesized __ge__ also preserves overloads
153+
reveal_type(a >= b) # revealed: bool
154+
reveal_type(a >= 15) # revealed: bool
155+
156+
# But comparison with an unsupported type should still error
157+
a <= "string" # error: [unsupported-operator]
158+
```
159+
41160
## Using `__gt__` as the root comparison method
42161

43162
When a class defines `__eq__` and `__gt__`, the decorator synthesizes `__lt__`, `__le__`, and
@@ -127,6 +246,41 @@ reveal_type(c1 > c2) # revealed: bool
127246
reveal_type(c1 >= c2) # revealed: bool
128247
```
129248

249+
## Method precedence with inheritance
250+
251+
The decorator always prefers `__lt__` > `__le__` > `__gt__` > `__ge__`, regardless of whether the
252+
method is defined locally or inherited. In this example, the inherited `__lt__` takes precedence
253+
over the locally-defined `__gt__`:
254+
255+
```py
256+
from functools import total_ordering
257+
from typing import Literal
258+
259+
class Base:
260+
def __lt__(self, other: "Base") -> Literal[True]:
261+
return True
262+
263+
@total_ordering
264+
class Child(Base):
265+
# __gt__ is defined locally, but __lt__ (inherited) takes precedence
266+
def __gt__(self, other: "Child") -> Literal[False]:
267+
return False
268+
269+
c1 = Child()
270+
c2 = Child()
271+
272+
# __lt__ is inherited from Base
273+
reveal_type(c1 < c2) # revealed: Literal[True]
274+
275+
# __gt__ is defined locally on Child
276+
reveal_type(c1 > c2) # revealed: Literal[False]
277+
278+
# __le__ and __ge__ are synthesized from __lt__ (the highest-priority method),
279+
# even though __gt__ is defined locally on the class itself
280+
reveal_type(c1 <= c2) # revealed: bool
281+
reveal_type(c1 >= c2) # revealed: bool
282+
```
283+
130284
## Explicitly-defined methods are not overridden
131285

132286
When a class explicitly defines multiple comparison methods, the decorator does not override them.
@@ -245,6 +399,79 @@ n1 <= n2 # error: [unsupported-operator]
245399
n1 >= n2 # error: [unsupported-operator]
246400
```
247401

402+
## Non-bool return type
403+
404+
When the root ordering method returns a non-bool type (like `int`), the synthesized methods return a
405+
union of that type and `bool`. This is because `@total_ordering` generates methods like:
406+
407+
```python
408+
def __le__(self, other):
409+
return self < other or self == other
410+
```
411+
412+
If `__lt__` returns `int`, then the synthesized `__le__` could return either `int` (from
413+
`self < other`) or `bool` (from `self == other`). Since `bool` is a subtype of `int`, the union
414+
simplifies to `int`:
415+
416+
```py
417+
from functools import total_ordering
418+
419+
@total_ordering
420+
class IntReturn:
421+
def __init__(self, value: int):
422+
self.value = value
423+
424+
def __eq__(self, other: object) -> bool:
425+
if not isinstance(other, IntReturn):
426+
return NotImplemented
427+
return self.value == other.value
428+
429+
def __lt__(self, other: "IntReturn") -> int:
430+
return self.value - other.value
431+
432+
a = IntReturn(10)
433+
b = IntReturn(20)
434+
435+
# User-defined __lt__ returns int.
436+
reveal_type(a < b) # revealed: int
437+
438+
# Synthesized methods return int (the union int | bool simplifies to int
439+
# because bool is a subtype of int in Python).
440+
reveal_type(a <= b) # revealed: int
441+
reveal_type(a > b) # revealed: int
442+
reveal_type(a >= b) # revealed: int
443+
```
444+
445+
When the root method returns a type that is not a supertype of `bool`, the union is preserved:
446+
447+
```py
448+
from functools import total_ordering
449+
450+
@total_ordering
451+
class StrReturn:
452+
def __init__(self, value: str):
453+
self.value = value
454+
455+
def __eq__(self, other: object) -> bool:
456+
if not isinstance(other, StrReturn):
457+
return NotImplemented
458+
return self.value == other.value
459+
460+
def __lt__(self, other: "StrReturn") -> str:
461+
return self.value
462+
463+
a = StrReturn("a")
464+
b = StrReturn("b")
465+
466+
# User-defined __lt__ returns str.
467+
reveal_type(a < b) # revealed: str
468+
469+
# Synthesized methods return str | bool.
470+
reveal_type(a <= b) # revealed: str | bool
471+
reveal_type(a > b) # revealed: str | bool
472+
reveal_type(a >= b) # revealed: str | bool
473+
```
474+
248475
## Function call form
249476

250477
When `total_ordering` is called as a function (not as a decorator), the same validation is

crates/ty_python_semantic/src/types/class.rs

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,17 +1586,44 @@ impl<'db> ClassLiteral<'db> {
15861586
}
15871587

15881588
/// Returns `true` if any class in this class's MRO (excluding `object`) defines an ordering
1589-
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation and
1590-
/// for synthesizing comparison methods.
1589+
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation.
15911590
pub(super) fn has_ordering_method_in_mro(
15921591
self,
15931592
db: &'db dyn Db,
15941593
specialization: Option<Specialization<'db>>,
15951594
) -> bool {
1596-
self.iter_mro(db, specialization)
1597-
.filter_map(ClassBase::into_class)
1598-
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
1599-
.any(|class| class.class_literal(db).0.has_own_ordering_method(db))
1595+
self.total_ordering_root_method(db, specialization)
1596+
.is_some()
1597+
}
1598+
1599+
/// Returns the type of the ordering method used by `@total_ordering`, if any.
1600+
///
1601+
/// Following `functools.total_ordering` precedence, we prefer `__lt__` > `__le__` > `__gt__` >
1602+
/// `__ge__`, regardless of whether the method is defined locally or inherited.
1603+
pub(super) fn total_ordering_root_method(
1604+
self,
1605+
db: &'db dyn Db,
1606+
specialization: Option<Specialization<'db>>,
1607+
) -> Option<Type<'db>> {
1608+
const ORDERING_METHODS: [&str; 4] = ["__lt__", "__le__", "__gt__", "__ge__"];
1609+
1610+
for name in ORDERING_METHODS {
1611+
for base in self.iter_mro(db, specialization) {
1612+
let Some(base_class) = base.into_class() else {
1613+
continue;
1614+
};
1615+
let (base_literal, base_specialization) = base_class.class_literal(db);
1616+
if base_literal.is_known(db, KnownClass::Object) {
1617+
continue;
1618+
}
1619+
let member = class_member(db, base_literal.body_scope(db), name);
1620+
if let Some(ty) = member.ignore_possibly_undefined() {
1621+
return Some(ty.apply_optional_specialization(db, base_specialization));
1622+
}
1623+
}
1624+
}
1625+
1626+
None
16001627
}
16011628

16021629
pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
@@ -2448,26 +2475,44 @@ impl<'db> ClassLiteral<'db> {
24482475
// ordering method. The decorator requires at least one of __lt__,
24492476
// __le__, __gt__, or __ge__ to be defined (either in this class or
24502477
// inherited from a superclass, excluding `object`).
2451-
if self.total_ordering(db) && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
2452-
if self.has_ordering_method_in_mro(db, specialization) {
2453-
let instance_ty =
2454-
Type::instance(db, self.apply_optional_specialization(db, specialization));
2455-
2456-
let signature = Signature::new(
2457-
Parameters::new(
2458-
db,
2459-
[
2460-
Parameter::positional_or_keyword(Name::new_static("self"))
2461-
.with_annotated_type(instance_ty),
2462-
Parameter::positional_or_keyword(Name::new_static("other"))
2463-
.with_annotated_type(instance_ty),
2464-
],
2465-
),
2466-
KnownClass::Bool.to_instance(db),
2478+
//
2479+
// Only synthesize methods that are not already defined in the MRO.
2480+
if self.total_ordering(db)
2481+
&& matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__")
2482+
&& !self
2483+
.iter_mro(db, specialization)
2484+
.filter_map(ClassBase::into_class)
2485+
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
2486+
.any(|class| {
2487+
class_member(db, class.class_literal(db).0.body_scope(db), name)
2488+
.ignore_possibly_undefined()
2489+
.is_some()
2490+
})
2491+
&& self.has_ordering_method_in_mro(db, specialization)
2492+
&& let Some(root_method_ty) = self.total_ordering_root_method(db, specialization)
2493+
&& let Some(callables) = root_method_ty.try_upcast_to_callable(db)
2494+
{
2495+
let bool_ty = KnownClass::Bool.to_instance(db);
2496+
let synthesized_callables = callables.map(|callable| {
2497+
let signatures = CallableSignature::from_overloads(
2498+
callable.signatures(db).iter().map(|signature| {
2499+
// The generated methods return a union of the root method's return type
2500+
// and `bool`. This is because `@total_ordering` synthesizes methods like:
2501+
// def __gt__(self, other): return not (self == other or self < other)
2502+
// If `__lt__` returns `int`, then `__gt__` could return `int | bool`.
2503+
let return_ty =
2504+
UnionType::from_elements(db, [signature.return_ty, bool_ty]);
2505+
Signature::new_generic(
2506+
signature.generic_context,
2507+
signature.parameters().clone(),
2508+
return_ty,
2509+
)
2510+
}),
24672511
);
2512+
CallableType::new(db, signatures, CallableTypeKind::FunctionLike)
2513+
});
24682514

2469-
return Some(Type::function_like_callable(db, signature));
2470-
}
2515+
return Some(synthesized_callables.into_type(db));
24712516
}
24722517

24732518
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;

0 commit comments

Comments
 (0)