@@ -7060,7 +7060,6 @@ def narrow_type_by_identity_equality(
70607060 return if_map , else_map
70617061
70627062 def broaden_equality_target_type (self , current_type : Type , target_type : Type ) -> Type :
7063- """Include closed-domain peers when narrowing object equality."""
70647063 current_type = get_proper_type (current_type )
70657064 if not (
70667065 isinstance (current_type , Instance ) and current_type .type .fullname == "builtins.object"
@@ -9677,18 +9676,21 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
96779676VALUE_EQUALITY_DOMAINS : Final = {** OPEN_VALUE_EQUALITY_DOMAINS , ** CLOSED_VALUE_EQUALITY_DOMAINS }
96789677
96799678
9679+ class EqualityDomainInfo (NamedTuple ):
9680+ type_names : set [str ]
9681+ enum_type_names : set [str ]
9682+
9683+
96809684class EqualityValueInfo (NamedTuple ):
9681- enum_types : set [str ]
9682- value_domains : dict [str , set [str ]]
9683- has_non_enum : bool
9685+ domains : dict [str , EqualityDomainInfo ]
96849686 is_top : bool
96859687
96869688
96879689def closed_equality_domain_type_names (info : EqualityValueInfo ) -> list [str ]:
96889690 return [
96899691 fullname
96909692 for fullname , domain in CLOSED_VALUE_EQUALITY_DOMAINS .items ()
9691- if domain in info .value_domains
9693+ if domain in info .domains
96929694 ]
96939695
96949696
@@ -9727,30 +9729,39 @@ def is_equality_ambiguous_for_narrowing(left: Type, right: Type) -> bool:
97279729 right_info = equality_value_info (right )
97289730
97299731 if left_info .is_top or right_info .is_top :
9732+ # Only open-domain enum values can make a top-like type ambiguous.
9733+ # Closed domains can be narrowed to their complete known set instead.
97309734 other_info = right_info if left_info .is_top else left_info
9731- return bool (
9732- other_info . enum_types
9733- and other_info .value_domains . keys () & OPEN_VALUE_EQUALITY_DOMAIN_NAMES
9735+ return any (
9736+ domain in OPEN_VALUE_EQUALITY_DOMAIN_NAMES and domain_info . enum_type_names
9737+ for domain , domain_info in other_info .domains . items ()
97349738 )
97359739
9736- shared_domains = left_info .value_domains .keys () & right_info .value_domains .keys ()
9740+ shared_domains = left_info .domains .keys () & right_info .domains .keys ()
97379741 if not shared_domains :
97389742 return False
97399743
9740- left_is_only_enum = bool (left_info .enum_types ) and not left_info .has_non_enum
9741- right_is_only_enum = bool (right_info .enum_types ) and not right_info .has_non_enum
9742- if left_is_only_enum and right_is_only_enum and left_info .enum_types == right_info .enum_types :
9743- return False
9744-
9745- if any (
9746- left_info .value_domains [domain ] != right_info .value_domains [domain ]
9747- for domain in shared_domains
9748- ):
9749- return True
9750- if not (left_info .enum_types or right_info .enum_types ):
9751- return False
9744+ for domain in shared_domains :
9745+ left_domain = left_info .domains [domain ]
9746+ right_domain = right_info .domains [domain ]
9747+ # Equality between two values from the same enum can still narrow by literal member.
9748+ if (
9749+ left_domain .enum_type_names
9750+ and left_domain .enum_type_names == right_domain .enum_type_names
9751+ and left_domain .type_names == left_domain .enum_type_names
9752+ and right_domain .type_names == right_domain .enum_type_names
9753+ ):
9754+ continue
9755+ # Different domain-member types may compare equal, but nominal narrowing would
9756+ # otherwise treat them as disjoint.
9757+ if left_domain .type_names != right_domain .type_names :
9758+ return True
9759+ # Same domain-member types are only ambiguous if an enum value may compare equal to
9760+ # its underlying value type.
9761+ if left_domain .enum_type_names or right_domain .enum_type_names :
9762+ return True
97529763
9753- return True
9764+ return False
97549765
97559766
97569767def equality_value_info (t : Type ) -> EqualityValueInfo :
@@ -9767,34 +9778,35 @@ def equality_value_info(t: Type) -> EqualityValueInfo:
97679778 return equality_value_info (t .fallback )
97689779 if isinstance (t , Instance ):
97699780 if t .type .fullname == "builtins.object" :
9770- return EqualityValueInfo (set (), {}, has_non_enum = False , is_top = True )
9781+ return EqualityValueInfo ({} , is_top = True )
97719782
9772- value_domains = {}
9783+ enum_type_names = {t .type .fullname } if t .type .is_enum else set ()
9784+ domains = {}
97739785 for base in t .type .mro :
97749786 if domain := VALUE_EQUALITY_DOMAINS .get (base .fullname ):
9775- value_domains [domain ] = {t .type .fullname }
9787+ domains [domain ] = EqualityDomainInfo ( {t .type .fullname }, enum_type_names )
97769788
9777- enum_types = {t .type .fullname } if t .type .is_enum else set ()
9778- return EqualityValueInfo (
9779- enum_types , value_domains , has_non_enum = not enum_types , is_top = False
9780- )
9789+ return EqualityValueInfo (domains , is_top = False )
97819790 if isinstance (t , AnyType ):
9782- return EqualityValueInfo (set (), {}, has_non_enum = False , is_top = True )
9783- return EqualityValueInfo (set (), {}, has_non_enum = False , is_top = False )
9791+ return EqualityValueInfo ({} , is_top = True )
9792+ return EqualityValueInfo ({} , is_top = False )
97849793
97859794
97869795def combine_equality_value_info (infos : Iterable [EqualityValueInfo ]) -> EqualityValueInfo :
9787- enum_types : set [str ] = set ()
9788- value_domains : dict [str , set [str ]] = defaultdict (set )
9789- has_non_enum = False
9796+ domains : dict [str , EqualityDomainInfo ] = {}
97909797 is_top = False
97919798 for info in infos :
9792- enum_types .update (info .enum_types )
9793- for domain , type_names in info .value_domains .items ():
9794- value_domains [domain ].update (type_names )
9795- has_non_enum = has_non_enum or info .has_non_enum
9799+ for domain , domain_info in info .domains .items ():
9800+ existing_domain_info = domains .get (domain )
9801+ if existing_domain_info is None :
9802+ domains [domain ] = EqualityDomainInfo (
9803+ set (domain_info .type_names ), set (domain_info .enum_type_names )
9804+ )
9805+ else :
9806+ existing_domain_info .type_names .update (domain_info .type_names )
9807+ existing_domain_info .enum_type_names .update (domain_info .enum_type_names )
97969808 is_top = is_top or info .is_top
9797- return EqualityValueInfo (enum_types , dict ( value_domains ), has_non_enum , is_top )
9809+ return EqualityValueInfo (domains , is_top )
97989810
97999811
98009812def is_typeddict_type_context (lvalue_type : Type ) -> bool :
0 commit comments