Skip to content

Commit 7237831

Browse files
authored
[mypyc] (Re-)Support iterating over an Union of dicts (#14713)
An optimization to make iterating over dict.keys(), dict.values() and dict.items() faster caused mypyc to crash while compiling a Union of dictionaries. This commit fixes the optimization helpers to properly handle unions. irbuild.Builder.get_dict_base_type() now returns list[Instance] with the union items. In the common case we don't have a union, a single-element list is returned. And get_dict_key_type() and get_dict_value_type() will now build a simplified RUnion as needed. Fixes mypyc/mypyc#965 and probably #14694.
1 parent 0bbeab8 commit 7237831

File tree

3 files changed

+83
-12
lines changed

3 files changed

+83
-12
lines changed

mypyc/codegen/literals.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, FrozenSet, List, Tuple, Union, cast
3+
from typing import Any, FrozenSet, List, Tuple, Union, cast
44
from typing_extensions import Final
55

66
# Supported Python literal types. All tuple / frozenset items must have supported
@@ -151,8 +151,7 @@ def _encode_collection_values(
151151
<length of the second collection>
152152
...
153153
"""
154-
# FIXME: https://github.com/mypyc/mypyc/issues/965
155-
value_by_index = {index: value for value, index in cast(Dict[Any, int], values).items()}
154+
value_by_index = {index: value for value, index in values.items()}
156155
result = []
157156
count = len(values)
158157
result.append(str(count))

mypyc/irbuild/builder.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -879,23 +879,39 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType:
879879
else:
880880
return self.type_to_rtype(target_type.args[0])
881881

882-
def get_dict_base_type(self, expr: Expression) -> Instance:
882+
def get_dict_base_type(self, expr: Expression) -> list[Instance]:
883883
"""Find dict type of a dict-like expression.
884884
885885
This is useful for dict subclasses like SymbolTable.
886886
"""
887887
target_type = get_proper_type(self.types[expr])
888-
assert isinstance(target_type, Instance), target_type
889-
dict_base = next(base for base in target_type.type.mro if base.fullname == "builtins.dict")
890-
return map_instance_to_supertype(target_type, dict_base)
888+
if isinstance(target_type, UnionType):
889+
types = [get_proper_type(item) for item in target_type.items]
890+
else:
891+
types = [target_type]
892+
893+
dict_types = []
894+
for t in types:
895+
assert isinstance(t, Instance), t
896+
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
897+
dict_types.append(map_instance_to_supertype(t, dict_base))
898+
return dict_types
891899

892900
def get_dict_key_type(self, expr: Expression) -> RType:
893-
dict_base_type = self.get_dict_base_type(expr)
894-
return self.type_to_rtype(dict_base_type.args[0])
901+
dict_base_types = self.get_dict_base_type(expr)
902+
if len(dict_base_types) == 1:
903+
return self.type_to_rtype(dict_base_types[0].args[0])
904+
else:
905+
rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types]
906+
return RUnion.make_simplified_union(rtypes)
895907

896908
def get_dict_value_type(self, expr: Expression) -> RType:
897-
dict_base_type = self.get_dict_base_type(expr)
898-
return self.type_to_rtype(dict_base_type.args[1])
909+
dict_base_types = self.get_dict_base_type(expr)
910+
if len(dict_base_types) == 1:
911+
return self.type_to_rtype(dict_base_types[0].args[1])
912+
else:
913+
rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types]
914+
return RUnion.make_simplified_union(rtypes)
899915

900916
def get_dict_item_type(self, expr: Expression) -> RType:
901917
key_type = self.get_dict_key_type(expr)

mypyc/test-data/irbuild-dict.test

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,17 @@ L0:
218218
return r2
219219

220220
[case testDictIterationMethods]
221-
from typing import Dict
221+
from typing import Dict, Union
222222
def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None:
223223
for v in d1.values():
224224
if v in d2:
225225
return
226226
for k, v in d2.items():
227227
d2[k] += v
228+
def union_of_dicts(d: Union[Dict[str, int], Dict[str, str]]) -> None:
229+
new = {}
230+
for k, v in d.items():
231+
new[k] = int(v)
228232
[out]
229233
def print_dict_methods(d1, d2):
230234
d1, d2 :: dict
@@ -314,6 +318,58 @@ L11:
314318
r34 = CPy_NoErrOccured()
315319
L12:
316320
return 1
321+
def union_of_dicts(d):
322+
d, r0, new :: dict
323+
r1 :: short_int
324+
r2 :: native_int
325+
r3 :: short_int
326+
r4 :: object
327+
r5 :: tuple[bool, short_int, object, object]
328+
r6 :: short_int
329+
r7 :: bool
330+
r8, r9 :: object
331+
r10 :: str
332+
r11 :: union[int, str]
333+
k :: str
334+
v :: union[int, str]
335+
r12, r13 :: object
336+
r14 :: int
337+
r15 :: object
338+
r16 :: int32
339+
r17, r18, r19 :: bit
340+
L0:
341+
r0 = PyDict_New()
342+
new = r0
343+
r1 = 0
344+
r2 = PyDict_Size(d)
345+
r3 = r2 << 1
346+
r4 = CPyDict_GetItemsIter(d)
347+
L1:
348+
r5 = CPyDict_NextItem(r4, r1)
349+
r6 = r5[1]
350+
r1 = r6
351+
r7 = r5[0]
352+
if r7 goto L2 else goto L4 :: bool
353+
L2:
354+
r8 = r5[2]
355+
r9 = r5[3]
356+
r10 = cast(str, r8)
357+
r11 = cast(union[int, str], r9)
358+
k = r10
359+
v = r11
360+
r12 = load_address PyLong_Type
361+
r13 = PyObject_CallFunctionObjArgs(r12, v, 0)
362+
r14 = unbox(int, r13)
363+
r15 = box(int, r14)
364+
r16 = CPyDict_SetItem(new, k, r15)
365+
r17 = r16 >= 0 :: signed
366+
L3:
367+
r18 = CPyDict_CheckSize(d, r3)
368+
goto L1
369+
L4:
370+
r19 = CPy_NoErrOccured()
371+
L5:
372+
return 1
317373

318374
[case testDictLoadAddress]
319375
def f() -> None:

0 commit comments

Comments
 (0)