@@ -6633,7 +6633,7 @@ def narrow_type_by_identity_equality(
66336633 if operator in {"is" , "is not" }:
66346634 is_target_for_value_narrowing = is_singleton_identity_type
66356635 should_coerce_literals = True
6636- should_narrow_by_identity_equality = True
6636+ custom_eq_indices = set ()
66376637 enum_comparison_is_ambiguous = False
66386638
66396639 elif operator in {"==" , "!=" }:
@@ -6646,19 +6646,11 @@ def narrow_type_by_identity_equality(
66466646 should_coerce_literals = True
66476647 break
66486648
6649- expr_types = [operand_types [i ] for i in expr_indices ]
6650- should_narrow_by_identity_equality = not any (map (has_custom_eq_checks , expr_types ))
6649+ custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks (operand_types [i ])}
66516650 enum_comparison_is_ambiguous = True
66526651 else :
66536652 raise AssertionError
66546653
6655- if not should_narrow_by_identity_equality :
6656- # This is a bit of a legacy code path that might be a little unsound since it ignores
6657- # custom __eq__. We should see if we can get rid of it in favour of `return {}, {}`
6658- return self .refine_away_none_in_comparison (
6659- operands , operand_types , expr_indices , narrowable_indices
6660- )
6661-
66626654 value_targets = []
66636655 type_targets = []
66646656 for i in expr_indices :
@@ -6670,6 +6662,10 @@ def narrow_type_by_identity_equality(
66706662 # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
66716663 # See testMatchEnumSingleChoice
66726664 expr_type = coerce_to_literal (expr_type )
6665+ if i in custom_eq_indices :
6666+ # We can't use types with custom __eq__ as targets for narrowing
6667+ # E.g. if (x: int | None) == (y: CustomEq | None), we cannot narrow x to None
6668+ continue
66736669 if is_target_for_value_narrowing (get_proper_type (expr_type )):
66746670 value_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
66756671 else :
@@ -6681,7 +6677,11 @@ def narrow_type_by_identity_equality(
66816677 for i in expr_indices :
66826678 if i not in narrowable_indices :
66836679 continue
6684- expr_type = coerce_to_literal (operand_types [i ])
6680+ if i in custom_eq_indices :
6681+ # Handled later
6682+ continue
6683+ expr_type = operand_types [i ]
6684+ expr_type = coerce_to_literal (expr_type )
66856685 expr_type = try_expanding_sum_type_to_union (expr_type , None )
66866686 expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
66876687 for j , target in value_targets :
@@ -6702,6 +6702,9 @@ def narrow_type_by_identity_equality(
67026702 for i in expr_indices :
67036703 if i not in narrowable_indices :
67046704 continue
6705+ if i in custom_eq_indices :
6706+ # Handled later
6707+ continue
67056708 expr_type = operand_types [i ]
67066709 for j , target in type_targets :
67076710 if i == j :
@@ -6710,9 +6713,63 @@ def narrow_type_by_identity_equality(
67106713 operands [i ], * conditional_types (expr_type , [target ])
67116714 )
67126715 if if_map :
6713- else_map = {} # this is the big difference compared to the above
6716+ # For type_targets, we cannot narrow in the negative case
6717+ # e.g. if (x: str | None) != (y: str), we cannot narrow x to None
6718+ else_map = {}
67146719 partial_type_maps .append ((if_map , else_map ))
67156720
6721+ for i in custom_eq_indices :
6722+ if i not in narrowable_indices :
6723+ continue
6724+ union_expr_type = operand_types [i ]
6725+ if not isinstance (union_expr_type , UnionType ):
6726+ expr_type = union_expr_type
6727+ for j , target in value_targets :
6728+ _if_map , else_map = conditional_types_to_typemaps (
6729+ operands [i ], * conditional_types (expr_type , [target ])
6730+ )
6731+ if else_map :
6732+ partial_type_maps .append (({}, else_map ))
6733+ continue
6734+
6735+ or_if_maps : list [TypeMap ] = []
6736+ or_else_maps : list [TypeMap ] = []
6737+ for expr_type in union_expr_type .items :
6738+ if has_custom_eq_checks (expr_type ):
6739+ or_if_maps .append ({operands [i ]: expr_type })
6740+
6741+ for j in expr_indices :
6742+ if j in custom_eq_indices :
6743+ continue
6744+ target_type = operand_types [j ]
6745+ if should_coerce_literals :
6746+ target_type = coerce_to_literal (target_type )
6747+ target = TypeRange (target_type , is_upper_bound = False )
6748+ is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
6749+
6750+ if is_value_target :
6751+ expr_type = coerce_to_literal (expr_type )
6752+ expr_type = try_expanding_sum_type_to_union (expr_type , None )
6753+ if_map , else_map = conditional_types_to_typemaps (
6754+ operands [i ], * conditional_types (expr_type , [target ], default = expr_type )
6755+ )
6756+ or_if_maps .append (if_map )
6757+ if is_value_target :
6758+ or_else_maps .append (else_map )
6759+
6760+ final_if_map = {}
6761+ final_else_map = {}
6762+ if or_if_maps :
6763+ final_if_map = or_if_maps [0 ]
6764+ for if_map in or_if_maps [1 :]:
6765+ final_if_map = or_conditional_maps (final_if_map , if_map )
6766+ if or_else_maps :
6767+ final_else_map = or_else_maps [0 ]
6768+ for else_map in or_else_maps [1 :]:
6769+ final_else_map = or_conditional_maps (final_else_map , else_map )
6770+
6771+ partial_type_maps .append ((final_if_map , final_else_map ))
6772+
67166773 for i in expr_indices :
67176774 type_expr = operands [i ]
67186775 if (
@@ -6934,49 +6991,6 @@ def _propagate_walrus_assignments(
69346991 return parent_expr
69356992 return expr
69366993
6937- def refine_away_none_in_comparison (
6938- self ,
6939- operands : list [Expression ],
6940- operand_types : list [Type ],
6941- chain_indices : list [int ],
6942- narrowable_operand_indices : AbstractSet [int ],
6943- ) -> tuple [TypeMap , TypeMap ]:
6944- """Produces conditional type maps refining away None in an identity/equality chain.
6945-
6946- For more details about what the different arguments mean, see the
6947- docstring of 'narrow_type_by_identity_equality' up above.
6948- """
6949-
6950- non_optional_types = []
6951- for i in chain_indices :
6952- typ = operand_types [i ]
6953- if not is_overlapping_none (typ ):
6954- non_optional_types .append (typ )
6955-
6956- if_map , else_map = {}, {}
6957-
6958- if not non_optional_types or (len (non_optional_types ) != len (chain_indices )):
6959-
6960- # Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
6961- # convenient but is strictly not type-safe):
6962- for i in narrowable_operand_indices :
6963- expr_type = operand_types [i ]
6964- if not is_overlapping_none (expr_type ):
6965- continue
6966- if any (is_overlapping_erased_types (expr_type , t ) for t in non_optional_types ):
6967- if_map [operands [i ]] = remove_optional (expr_type )
6968-
6969- # Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
6970- # so type-safe but less convenient, because e.g. `Optional[A] == None` still results
6971- # in `Optional[A]`):
6972- if any (isinstance (get_proper_type (ot ), NoneType ) for ot in operand_types ):
6973- for i in narrowable_operand_indices :
6974- expr_type = operand_types [i ]
6975- if is_overlapping_none (expr_type ):
6976- else_map [operands [i ]] = remove_optional (expr_type )
6977-
6978- return if_map , else_map
6979-
69806994 def is_len_of_tuple (self , expr : Expression ) -> bool :
69816995 """Is this expression a `len(x)` call where x is a tuple or union of tuples?"""
69826996 if not isinstance (expr , CallExpr ):
0 commit comments