Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4850,6 +4850,17 @@ def intersect_instances(
curr_module = self.scope.stack[0]
assert isinstance(curr_module, MypyFile)

# First, retry narrowing while allowing promotions (they are disabled by default
# for isinstance() checks, etc). This way we will still type-check branches like
# x: complex = 1
# if isinstance(x, int):
# ...
left, right = instances
if is_proper_subtype(left, right, ignore_promotions=False):
return left
if is_proper_subtype(right, left, ignore_promotions=False):
return right

def _get_base_classes(instances_: tuple[Instance, Instance]) -> list[Instance]:
base_classes_ = []
for inst in instances_:
Expand Down
56 changes: 56 additions & 0 deletions test-data/unit/check-type-promotion.test
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,59 @@ def f(x: Union[SupportsFloat, T]) -> Union[SupportsFloat, T]: pass
f(0) # should not crash
[builtins fixtures/primitives.pyi]
[out]

[case testIntersectionUsingPromotion]
# flags: --warn-unreachable
x: complex = 1
reveal_type(x) # N: Revealed type is "builtins.complex"
if isinstance(x, int):
reveal_type(x) # N: Revealed type is "builtins.int"
else:
reveal_type(x) # N: Revealed type is "builtins.complex"
[builtins fixtures/primitives.pyi]

[case testIntersectionUsingPromotion2]
# flags: --warn-unreachable
x: complex = 1
reveal_type(x) # N: Revealed type is "builtins.complex"
if isinstance(x, (int, float)):
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float]"
else:
reveal_type(x) # N: Revealed type is "builtins.complex"
[builtins fixtures/primitives.pyi]

[case testIntersectionUsingPromotion3]
# flags: --warn-unreachable
x: object
if isinstance(x, int) and isinstance(x, complex):
reveal_type(x) # N: Revealed type is "builtins.int"
if isinstance(x, complex) and isinstance(x, int):
reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/primitives.pyi]

[case testIntersectionUsingPromotion4]
# flags: --warn-unreachable
x: object
if isinstance(x, int):
if isinstance(x, complex):
reveal_type(x) # N: Revealed type is "builtins.int"
else:
reveal_type(x) # N: Revealed type is "builtins.int"
if isinstance(x, complex):
if isinstance(x, int):
reveal_type(x) # N: Revealed type is "builtins.int"
else:
reveal_type(x) # N: Revealed type is "builtins.complex"
[builtins fixtures/primitives.pyi]

[case testIntersectionUsingPromotion5]
# flags: --warn-unreachable
from typing import Union

x: Union[float, complex]
if isinstance(x, int):
# Most likely this was an error, but we still type-check this branch
reveal_type(x) # N: Revealed type is "<nothing>"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, I'd expected this to also give int. What happens if x is a union with an unrelated type, like complex | str?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be also <nothing>. It is not easy to change. I however discovered couple more bugs while playing with this, so I will try to figure out something.

else:
reveal_type(x) # N: Revealed type is "Union[builtins.float, builtins.complex]"
[builtins fixtures/primitives.pyi]
8 changes: 6 additions & 2 deletions test-data/unit/fixtures/primitives.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# builtins stub with non-generic primitive types
from typing import Generic, TypeVar, Sequence, Iterator, Mapping, Iterable, overload
from typing import Generic, TypeVar, Sequence, Iterator, Mapping, Iterable, Tuple, Union

T = TypeVar('T')
V = TypeVar('V')
Expand All @@ -20,7 +20,9 @@ class int:
def __rmul__(self, x: int) -> int: pass
class float:
def __float__(self) -> float: pass
class complex: pass
def __add__(self, x: float) -> float: ...
class complex:
def __add__(self, x: complex) -> complex: ...
class bool(int): pass
class str(Sequence[str]):
def __add__(self, s: str) -> str: pass
Expand Down Expand Up @@ -63,3 +65,5 @@ class range(Sequence[int]):
def __getitem__(self, i: int) -> int: pass
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass

def isinstance(x: object, t: Union[type, Tuple]) -> bool: pass