Skip to content

Commit bdbd932

Browse files
committed
Improve narrowing via more coercion
1 parent 73d3725 commit bdbd932

4 files changed

Lines changed: 32 additions & 35 deletions

File tree

mypy/checker.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6655,10 +6655,9 @@ def narrow_type_by_identity_equality(
66556655
continue # Handled later
66566656

66576657
expr_type = operand_types[i]
6658-
expanded_expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
6659-
66606658
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
66616659

6660+
expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
66626661
for j in expr_indices:
66636662
if i == j:
66646663
continue
@@ -6684,21 +6683,15 @@ def narrow_type_by_identity_equality(
66846683
continue
66856684

66866685
target = TypeRange(target_type, is_upper_bound=False)
6687-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
6686+
if_map, else_map = conditional_types_to_typemaps(
6687+
operands[i], *conditional_types(expr_type, [target])
6688+
)
66886689

6689-
if is_value_target:
6690-
if_map, else_map = conditional_types_to_typemaps(
6691-
operands[i], *conditional_types(expanded_expr_type, [target])
6692-
)
6693-
all_if_maps.append(if_map)
6694-
all_else_maps.append(else_map)
6695-
else:
6696-
if_map, else_map = conditional_types_to_typemaps(
6697-
operands[i], *conditional_types(expr_type, [target])
6698-
)
6690+
all_if_maps.append(if_map)
6691+
if is_target_for_value_narrowing(get_proper_type(target_type)):
66996692
# For type_targets, we cannot narrow in the negative case, so ignore else_map
67006693
# e.g. if (x: str | None) != (y: str), we cannot narrow x to None
6701-
all_if_maps.append(if_map)
6694+
all_else_maps.append(else_map)
67026695

67036696
# Handle narrowing for operands with custom __eq__ methods specially
67046697
# In most cases, we won't be able to do any narrowing
@@ -6718,9 +6711,7 @@ def narrow_type_by_identity_equality(
67186711
if should_coerce_literals:
67196712
target_type = coerce_to_literal(target_type)
67206713
target = TypeRange(target_type, is_upper_bound=False)
6721-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
6722-
6723-
if is_value_target:
6714+
if is_target_for_value_narrowing(get_proper_type(target_type)):
67246715
if_map, else_map = conditional_types_to_typemaps(
67256716
operands[i], *conditional_types(expr_type, [target])
67266717
)
@@ -6737,23 +6728,19 @@ def narrow_type_by_identity_equality(
67376728
if has_custom_eq_checks(expr_type):
67386729
or_if_maps.append({operands[i]: expr_type})
67396730

6731+
expr_type = coerce_to_literal(try_expanding_sum_type_to_union(expr_type, None))
67406732
for j in expr_indices:
67416733
if j in custom_eq_indices:
67426734
continue
67436735
target_type = operand_types[j]
67446736
if should_coerce_literals:
67456737
target_type = coerce_to_literal(target_type)
67466738
target = TypeRange(target_type, is_upper_bound=False)
6747-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
6748-
6749-
if is_value_target:
6750-
expr_type = coerce_to_literal(expr_type)
6751-
expr_type = try_expanding_sum_type_to_union(expr_type, None)
67526739
if_map, else_map = conditional_types_to_typemaps(
67536740
operands[i], *conditional_types(expr_type, [target], default=expr_type)
67546741
)
67556742
or_if_maps.append(if_map)
6756-
if is_value_target:
6743+
if is_target_for_value_narrowing(get_proper_type(target_type)):
67576744
or_else_maps.append(else_map)
67586745

67596746
all_if_maps.append(reduce_or_conditional_type_maps(or_if_maps))

test-data/unit/check-optional.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,11 @@ from typing import Optional
493493

494494
def main(x: Optional[str]):
495495
if x == 0:
496-
reveal_type(x) # N: Revealed type is "builtins.str | None"
496+
reveal_type(x) # E: Statement is unreachable
497497
else:
498498
reveal_type(x) # N: Revealed type is "builtins.str | None"
499499
if x is 0:
500-
reveal_type(x) # N: Revealed type is "builtins.str | None"
500+
reveal_type(x) # E: Statement is unreachable
501501
else:
502502
reveal_type(x) # N: Revealed type is "builtins.str | None"
503503
[builtins fixtures/ops.pyi]

test-data/unit/check-python310.test

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ match m:
5656
-- Value Pattern --
5757

5858
[case testMatchValuePatternNarrows]
59+
# flags: --warn-unreachable
5960
import b
6061
m: object
6162

@@ -66,6 +67,7 @@ match m:
6667
b: int
6768

6869
[case testMatchValuePatternAlreadyNarrower]
70+
# flags: --warn-unreachable
6971
import b
7072
m: bool
7173

@@ -76,27 +78,28 @@ match m:
7678
b: int
7779

7880
[case testMatchValuePatternIntersect]
81+
# flags: --warn-unreachable
7982
import b
8083

8184
class A: ...
8285
m: A
8386

8487
match m:
8588
case b.b:
86-
reveal_type(m) # N: Revealed type is "__main__.A"
89+
reveal_type(m) # E: Statement is unreachable
8790
[file b.py]
8891
class B: ...
8992
b: B
9093

9194
[case testMatchValuePatternUnreachable]
92-
# primitives are needed because otherwise mypy doesn't see that int and str are incompatible
95+
# flags: --warn-unreachable
9396
import b
9497

9598
m: int
9699

97100
match m:
98101
case b.b:
99-
reveal_type(m) # N: Revealed type is "builtins.int"
102+
reveal_type(m) # E: Statement is unreachable
100103
[file b.py]
101104
b: str
102105
[builtins fixtures/primitives.pyi]
@@ -2766,7 +2769,7 @@ def x() -> tuple[Literal["test"]]: ...
27662769

27672770
match x():
27682771
case (x,) if x == "test": # E: Incompatible types in capture pattern (pattern captures type "Literal['test']", variable has type "Callable[[], tuple[Literal['test']]]")
2769-
reveal_type(x) # N: Revealed type is "def () -> tuple[Literal['test']]"
2772+
reveal_type(x) # E: Statement is unreachable
27702773
case foo:
27712774
foo
27722775

test-data/unit/check-tuples.test

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,17 +1533,24 @@ reveal_type(x) # N: Revealed type is "tuple[builtins.int, builtins.int]"
15331533
[builtins fixtures/tuple.pyi]
15341534

15351535
[case testTupleOverlapDifferentTuples]
1536+
# flags: --warn-unreachable
15361537
from typing import Optional, Tuple
15371538
class A: pass
15381539
class B: pass
15391540

1540-
possibles: Tuple[int, Tuple[A]]
1541-
x: Optional[Tuple[B]]
1541+
def f1(possibles: Tuple[int, Tuple[A]], x: Optional[Tuple[B]]):
1542+
if x in possibles:
1543+
reveal_type(x) # E: Statement is unreachable
1544+
else:
1545+
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
1546+
1547+
class AA(A): pass
15421548

1543-
if x in possibles:
1544-
reveal_type(x) # N: Revealed type is "tuple[__main__.B]"
1545-
else:
1546-
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
1549+
def f2(possibles: Tuple[int, Tuple[A]], x: Optional[Tuple[AA]]):
1550+
if x in possibles:
1551+
reveal_type(x) # N: Revealed type is "tuple[__main__.AA]"
1552+
else:
1553+
reveal_type(x) # N: Revealed type is "tuple[__main__.AA] | None"
15471554

15481555
[builtins fixtures/tuple.pyi]
15491556

0 commit comments

Comments
 (0)