Skip to content

Commit 1a82e76

Browse files
committed
binder experiment
1 parent 8e3c99a commit 1a82e76

4 files changed

Lines changed: 100 additions & 51 deletions

File tree

mypy/binder.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
ProperType,
2929
TupleType,
3030
Type,
31-
TypeOfAny,
3231
TypeType,
3332
TypeVarType,
3433
UnionType,
@@ -73,6 +72,7 @@ class Frame:
7372
def __init__(self, id: int, conditional_frame: bool = False) -> None:
7473
self.id = id
7574
self.types: dict[Key, CurrentType] = {}
75+
self.conditionally_narrowed_keys: set[Key] = set()
7676
self.unreachable = False
7777
self.conditional_frame = conditional_frame
7878
self.suppress_unreachable_warnings = False
@@ -301,8 +301,6 @@ def update_from_options(self, frames: list[Frame]) -> bool:
301301
"""Update the frame to reflect that each key will be updated
302302
as in one of the frames. Return whether any item changes.
303303
304-
If a key is declared as AnyType, only update it if all the
305-
options are the same.
306304
"""
307305
all_reachable = all(not f.unreachable for f in frames)
308306
if not all_reachable:
@@ -328,9 +326,9 @@ def update_from_options(self, frames: list[Frame]) -> bool:
328326
# know anything about key in at least one possible frame.
329327
continue
330328

331-
resulting_values = [x for x in resulting_values if x is not None]
329+
filtered_resulting_values = [x for x in resulting_values if x is not None]
332330

333-
if all_reachable and all(not x.from_assignment for x in resulting_values):
331+
if all_reachable and all(not x.from_assignment for x in filtered_resulting_values):
334332
# Do not synthesize a new type if we encountered a conditional block
335333
# (if, while or match-case) without assignments.
336334
# See check-isinstance.test::testNoneCheckDoesNotMakeTypeVarOptional
@@ -342,57 +340,43 @@ def update_from_options(self, frames: list[Frame]) -> bool:
342340
# a micro-optimization for --allow-redefinition-new.
343341
seen_types = set()
344342
resulting_types = []
345-
for rv in resulting_values:
346-
assert rv is not None
343+
for rv in filtered_resulting_values:
347344
if rv.type in seen_types:
348345
continue
349346
resulting_types.append(rv.type)
350347
seen_types.add(rv.type)
351348

352-
type = resulting_types[0]
353-
declaration_type = get_proper_type(self.declarations.get(key))
354-
if isinstance(declaration_type, AnyType):
355-
# At this point resulting values can't contain None, see continue above
356-
if not all(is_same_type(type, t) for t in resulting_types[1:]):
357-
type = AnyType(TypeOfAny.from_another_any, source_any=declaration_type)
349+
declaration_type = self.declarations.get(key)
350+
if len(resulting_types) == 1:
351+
# This is to avoid calling get_proper_type() unless needed, as this may
352+
# interfere with our (hacky) TypeGuard support.
353+
type = resulting_types[0]
358354
else:
359-
possible_types = []
360-
for t in resulting_types:
361-
assert t is not None
362-
possible_types.append(t)
363-
if len(possible_types) == 1:
364-
# This is to avoid calling get_proper_type() unless needed, as this may
365-
# interfere with our (hacky) TypeGuard support.
366-
type = possible_types[0]
367-
else:
368-
type = make_simplified_union(possible_types)
369-
# Legacy guard for corner case when the original type is TypeVarType.
370-
if isinstance(declaration_type, TypeVarType) and not is_subtype(
371-
type, declaration_type
372-
):
373-
type = declaration_type
374-
# Try simplifying resulting type for unions involving variadic tuples.
375-
# Technically, everything is still valid without this step, but if we do
376-
# not do this, this may create long unions after exiting an if check like:
377-
# x: tuple[int, ...]
378-
# if len(x) < 10:
379-
# ...
380-
# We want the type of x to be tuple[int, ...] after this block (if it is
381-
# still equivalent to such type).
382-
if isinstance(type, UnionType):
383-
type = collapse_variadic_union(type)
384-
if (
385-
old_semantics
386-
and isinstance(type, ProperType)
387-
and isinstance(type, UnionType)
388-
):
389-
# Simplify away any extra Any's that were added to the declared
390-
# type when popping a frame.
391-
simplified = UnionType.make_union(
392-
[t for t in type.items if not isinstance(get_proper_type(t), AnyType)]
393-
)
394-
if simplified == self.declarations[key]:
395-
type = simplified
355+
type = make_simplified_union(resulting_types)
356+
# Legacy guard for corner case when the original type is TypeVarType.
357+
proper_declaration_type = get_proper_type(declaration_type)
358+
if isinstance(proper_declaration_type, TypeVarType) and not is_subtype(
359+
type, proper_declaration_type
360+
):
361+
type = proper_declaration_type
362+
# Try simplifying resulting type for unions involving variadic tuples.
363+
# Technically, everything is still valid without this step, but if we do
364+
# not do this, this may create long unions after exiting an if check like:
365+
# x: tuple[int, ...]
366+
# if len(x) < 10:
367+
# ...
368+
# We want the type of x to be tuple[int, ...] after this block (if it is
369+
# still equivalent to such type).
370+
if isinstance(type, UnionType):
371+
type = collapse_variadic_union(type)
372+
if old_semantics and isinstance(type, ProperType) and isinstance(type, UnionType):
373+
# Simplify away any extra Any's that were added to the declared
374+
# type when popping a frame.
375+
simplified = UnionType.make_union(
376+
[t for t in type.items if not isinstance(get_proper_type(t), AnyType)]
377+
)
378+
if simplified == self.declarations[key]:
379+
type = simplified
396380
if (
397381
current_value is None
398382
or not is_same_type(type, current_value.type)
@@ -482,6 +466,8 @@ def assign_type(self, expr: Expression, type: Type, declared_type: Type | None)
482466
return
483467
if not literal(expr):
484468
return
469+
key = literal_hash(expr)
470+
assert key is not None
485471
self.invalidate_dependencies(expr)
486472

487473
if declared_type is None:
@@ -530,6 +516,8 @@ def assign_type(self, expr: Expression, type: Type, declared_type: Type | None)
530516
# has an explicit `Any` type annotation.
531517
if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.is_inferred:
532518
self.put(expr, type)
519+
elif any(key in f.conditionally_narrowed_keys for f in self.frames):
520+
self.put(expr, type)
533521
else:
534522
self.put(expr, declared_type)
535523
else:
@@ -553,6 +541,16 @@ def invalidate_dependencies(self, expr: BindableExpression) -> None:
553541
for dep in self.dependencies.get(key, set()):
554542
self._cleanse_key(dep)
555543

544+
def record_conditional_type_map(self, type_map: dict[Expression, Type] | None) -> None:
545+
"""Record expressions that are mentioned by the active conditional."""
546+
if type_map is None:
547+
return
548+
for expr in type_map:
549+
if self.can_put_directly(expr):
550+
key = literal_hash(expr)
551+
assert key is not None
552+
self.frames[-1].conditionally_narrowed_keys.add(key)
553+
556554
def allow_jump(self, index: int) -> None:
557555
# self.frames and self.options_on_return have different lengths
558556
# so make sure the index is positive

mypy/checker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5172,6 +5172,8 @@ def visit_if_stmt(self, s: IfStmt) -> None:
51725172
self.msg.deleted_as_rvalue(t, s)
51735173

51745174
if_map, else_map = self.find_isinstance_check(e)
5175+
self.binder.record_conditional_type_map(if_map)
5176+
self.binder.record_conditional_type_map(else_map)
51755177

51765178
s.unreachable_else = is_unreachable_map(else_map)
51775179

test-data/unit/check-isinstance.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2102,7 +2102,7 @@ def narrow_any_to_str_then_reassign_to_int() -> None:
21022102
if isinstance(v, str):
21032103
reveal_type(v) # N: Revealed type is "builtins.str"
21042104
v = 2
2105-
reveal_type(v) # N: Revealed type is "Any"
2105+
reveal_type(v) # N: Revealed type is "builtins.int"
21062106
[builtins fixtures/isinstance.pyi]
21072107

21082108
[case testIsinstanceNarrowAnyImplicit]

test-data/unit/check-narrowing.test

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3131,6 +3131,55 @@ def main(key: str):
31313131
reveal_type(existing_value_type(Box2("str"))) # N: Revealed type is "__main__.TupleLike[builtins.str] | Any"
31323132
[builtins fixtures/tuple.pyi]
31333133

3134+
[case testNarrowingDynamicAnyMemberAssignment]
3135+
from enum import Enum
3136+
from typing import Any
3137+
3138+
class Fruit(Enum):
3139+
APPLE = 1
3140+
ORANGE = 2
3141+
3142+
def f(a: Any, x: Any, y: Any) -> None:
3143+
if x is not list:
3144+
x = []
3145+
reveal_type(x) # N: Revealed type is "builtins.list[Any] | Overload(def [T] () -> builtins.list[T`1], def [T] (x: typing.Iterable[T`1]) -> builtins.list[T`1])"
3146+
3147+
if y is list:
3148+
pass
3149+
else:
3150+
y = []
3151+
reveal_type(y) # N: Revealed type is "Overload(def [T] () -> builtins.list[T`1], def [T] (x: typing.Iterable[T`1]) -> builtins.list[T`1]) | builtins.list[Any]"
3152+
3153+
if a.foo is not list:
3154+
a.foo = []
3155+
reveal_type(a.foo) # N: Revealed type is "builtins.list[Any] | Overload(def [T] () -> builtins.list[T`1], def [T] (x: typing.Iterable[T`1]) -> builtins.list[T`1])"
3156+
a.foo.append(x) # E: Missing positional argument "self" in call to "append" of "list"
3157+
3158+
if a.baz is list:
3159+
pass
3160+
else:
3161+
a.baz = []
3162+
reveal_type(a.baz) # N: Revealed type is "Overload(def [T] () -> builtins.list[T`1], def [T] (x: typing.Iterable[T`1]) -> builtins.list[T`1]) | builtins.list[Any]"
3163+
3164+
if a.bar is not Fruit.APPLE:
3165+
a.bar = []
3166+
reveal_type(a.bar) # N: Revealed type is "builtins.list[Any] | Literal[__main__.Fruit.APPLE]"
3167+
[builtins fixtures/list.pyi]
3168+
3169+
[case testNarrowingExplicitAnyMemberAssignment]
3170+
from typing import Any
3171+
3172+
class A:
3173+
foo: Any
3174+
3175+
def f(a: A) -> None:
3176+
if a.foo is not list:
3177+
a.foo = []
3178+
reveal_type(a.foo) # N: Revealed type is "builtins.list[Any] | Overload(def [T] () -> builtins.list[T`1], def [T] (x: typing.Iterable[T`1]) -> builtins.list[T`1])"
3179+
a.foo.append(1) # E: Missing positional argument "x" in call to "append" of "list" \
3180+
# E: Argument 1 to "append" of "list" has incompatible type "int"; expected "list[Never]"
3181+
[builtins fixtures/list.pyi]
3182+
31343183
[case testNarrowingCollections]
31353184
# flags: --strict-equality --warn-unreachable
31363185
from typing import cast

0 commit comments

Comments
 (0)