@@ -829,19 +829,19 @@ from typing import Literal, Union
829829class Custom:
830830 def __eq__(self, other: object) -> bool: return True
831831
832- class Default: pass
832+ def f1(x: Union[Custom, Literal[1], Literal[2]]):
833+ if x == 1:
834+ reveal_type(x) # N: Revealed type is "__main__.Custom | Literal[1]"
835+ else:
836+ reveal_type(x) # N: Revealed type is "__main__.Custom | Literal[2]"
833837
834- x1: Union[Custom, Literal[1], Literal[2]]
835- if x1 == 1:
836- reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1]"
837- else:
838- reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[2]"
838+ class Default: pass
839839
840- x2 : Union[Default, Literal[1], Literal[2]]
841- if x2 == 1:
842- reveal_type(x2 ) # N: Revealed type is "Literal[1]"
843- else:
844- reveal_type(x2 ) # N: Revealed type is "__main__.Default | Literal[2]"
840+ def f2(x : Union[Default, Literal[1], Literal[2]]):
841+ if x == 1:
842+ reveal_type(x ) # N: Revealed type is "Literal[1]"
843+ else:
844+ reveal_type(x ) # N: Revealed type is "__main__.Default | Literal[2]"
845845[builtins fixtures/primitives.pyi]
846846
847847[case testNarrowingEqualityCustomEqualityEnum]
@@ -875,25 +875,23 @@ from typing import Literal, Union
875875class Custom:
876876 def __eq__(self, other: object) -> bool: return True
877877
878- class Default: pass
879-
880- x: Literal[1, 2, None]
881- y: Custom
882- z: Default
878+ def f1(x: Literal[1, 2, None], y: Custom):
879+ if 1 == x == y:
880+ reveal_type(x) # N: Revealed type is "Literal[1]"
881+ reveal_type(y) # N: Revealed type is "__main__.Custom"
882+ else:
883+ reveal_type(x) # N: Revealed type is "Literal[2] | None"
884+ reveal_type(y) # N: Revealed type is "__main__.Custom"
883885
884- if 1 == x == y:
885- reveal_type(x) # N: Revealed type is "Literal[1]"
886- reveal_type(y) # N: Revealed type is "__main__.Custom"
887- else:
888- reveal_type(x) # N: Revealed type is "Literal[2] | None"
889- reveal_type(y) # N: Revealed type is "__main__.Custom"
886+ class Default: pass
890887
891- if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Literal[1, 2] | None", right operand type: "Default")
892- reveal_type(x) # E: Statement is unreachable
893- reveal_type(z)
894- else:
895- reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2] | None"
896- reveal_type(z) # N: Revealed type is "__main__.Default"
888+ def f2(x: Literal[1, 2, None], z: Default):
889+ if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Literal[1, 2] | None", right operand type: "Default")
890+ reveal_type(x) # E: Statement is unreachable
891+ reveal_type(z)
892+ else:
893+ reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2] | None"
894+ reveal_type(z) # N: Revealed type is "__main__.Default"
897895[builtins fixtures/primitives.pyi]
898896
899897[case testNarrowingCustomEqualityLiteralElseBranch]
@@ -1445,19 +1443,23 @@ if val not in (None,):
14451443 reveal_type(val) # N: Revealed type is "__main__.A"
14461444else:
14471445 reveal_type(val) # N: Revealed type is "None"
1446+ [builtins fixtures/primitives.pyi]
14481447
1449- class Hmm:
1448+ [case testNarrowingCustomEqualityOptionalEqualsNone]
1449+ # flags: --strict-equality --warn-unreachable
1450+ from __future__ import annotations
1451+ class Custom:
14501452 def __eq__(self, other) -> bool: ...
14511453
1452- hmm: Optional[Hmm]
1453- if hmm == None:
1454- reveal_type(hmm ) # N: Revealed type is "__main__.Hmm | None"
1455- else:
1456- reveal_type(hmm ) # N: Revealed type is "__main__.Hmm "
1457- if hmm != None:
1458- reveal_type(hmm ) # N: Revealed type is "__main__.Hmm "
1459- else:
1460- reveal_type(hmm ) # N: Revealed type is "__main__.Hmm | None"
1454+ def f(x: Custom | None):
1455+ if x == None:
1456+ reveal_type(x ) # N: Revealed type is "__main__.Custom | None"
1457+ else:
1458+ reveal_type(x ) # N: Revealed type is "__main__.Custom "
1459+ if x != None:
1460+ reveal_type(x ) # N: Revealed type is "__main__.Custom "
1461+ else:
1462+ reveal_type(x ) # N: Revealed type is "__main__.Custom | None"
14611463[builtins fixtures/primitives.pyi]
14621464
14631465[case testNarrowingWithTupleOfTypes]
@@ -2992,6 +2994,62 @@ def f2(x: Any) -> None:
29922994 reveal_type(x) # N: Revealed type is "Any"
29932995[builtins fixtures/tuple.pyi]
29942996
2997+ [case testNarrowTypeObject]
2998+ # flags: --strict-equality --warn-unreachable
2999+ from typing import Any
3000+
3001+ # https://github.com/python/mypy/issues/13704
3002+
3003+ def f1(cls: type):
3004+ if cls is str:
3005+ reveal_type(cls) # N: Revealed type is "def (o: builtins.object =) -> builtins.str"
3006+ reveal_type(cls(5)) # N: Revealed type is "builtins.str"
3007+
3008+ if issubclass(cls, int):
3009+ pass
3010+ elif cls is str:
3011+ reveal_type(cls) # N: Revealed type is "type[builtins.object]"
3012+ reveal_type(cls(5)) # E: Too many arguments for "object" \
3013+ # N: Revealed type is "builtins.object"
3014+
3015+ def f2(cls: type[object]):
3016+ if cls is str:
3017+ reveal_type(cls) # N: Revealed type is "type[builtins.object]"
3018+ reveal_type(cls(5)) # E: Too many arguments for "object" \
3019+ # N: Revealed type is "builtins.object"
3020+
3021+ if issubclass(cls, int):
3022+ pass
3023+ elif cls is str:
3024+ reveal_type(cls) # N: Revealed type is "type[builtins.object]"
3025+ reveal_type(cls(5)) # E: Too many arguments for "object" \
3026+ # N: Revealed type is "builtins.object"
3027+
3028+ def f3(cls: type[Any]):
3029+ if cls is str:
3030+ reveal_type(cls) # N: Revealed type is "type[Any]"
3031+ reveal_type(cls(5)) # N: Revealed type is "Any"
3032+
3033+ if issubclass(cls, int):
3034+ pass
3035+ elif cls is str:
3036+ reveal_type(cls) # N: Revealed type is "type[Any]"
3037+ reveal_type(cls(5)) # N: Revealed type is "Any"
3038+ [builtins fixtures/isinstance.pyi]
3039+
3040+ [case testNarrowTypeObjectUnion]
3041+ # flags: --strict-equality --warn-unreachable
3042+ from __future__ import annotations
3043+
3044+ def f4(cls: type[str | int]):
3045+ reveal_type(cls) # N: Revealed type is "type[builtins.str] | type[builtins.int]"
3046+
3047+ if cls is int:
3048+ reveal_type(cls) # N: Revealed type is "type[builtins.int]"
3049+ if cls == int:
3050+ reveal_type(cls) # N: Revealed type is "type[builtins.int]"
3051+ [builtins fixtures/primitives.pyi]
3052+
29953053[case testTypeEqualsCheck]
29963054# flags: --strict-equality --warn-unreachable
29973055from typing import Any
0 commit comments