Skip to content

Commit 2b6d26c

Browse files
committed
cleaner
1 parent 9599484 commit 2b6d26c

1 file changed

Lines changed: 51 additions & 39 deletions

File tree

mypy/checker.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7060,7 +7060,6 @@ def narrow_type_by_identity_equality(
70607060
return if_map, else_map
70617061

70627062
def broaden_equality_target_type(self, current_type: Type, target_type: Type) -> Type:
7063-
"""Include closed-domain peers when narrowing object equality."""
70647063
current_type = get_proper_type(current_type)
70657064
if not (
70667065
isinstance(current_type, Instance) and current_type.type.fullname == "builtins.object"
@@ -9677,18 +9676,21 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
96779676
VALUE_EQUALITY_DOMAINS: Final = {**OPEN_VALUE_EQUALITY_DOMAINS, **CLOSED_VALUE_EQUALITY_DOMAINS}
96789677

96799678

9679+
class EqualityDomainInfo(NamedTuple):
9680+
type_names: set[str]
9681+
enum_type_names: set[str]
9682+
9683+
96809684
class EqualityValueInfo(NamedTuple):
9681-
enum_types: set[str]
9682-
value_domains: dict[str, set[str]]
9683-
has_non_enum: bool
9685+
domains: dict[str, EqualityDomainInfo]
96849686
is_top: bool
96859687

96869688

96879689
def closed_equality_domain_type_names(info: EqualityValueInfo) -> list[str]:
96889690
return [
96899691
fullname
96909692
for fullname, domain in CLOSED_VALUE_EQUALITY_DOMAINS.items()
9691-
if domain in info.value_domains
9693+
if domain in info.domains
96929694
]
96939695

96949696

@@ -9727,30 +9729,39 @@ def is_equality_ambiguous_for_narrowing(left: Type, right: Type) -> bool:
97279729
right_info = equality_value_info(right)
97289730

97299731
if left_info.is_top or right_info.is_top:
9732+
# Only open-domain enum values can make a top-like type ambiguous.
9733+
# Closed domains can be narrowed to their complete known set instead.
97309734
other_info = right_info if left_info.is_top else left_info
9731-
return bool(
9732-
other_info.enum_types
9733-
and other_info.value_domains.keys() & OPEN_VALUE_EQUALITY_DOMAIN_NAMES
9735+
return any(
9736+
domain in OPEN_VALUE_EQUALITY_DOMAIN_NAMES and domain_info.enum_type_names
9737+
for domain, domain_info in other_info.domains.items()
97349738
)
97359739

9736-
shared_domains = left_info.value_domains.keys() & right_info.value_domains.keys()
9740+
shared_domains = left_info.domains.keys() & right_info.domains.keys()
97379741
if not shared_domains:
97389742
return False
97399743

9740-
left_is_only_enum = bool(left_info.enum_types) and not left_info.has_non_enum
9741-
right_is_only_enum = bool(right_info.enum_types) and not right_info.has_non_enum
9742-
if left_is_only_enum and right_is_only_enum and left_info.enum_types == right_info.enum_types:
9743-
return False
9744-
9745-
if any(
9746-
left_info.value_domains[domain] != right_info.value_domains[domain]
9747-
for domain in shared_domains
9748-
):
9749-
return True
9750-
if not (left_info.enum_types or right_info.enum_types):
9751-
return False
9744+
for domain in shared_domains:
9745+
left_domain = left_info.domains[domain]
9746+
right_domain = right_info.domains[domain]
9747+
# Equality between two values from the same enum can still narrow by literal member.
9748+
if (
9749+
left_domain.enum_type_names
9750+
and left_domain.enum_type_names == right_domain.enum_type_names
9751+
and left_domain.type_names == left_domain.enum_type_names
9752+
and right_domain.type_names == right_domain.enum_type_names
9753+
):
9754+
continue
9755+
# Different domain-member types may compare equal, but nominal narrowing would
9756+
# otherwise treat them as disjoint.
9757+
if left_domain.type_names != right_domain.type_names:
9758+
return True
9759+
# Same domain-member types are only ambiguous if an enum value may compare equal to
9760+
# its underlying value type.
9761+
if left_domain.enum_type_names or right_domain.enum_type_names:
9762+
return True
97529763

9753-
return True
9764+
return False
97549765

97559766

97569767
def equality_value_info(t: Type) -> EqualityValueInfo:
@@ -9767,34 +9778,35 @@ def equality_value_info(t: Type) -> EqualityValueInfo:
97679778
return equality_value_info(t.fallback)
97689779
if isinstance(t, Instance):
97699780
if t.type.fullname == "builtins.object":
9770-
return EqualityValueInfo(set(), {}, has_non_enum=False, is_top=True)
9781+
return EqualityValueInfo({}, is_top=True)
97719782

9772-
value_domains = {}
9783+
enum_type_names = {t.type.fullname} if t.type.is_enum else set()
9784+
domains = {}
97739785
for base in t.type.mro:
97749786
if domain := VALUE_EQUALITY_DOMAINS.get(base.fullname):
9775-
value_domains[domain] = {t.type.fullname}
9787+
domains[domain] = EqualityDomainInfo({t.type.fullname}, enum_type_names)
97769788

9777-
enum_types = {t.type.fullname} if t.type.is_enum else set()
9778-
return EqualityValueInfo(
9779-
enum_types, value_domains, has_non_enum=not enum_types, is_top=False
9780-
)
9789+
return EqualityValueInfo(domains, is_top=False)
97819790
if isinstance(t, AnyType):
9782-
return EqualityValueInfo(set(), {}, has_non_enum=False, is_top=True)
9783-
return EqualityValueInfo(set(), {}, has_non_enum=False, is_top=False)
9791+
return EqualityValueInfo({}, is_top=True)
9792+
return EqualityValueInfo({}, is_top=False)
97849793

97859794

97869795
def combine_equality_value_info(infos: Iterable[EqualityValueInfo]) -> EqualityValueInfo:
9787-
enum_types: set[str] = set()
9788-
value_domains: dict[str, set[str]] = defaultdict(set)
9789-
has_non_enum = False
9796+
domains: dict[str, EqualityDomainInfo] = {}
97909797
is_top = False
97919798
for info in infos:
9792-
enum_types.update(info.enum_types)
9793-
for domain, type_names in info.value_domains.items():
9794-
value_domains[domain].update(type_names)
9795-
has_non_enum = has_non_enum or info.has_non_enum
9799+
for domain, domain_info in info.domains.items():
9800+
existing_domain_info = domains.get(domain)
9801+
if existing_domain_info is None:
9802+
domains[domain] = EqualityDomainInfo(
9803+
set(domain_info.type_names), set(domain_info.enum_type_names)
9804+
)
9805+
else:
9806+
existing_domain_info.type_names.update(domain_info.type_names)
9807+
existing_domain_info.enum_type_names.update(domain_info.enum_type_names)
97969808
is_top = is_top or info.is_top
9797-
return EqualityValueInfo(enum_types, dict(value_domains), has_non_enum, is_top)
9809+
return EqualityValueInfo(domains, is_top)
97989810

97999811

98009812
def is_typeddict_type_context(lvalue_type: Type) -> bool:

0 commit comments

Comments
 (0)