Skip to content

Commit 6e8c42f

Browse files
committed
Better narrowing with custom equality
Co-authored-by: A5rocks
1 parent 32a35cd commit 6e8c42f

2 files changed

Lines changed: 162 additions & 77 deletions

File tree

mypy/checker.py

Lines changed: 69 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6633,7 +6633,7 @@ def narrow_type_by_identity_equality(
66336633
if operator in {"is", "is not"}:
66346634
is_target_for_value_narrowing = is_singleton_identity_type
66356635
should_coerce_literals = True
6636-
should_narrow_by_identity_equality = True
6636+
custom_eq_indices = set()
66376637
enum_comparison_is_ambiguous = False
66386638

66396639
elif operator in {"==", "!="}:
@@ -6646,19 +6646,11 @@ def narrow_type_by_identity_equality(
66466646
should_coerce_literals = True
66476647
break
66486648

6649-
expr_types = [operand_types[i] for i in expr_indices]
6650-
should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types))
6649+
custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])}
66516650
enum_comparison_is_ambiguous = True
66526651
else:
66536652
raise AssertionError
66546653

6655-
if not should_narrow_by_identity_equality:
6656-
# This is a bit of a legacy code path that might be a little unsound since it ignores
6657-
# custom __eq__. We should see if we can get rid of it in favour of `return {}, {}`
6658-
return self.refine_away_none_in_comparison(
6659-
operands, operand_types, expr_indices, narrowable_indices
6660-
)
6661-
66626654
value_targets = []
66636655
type_targets = []
66646656
for i in expr_indices:
@@ -6670,6 +6662,10 @@ def narrow_type_by_identity_equality(
66706662
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
66716663
# See testMatchEnumSingleChoice
66726664
expr_type = coerce_to_literal(expr_type)
6665+
if i in custom_eq_indices:
6666+
# We can't use types with custom __eq__ as targets for narrowing
6667+
# E.g. if (x: int | None) == (y: CustomEq | None), we cannot narrow x to None
6668+
continue
66736669
if is_target_for_value_narrowing(get_proper_type(expr_type)):
66746670
value_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
66756671
else:
@@ -6681,7 +6677,11 @@ def narrow_type_by_identity_equality(
66816677
for i in expr_indices:
66826678
if i not in narrowable_indices:
66836679
continue
6684-
expr_type = coerce_to_literal(operand_types[i])
6680+
if i in custom_eq_indices:
6681+
# Handled later
6682+
continue
6683+
expr_type = operand_types[i]
6684+
expr_type = coerce_to_literal(expr_type)
66856685
expr_type = try_expanding_sum_type_to_union(expr_type, None)
66866686
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
66876687
for j, target in value_targets:
@@ -6702,6 +6702,9 @@ def narrow_type_by_identity_equality(
67026702
for i in expr_indices:
67036703
if i not in narrowable_indices:
67046704
continue
6705+
if i in custom_eq_indices:
6706+
# Handled later
6707+
continue
67056708
expr_type = operand_types[i]
67066709
for j, target in type_targets:
67076710
if i == j:
@@ -6710,9 +6713,63 @@ def narrow_type_by_identity_equality(
67106713
operands[i], *conditional_types(expr_type, [target])
67116714
)
67126715
if if_map:
6713-
else_map = {} # this is the big difference compared to the above
6716+
# For type_targets, we cannot narrow in the negative case
6717+
# e.g. if (x: str | None) != (y: str), we cannot narrow x to None
6718+
else_map = {}
67146719
partial_type_maps.append((if_map, else_map))
67156720

6721+
for i in custom_eq_indices:
6722+
if i not in narrowable_indices:
6723+
continue
6724+
union_expr_type = operand_types[i]
6725+
if not isinstance(union_expr_type, UnionType):
6726+
expr_type = union_expr_type
6727+
for j, target in value_targets:
6728+
_if_map, else_map = conditional_types_to_typemaps(
6729+
operands[i], *conditional_types(expr_type, [target])
6730+
)
6731+
if else_map:
6732+
partial_type_maps.append(({}, else_map))
6733+
continue
6734+
6735+
or_if_maps: list[TypeMap] = []
6736+
or_else_maps: list[TypeMap] = []
6737+
for expr_type in union_expr_type.items:
6738+
if has_custom_eq_checks(expr_type):
6739+
or_if_maps.append({operands[i]: expr_type})
6740+
6741+
for j in expr_indices:
6742+
if j in custom_eq_indices:
6743+
continue
6744+
target_type = operand_types[j]
6745+
if should_coerce_literals:
6746+
target_type = coerce_to_literal(target_type)
6747+
target = TypeRange(target_type, is_upper_bound=False)
6748+
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
6749+
6750+
if is_value_target:
6751+
expr_type = coerce_to_literal(expr_type)
6752+
expr_type = try_expanding_sum_type_to_union(expr_type, None)
6753+
if_map, else_map = conditional_types_to_typemaps(
6754+
operands[i], *conditional_types(expr_type, [target], default=expr_type)
6755+
)
6756+
or_if_maps.append(if_map)
6757+
if is_value_target:
6758+
or_else_maps.append(else_map)
6759+
6760+
final_if_map = {}
6761+
final_else_map = {}
6762+
if or_if_maps:
6763+
final_if_map = or_if_maps[0]
6764+
for if_map in or_if_maps[1:]:
6765+
final_if_map = or_conditional_maps(final_if_map, if_map)
6766+
if or_else_maps:
6767+
final_else_map = or_else_maps[0]
6768+
for else_map in or_else_maps[1:]:
6769+
final_else_map = or_conditional_maps(final_else_map, else_map)
6770+
6771+
partial_type_maps.append((final_if_map, final_else_map))
6772+
67166773
for i in expr_indices:
67176774
type_expr = operands[i]
67186775
if (
@@ -6934,49 +6991,6 @@ def _propagate_walrus_assignments(
69346991
return parent_expr
69356992
return expr
69366993

6937-
def refine_away_none_in_comparison(
6938-
self,
6939-
operands: list[Expression],
6940-
operand_types: list[Type],
6941-
chain_indices: list[int],
6942-
narrowable_operand_indices: AbstractSet[int],
6943-
) -> tuple[TypeMap, TypeMap]:
6944-
"""Produces conditional type maps refining away None in an identity/equality chain.
6945-
6946-
For more details about what the different arguments mean, see the
6947-
docstring of 'narrow_type_by_identity_equality' up above.
6948-
"""
6949-
6950-
non_optional_types = []
6951-
for i in chain_indices:
6952-
typ = operand_types[i]
6953-
if not is_overlapping_none(typ):
6954-
non_optional_types.append(typ)
6955-
6956-
if_map, else_map = {}, {}
6957-
6958-
if not non_optional_types or (len(non_optional_types) != len(chain_indices)):
6959-
6960-
# Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
6961-
# convenient but is strictly not type-safe):
6962-
for i in narrowable_operand_indices:
6963-
expr_type = operand_types[i]
6964-
if not is_overlapping_none(expr_type):
6965-
continue
6966-
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
6967-
if_map[operands[i]] = remove_optional(expr_type)
6968-
6969-
# Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
6970-
# so type-safe but less convenient, because e.g. `Optional[A] == None` still results
6971-
# in `Optional[A]`):
6972-
if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types):
6973-
for i in narrowable_operand_indices:
6974-
expr_type = operand_types[i]
6975-
if is_overlapping_none(expr_type):
6976-
else_map[operands[i]] = remove_optional(expr_type)
6977-
6978-
return if_map, else_map
6979-
69806994
def is_len_of_tuple(self, expr: Expression) -> bool:
69816995
"""Is this expression a `len(x)` call where x is a tuple or union of tuples?"""
69826996
if not isinstance(expr, CallExpr):

test-data/unit/check-narrowing.test

Lines changed: 93 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -823,9 +823,8 @@ def bar(x: Union[SingletonFoo, Foo], y: SingletonFoo) -> None:
823823
reveal_type(x) # N: Revealed type is "Literal[__main__.SingletonFoo.A]"
824824
[builtins fixtures/primitives.pyi]
825825

826-
[case testNarrowingEqualityDisabledForCustomEquality]
826+
[case testNarrowingEqualityCustomEqualityDisabled]
827827
from typing import Literal, Union
828-
from enum import Enum
829828

830829
class Custom:
831830
def __eq__(self, other: object) -> bool: return True
@@ -834,15 +833,20 @@ class Default: pass
834833

835834
x1: Union[Custom, Literal[1], Literal[2]]
836835
if x1 == 1:
837-
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1] | Literal[2]"
836+
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1]"
838837
else:
839-
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1] | Literal[2]"
838+
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[2]"
840839

841840
x2: Union[Default, Literal[1], Literal[2]]
842841
if x2 == 1:
843842
reveal_type(x2) # N: Revealed type is "Literal[1]"
844843
else:
845844
reveal_type(x2) # N: Revealed type is "__main__.Default | Literal[2]"
845+
[builtins fixtures/primitives.pyi]
846+
847+
[case testNarrowingEqualityCustomEqualityEnum]
848+
from typing import Literal, Union
849+
from enum import Enum
846850

847851
class CustomEnum(Enum):
848852
A = 1
@@ -855,7 +859,7 @@ key: Literal[CustomEnum.A]
855859
if x3 == key:
856860
reveal_type(x3) # N: Revealed type is "__main__.CustomEnum"
857861
else:
858-
reveal_type(x3) # N: Revealed type is "__main__.CustomEnum"
862+
reveal_type(x3) # N: Revealed type is "Literal[__main__.CustomEnum.B]"
859863

860864
# For comparison, this narrows since we bypass __eq__
861865
if x3 is key:
@@ -864,7 +868,7 @@ else:
864868
reveal_type(x3) # N: Revealed type is "Literal[__main__.CustomEnum.B]"
865869
[builtins fixtures/primitives.pyi]
866870

867-
[case testNarrowingEqualityDisabledForCustomEqualityChain]
871+
[case testNarrowingEqualityCustomEqualityDisabledChainedComparison]
868872
# flags: --strict-equality --warn-unreachable
869873
from typing import Literal, Union
870874

@@ -877,21 +881,13 @@ x: Literal[1, 2, None]
877881
y: Custom
878882
z: Default
879883

880-
# We could maybe try doing something clever, but for simplicity we
881-
# treat the whole chain as contaminated and mostly disable narrowing.
882-
#
883-
# The only exception is that we do at least strip away the 'None'. We
884-
# (perhaps optimistically) assume no custom class would be pathological
885-
# enough to declare itself to be equal to None and so permit this narrowing,
886-
# since it's often convenient in practice.
887884
if 1 == x == y:
888-
reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2]"
885+
reveal_type(x) # N: Revealed type is "Literal[1]"
889886
reveal_type(y) # N: Revealed type is "__main__.Custom"
890887
else:
891-
reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2] | None"
888+
reveal_type(x) # N: Revealed type is "Literal[2] | None"
892889
reveal_type(y) # N: Revealed type is "__main__.Custom"
893890

894-
# No contamination here
895891
if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Literal[1, 2] | None", right operand type: "Default")
896892
reveal_type(x) # E: Statement is unreachable
897893
reveal_type(z)
@@ -900,6 +896,75 @@ else:
900896
reveal_type(z) # N: Revealed type is "__main__.Default"
901897
[builtins fixtures/primitives.pyi]
902898

899+
[case testNarrowingCustomEqualityLiteralElseBranch]
900+
# flags: --strict-equality --warn-unreachable
901+
from __future__ import annotations
902+
from typing import Literal
903+
904+
class Custom:
905+
def __eq__(self, other: object) -> bool:
906+
raise
907+
908+
def f(v: Custom | Literal["text"]) -> Custom | None:
909+
if v == "text":
910+
reveal_type(v) # N: Revealed type is "__main__.Custom | Literal['text']"
911+
return None
912+
else:
913+
reveal_type(v) # N: Revealed type is "__main__.Custom"
914+
return v
915+
916+
def g(v: Custom | Literal["text"]) -> Custom | None:
917+
if v != "text":
918+
reveal_type(v) # N: Revealed type is "__main__.Custom"
919+
return None
920+
else:
921+
reveal_type(v) # N: Revealed type is "__main__.Custom | Literal['text']"
922+
return v # E: Incompatible return value type (got "Custom | Literal['text']", expected "Custom | None")
923+
[builtins fixtures/primitives.pyi]
924+
925+
[case testNarrowingCustomEqualityUnion]
926+
# flags: --strict-equality --warn-unreachable
927+
from __future__ import annotations
928+
from typing import Any
929+
930+
def realistic(x: dict[str, Any]):
931+
val = x.get("hey")
932+
if val == 12:
933+
reveal_type(val) # N: Revealed type is "Any | Literal[12]?"
934+
935+
def f1(x: Any | None):
936+
if x == 12:
937+
reveal_type(x) # N: Revealed type is "Any | Literal[12]?"
938+
939+
class Custom:
940+
def __eq__(self, other: object) -> bool:
941+
raise
942+
943+
def f2(x: Custom | None):
944+
if x == 12:
945+
reveal_type(x) # N: Revealed type is "__main__.Custom"
946+
else:
947+
reveal_type(x) # N: Revealed type is "__main__.Custom | None"
948+
[builtins fixtures/dict.pyi]
949+
950+
[case testNarrowingCustomEqualityUnionTypeTarget]
951+
# flags: --strict-equality --warn-unreachable
952+
from __future__ import annotations
953+
from typing import Any
954+
955+
class Custom:
956+
def __eq__(self, other: object) -> bool:
957+
raise
958+
959+
def f(x: Custom | None, y: int | None):
960+
if x == y:
961+
reveal_type(x) # N: Revealed type is "__main__.Custom | None"
962+
reveal_type(y) # N: Revealed type is "builtins.int | None"
963+
else:
964+
reveal_type(x) # N: Revealed type is "__main__.Custom | None"
965+
reveal_type(y) # N: Revealed type is "builtins.int | None"
966+
[builtins fixtures/primitives.pyi]
967+
903968
[case testNarrowingUnreachableCases]
904969
# flags: --strict-equality --warn-unreachable
905970
from typing import Literal, Union
@@ -2157,7 +2222,7 @@ def f3(x: object) -> None:
21572222

21582223
def f4(x: int | Any) -> None:
21592224
if x == IE.X:
2160-
reveal_type(x) # N: Revealed type is "builtins.int | Any"
2225+
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X] | Any"
21612226
else:
21622227
reveal_type(x) # N: Revealed type is "builtins.int | Any"
21632228

@@ -2232,9 +2297,9 @@ def f5(x: E | str | int) -> None:
22322297

22332298
def f6(x: IE | Any) -> None:
22342299
if x == IE.X:
2235-
reveal_type(x) # N: Revealed type is "__main__.IE | Any"
2300+
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X] | Any"
22362301
else:
2237-
reveal_type(x) # N: Revealed type is "__main__.IE | Any"
2302+
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.Y] | Any"
22382303

22392304
def f7(x: IE | None) -> None:
22402305
if x == IE.X:
@@ -2316,7 +2381,7 @@ def f(x: str | int) -> None:
23162381
z = y
23172382
[builtins fixtures/primitives.pyi]
23182383

2319-
[case testConsistentNarrowingInWithCustomEq]
2384+
[case testConsistentNarrowingEqAndInWithCustomEq]
23202385
# flags: --python-version 3.10
23212386

23222387
# https://github.com/python/mypy/issues/17864
@@ -2334,11 +2399,17 @@ class C:
23342399
class D(C):
23352400
pass
23362401

2337-
def f(x: C) -> None:
2402+
def f1(x: C) -> None:
23382403
if x in [D(5)]:
23392404
reveal_type(x) # D # N: Revealed type is "__main__.C"
23402405

2341-
f(C(5))
2406+
f1(C(5))
2407+
2408+
def f2(x: C) -> None:
2409+
if x == D(5):
2410+
reveal_type(x) # D # N: Revealed type is "__main__.C"
2411+
2412+
f2(C(5))
23422413
[builtins fixtures/primitives.pyi]
23432414

23442415
[case testNarrowingTypeVarNone]

0 commit comments

Comments
 (0)