Skip to content

Commit ee28183

Browse files
committed
Improve narrowing via more coercion
1 parent cc696ff commit ee28183

4 files changed

Lines changed: 32 additions & 35 deletions

File tree

mypy/checker.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6657,10 +6657,9 @@ def narrow_type_by_identity_equality(
66576657
continue # Handled later
66586658

66596659
expr_type = operand_types[i]
6660-
expanded_expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
6661-
66626660
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
66636661

6662+
expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
66646663
for j in expr_indices:
66656664
if i == j:
66666665
continue
@@ -6686,21 +6685,15 @@ def narrow_type_by_identity_equality(
66866685
continue
66876686

66886687
target = TypeRange(target_type, is_upper_bound=False)
6689-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
6688+
if_map, else_map = conditional_types_to_typemaps(
6689+
operands[i], *conditional_types(expr_type, [target])
6690+
)
66906691

6691-
if is_value_target:
6692-
if_map, else_map = conditional_types_to_typemaps(
6693-
operands[i], *conditional_types(expanded_expr_type, [target])
6694-
)
6695-
all_if_maps.append(if_map)
6696-
all_else_maps.append(else_map)
6697-
else:
6698-
if_map, else_map = conditional_types_to_typemaps(
6699-
operands[i], *conditional_types(expr_type, [target])
6700-
)
6692+
all_if_maps.append(if_map)
6693+
if is_target_for_value_narrowing(get_proper_type(target_type)):
67016694
# For type_targets, we cannot narrow in the negative case, so ignore else_map
67026695
# e.g. if (x: str | None) != (y: str), we cannot narrow x to None
6703-
all_if_maps.append(if_map)
6696+
all_else_maps.append(else_map)
67046697

67056698
# Handle narrowing for operands with custom __eq__ methods specially
67066699
# In most cases, we won't be able to do any narrowing
@@ -6720,9 +6713,7 @@ def narrow_type_by_identity_equality(
67206713
if should_coerce_literals:
67216714
target_type = coerce_to_literal(target_type)
67226715
target = TypeRange(target_type, is_upper_bound=False)
6723-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
6724-
6725-
if is_value_target:
6716+
if is_target_for_value_narrowing(get_proper_type(target_type)):
67266717
if_map, else_map = conditional_types_to_typemaps(
67276718
operands[i], *conditional_types(expr_type, [target])
67286719
)
@@ -6739,23 +6730,19 @@ def narrow_type_by_identity_equality(
67396730
if has_custom_eq_checks(expr_type):
67406731
or_if_maps.append({operands[i]: expr_type})
67416732

6733+
expr_type = coerce_to_literal(try_expanding_sum_type_to_union(expr_type, None))
67426734
for j in expr_indices:
67436735
if j in custom_eq_indices:
67446736
continue
67456737
target_type = operand_types[j]
67466738
if should_coerce_literals:
67476739
target_type = coerce_to_literal(target_type)
67486740
target = TypeRange(target_type, is_upper_bound=False)
6749-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
6750-
6751-
if is_value_target:
6752-
expr_type = coerce_to_literal(expr_type)
6753-
expr_type = try_expanding_sum_type_to_union(expr_type, None)
67546741
if_map, else_map = conditional_types_to_typemaps(
67556742
operands[i], *conditional_types(expr_type, [target], default=expr_type)
67566743
)
67576744
or_if_maps.append(if_map)
6758-
if is_value_target:
6745+
if is_target_for_value_narrowing(get_proper_type(target_type)):
67596746
or_else_maps.append(else_map)
67606747

67616748
all_if_maps.append(reduce_or_conditional_type_maps(or_if_maps))

test-data/unit/check-optional.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,11 @@ from typing import Optional
493493

494494
def main(x: Optional[str]):
495495
if x == 0:
496-
reveal_type(x) # N: Revealed type is "builtins.str | None"
496+
reveal_type(x) # E: Statement is unreachable
497497
else:
498498
reveal_type(x) # N: Revealed type is "builtins.str | None"
499499
if x is 0:
500-
reveal_type(x) # N: Revealed type is "builtins.str | None"
500+
reveal_type(x) # E: Statement is unreachable
501501
else:
502502
reveal_type(x) # N: Revealed type is "builtins.str | None"
503503
[builtins fixtures/ops.pyi]

test-data/unit/check-python310.test

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ match m:
5656
-- Value Pattern --
5757

5858
[case testMatchValuePatternNarrows]
59+
# flags: --warn-unreachable
5960
import b
6061
m: object
6162

@@ -66,6 +67,7 @@ match m:
6667
b: int
6768

6869
[case testMatchValuePatternAlreadyNarrower]
70+
# flags: --warn-unreachable
6971
import b
7072
m: bool
7173

@@ -76,27 +78,28 @@ match m:
7678
b: int
7779

7880
[case testMatchValuePatternIntersect]
81+
# flags: --warn-unreachable
7982
import b
8083

8184
class A: ...
8285
m: A
8386

8487
match m:
8588
case b.b:
86-
reveal_type(m) # N: Revealed type is "__main__.A"
89+
reveal_type(m) # E: Statement is unreachable
8790
[file b.py]
8891
class B: ...
8992
b: B
9093

9194
[case testMatchValuePatternUnreachable]
92-
# primitives are needed because otherwise mypy doesn't see that int and str are incompatible
95+
# flags: --warn-unreachable
9396
import b
9497

9598
m: int
9699

97100
match m:
98101
case b.b:
99-
reveal_type(m) # N: Revealed type is "builtins.int"
102+
reveal_type(m) # E: Statement is unreachable
100103
[file b.py]
101104
b: str
102105
[builtins fixtures/primitives.pyi]
@@ -2766,7 +2769,7 @@ def x() -> tuple[Literal["test"]]: ...
27662769

27672770
match x():
27682771
case (x,) if x == "test": # E: Incompatible types in capture pattern (pattern captures type "Literal['test']", variable has type "Callable[[], tuple[Literal['test']]]")
2769-
reveal_type(x) # N: Revealed type is "def () -> tuple[Literal['test']]"
2772+
reveal_type(x) # E: Statement is unreachable
27702773
case foo:
27712774
foo
27722775

test-data/unit/check-tuples.test

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,17 +1533,24 @@ reveal_type(x) # N: Revealed type is "tuple[builtins.int, builtins.int]"
15331533
[builtins fixtures/tuple.pyi]
15341534

15351535
[case testTupleOverlapDifferentTuples]
1536+
# flags: --warn-unreachable
15361537
from typing import Optional, Tuple
15371538
class A: pass
15381539
class B: pass
15391540

1540-
possibles: Tuple[int, Tuple[A]]
1541-
x: Optional[Tuple[B]]
1541+
def f1(possibles: Tuple[int, Tuple[A]], x: Optional[Tuple[B]]):
1542+
if x in possibles:
1543+
reveal_type(x) # E: Statement is unreachable
1544+
else:
1545+
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
1546+
1547+
class AA(A): pass
15421548

1543-
if x in possibles:
1544-
reveal_type(x) # N: Revealed type is "tuple[__main__.B]"
1545-
else:
1546-
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
1549+
def f2(possibles: Tuple[int, Tuple[A]], x: Optional[Tuple[AA]]):
1550+
if x in possibles:
1551+
reveal_type(x) # N: Revealed type is "tuple[__main__.AA]"
1552+
else:
1553+
reveal_type(x) # N: Revealed type is "tuple[__main__.AA] | None"
15471554

15481555
[builtins fixtures/tuple.pyi]
15491556

0 commit comments

Comments
 (0)