Skip to content

Commit 2a810c8

Browse files
committed
[mypyc] Fixes to union simplification
Fix crash related to unions in loops. The crash was introduced in #14363. Flatten nested unions before simplifying unions.
1 parent 86dad8a commit 2a810c8

File tree

4 files changed

+116
-17
lines changed

4 files changed

+116
-17
lines changed

mypyc/ir/rtypes.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,30 @@ def __init__(self, items: list[RType]) -> None:
797797
self.items_set = frozenset(items)
798798
self._ctype = "PyObject *"
799799

800+
@staticmethod
801+
def make_simplified_union(items: list[RType]) -> RType:
802+
"""Return a normalized union that covers the given items.
803+
804+
Flatten nested unions and remove duplicate items.
805+
806+
Overlapping items are *not* simplified. For example,
807+
[object, str] will not be simplified.
808+
"""
809+
items = flatten_nested_unions(items)
810+
assert items
811+
812+
# Remove duplicate items using set + list to preserve item order
813+
seen = set()
814+
new_items = []
815+
for item in items:
816+
if item not in seen:
817+
new_items.append(item)
818+
seen.add(item)
819+
if len(new_items) > 1:
820+
return RUnion(new_items)
821+
else:
822+
return new_items[0]
823+
800824
def accept(self, visitor: RTypeVisitor[T]) -> T:
801825
return visitor.visit_runion(self)
802826

@@ -823,6 +847,19 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RUnion:
823847
return RUnion(types)
824848

825849

850+
def flatten_nested_unions(types: list[RType]) -> list[RType]:
851+
if not any(isinstance(t, RUnion) for t in types):
852+
return types # Fast path
853+
854+
flat_items: list[RType] = []
855+
for t in types:
856+
if isinstance(t, RUnion):
857+
flat_items.extend(flatten_nested_unions(t.items))
858+
else:
859+
flat_items.append(t)
860+
return flat_items
861+
862+
826863
def optional_value_type(rtype: RType) -> RType | None:
827864
"""If rtype is the union of none_rprimitive and another type X, return X.
828865

mypyc/irbuild/builder.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
Type,
5454
TypeOfAny,
5555
UninhabitedType,
56+
UnionType,
5657
get_proper_type,
5758
)
5859
from mypy.util import split_target
@@ -85,6 +86,7 @@
8586
RInstance,
8687
RTuple,
8788
RType,
89+
RUnion,
8890
bitmap_rprimitive,
8991
c_int_rprimitive,
9092
c_pyssize_t_rprimitive,
@@ -864,8 +866,15 @@ def extract_int(self, e: Expression) -> int | None:
864866
return None
865867

866868
def get_sequence_type(self, expr: Expression) -> RType:
867-
target_type = get_proper_type(self.types[expr])
868-
assert isinstance(target_type, Instance)
869+
return self.get_sequence_type_from_type(self.types[expr])
870+
871+
def get_sequence_type_from_type(self, target_type: Type) -> RType:
872+
target_type = get_proper_type(target_type)
873+
if isinstance(target_type, UnionType):
874+
return RUnion.make_simplified_union(
875+
[self.get_sequence_type_from_type(item) for item in target_type.items]
876+
)
877+
assert isinstance(target_type, Instance), target_type
869878
if target_type.type.fullname == "builtins.str":
870879
return str_rprimitive
871880
else:

mypyc/irbuild/mapper.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,7 @@ def type_to_rtype(self, typ: Type | None) -> RType:
116116
elif isinstance(typ, NoneTyp):
117117
return none_rprimitive
118118
elif isinstance(typ, UnionType):
119-
# Remove redundant items using set + list to preserve item order
120-
seen = set()
121-
items = []
122-
for item in typ.items:
123-
rtype = self.type_to_rtype(item)
124-
if rtype not in seen:
125-
items.append(rtype)
126-
seen.add(rtype)
127-
if len(items) > 1:
128-
return RUnion(items)
129-
else:
130-
return items[0]
119+
return RUnion.make_simplified_union([self.type_to_rtype(item) for item in typ.items])
131120
elif isinstance(typ, AnyType):
132121
return object_rprimitive
133122
elif isinstance(typ, TypeType):

mypyc/test-data/irbuild-lists.test

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,14 +430,20 @@ L5:
430430
return 1
431431

432432
[case testSimplifyListUnion]
433-
from typing import List, Union
433+
from typing import List, Union, Optional
434434

435-
def f(a: Union[List[str], List[bytes], int]) -> int:
435+
def narrow(a: Union[List[str], List[bytes], int]) -> int:
436436
if isinstance(a, list):
437437
return len(a)
438438
return a
439+
def loop(a: Union[List[str], List[bytes]]) -> None:
440+
for x in a:
441+
pass
442+
def nested_union(a: Union[List[str], List[Optional[str]]]) -> None:
443+
for x in a:
444+
pass
439445
[out]
440-
def f(a):
446+
def narrow(a):
441447
a :: union[list, int]
442448
r0 :: object
443449
r1 :: int32
@@ -465,3 +471,61 @@ L1:
465471
L2:
466472
r8 = unbox(int, a)
467473
return r8
474+
def loop(a):
475+
a :: list
476+
r0 :: short_int
477+
r1 :: ptr
478+
r2 :: native_int
479+
r3 :: short_int
480+
r4 :: bit
481+
r5 :: object
482+
r6, x :: union[str, bytes]
483+
r7 :: short_int
484+
L0:
485+
r0 = 0
486+
L1:
487+
r1 = get_element_ptr a ob_size :: PyVarObject
488+
r2 = load_mem r1 :: native_int*
489+
keep_alive a
490+
r3 = r2 << 1
491+
r4 = r0 < r3 :: signed
492+
if r4 goto L2 else goto L4 :: bool
493+
L2:
494+
r5 = CPyList_GetItemUnsafe(a, r0)
495+
r6 = cast(union[str, bytes], r5)
496+
x = r6
497+
L3:
498+
r7 = r0 + 2
499+
r0 = r7
500+
goto L1
501+
L4:
502+
return 1
503+
def nested_union(a):
504+
a :: list
505+
r0 :: short_int
506+
r1 :: ptr
507+
r2 :: native_int
508+
r3 :: short_int
509+
r4 :: bit
510+
r5 :: object
511+
r6, x :: union[str, None]
512+
r7 :: short_int
513+
L0:
514+
r0 = 0
515+
L1:
516+
r1 = get_element_ptr a ob_size :: PyVarObject
517+
r2 = load_mem r1 :: native_int*
518+
keep_alive a
519+
r3 = r2 << 1
520+
r4 = r0 < r3 :: signed
521+
if r4 goto L2 else goto L4 :: bool
522+
L2:
523+
r5 = CPyList_GetItemUnsafe(a, r0)
524+
r6 = cast(union[str, None], r5)
525+
x = r6
526+
L3:
527+
r7 = r0 + 2
528+
r0 = r7
529+
goto L1
530+
L4:
531+
return 1

0 commit comments

Comments
 (0)