|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import itertools |
5 | 6 | from collections import defaultdict |
6 | 7 | from typing import Final, NamedTuple |
7 | 8 |
|
@@ -242,79 +243,92 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: |
242 | 243 | if star_position is not None: |
243 | 244 | required_patterns -= 1 |
244 | 245 |
|
245 | | - # |
246 | | - # get inner types of original type |
247 | | - # 1. Go through all possible types and filter to only those which are sequences that could match that number of items |
248 | | - # 2. If there is exactly one tuple left with an unpack, then use that type and the unpack index |
249 | | - # 3. Otherwise, take the product of the item types so that each index can have a unique type. For tuples with unpack |
250 | | - # fallback to merging all of their types for each index since we can't handle multiple unpacked items at once yet. |
| 246 | + # 1. Go through all possible types and filter to only those which are sequences that |
| 247 | + # could match that number of items |
| 248 | + # 2. If there is exactly one tuple left with an unpack, then use that type |
| 249 | + # and the unpack index |
| 250 | + # 3. Otherwise, take the product of the item types so that each index can have a |
| 251 | + # unique type. For tuples with unpack fallback to merging all of their types |
| 252 | + # for each index, since we can't handle multiple unpacked items at once yet. |
251 | 253 |
|
252 | 254 | # Whether we have encountered a type that we don't know how to handle in the union |
253 | 255 | unknown_type = False |
254 | 256 | # A list of types that could match any of the items in the sequence. |
255 | 257 | sequence_types: list[Type] = [] |
256 | 258 | # A list of tuple types that could match the sequence, per index |
257 | 259 | tuple_types: list[list[Type]] = [] |
258 | | - # A list of all the unpack tuple types that we encountered, each containing the tuple type, unpack index, and union index |
| 260 | + # A list of all the unpack tuple types that we encountered, each containing the |
| 261 | + # tuple type, unpack index, and union index |
259 | 262 | unpack_tuple_types: list[tuple[TupleType, int, int]] = [] |
260 | 263 | for i, t in enumerate( |
261 | 264 | current_type.items if isinstance(current_type, UnionType) else [current_type] |
262 | 265 | ): |
263 | 266 | t = get_proper_type(t) |
264 | 267 | if isinstance(t, TupleType): |
265 | | - t_items = list(t.items) |
266 | | - unpack_index = find_unpack_in_list(t_items) |
| 268 | + tuple_items = list(t.items) |
| 269 | + unpack_index = find_unpack_in_list(tuple_items) |
267 | 270 | if unpack_index is None: |
268 | | - size_diff = len(t_items) - required_patterns |
| 271 | + size_diff = len(tuple_items) - required_patterns |
269 | 272 | if size_diff < 0: |
270 | 273 | continue |
271 | | - elif size_diff > 0 and star_position is None: |
| 274 | + if size_diff > 0 and star_position is None: |
272 | 275 | continue |
273 | | - elif not size_diff and star_position is not None: |
274 | | - t_items.append(UninhabitedType()) |
275 | | - tuple_types.append(t_items) |
| 276 | + if not size_diff and star_position is not None: |
| 277 | + # Above we subtract from required_patterns if star_position is not None |
| 278 | + tuple_items.append(UninhabitedType()) |
| 279 | + tuple_types.append(tuple_items) |
276 | 280 | else: |
277 | 281 | normalized_inner_types = [] |
278 | | - for it in t_items: |
| 282 | + for it in tuple_items: |
279 | 283 | # Unfortunately, it is not possible to "split" the TypeVarTuple |
280 | 284 | # into individual items, so we just use its upper bound for the whole |
281 | 285 | # analysis instead. |
282 | 286 | if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType): |
283 | 287 | it = UnpackType(it.type.upper_bound) |
284 | 288 | normalized_inner_types.append(it) |
285 | | - t_items = normalized_inner_types |
286 | | - t = t.copy_modified(items=normalized_inner_types) |
287 | | - if len(t_items) - 1 > required_patterns and star_position is None: |
| 289 | + if ( |
| 290 | + len(normalized_inner_types) - 1 > required_patterns |
| 291 | + and star_position is None |
| 292 | + ): |
288 | 293 | continue |
| 294 | + t = t.copy_modified(items=normalized_inner_types) |
289 | 295 | unpack_tuple_types.append((t, unpack_index, i)) |
290 | | - # add the combined tuple type to the sequence types in case we have multiple unpacks we want to combine them all |
| 296 | + # In case we have multiple unpacks we want to combine them all, so add |
| 297 | + # the combined tuple type to the sequence types. |
291 | 298 | sequence_types.append(self.chk.iterable_item_type(tuple_fallback(t), o)) |
292 | 299 | elif isinstance(t, AnyType): |
293 | 300 | sequence_types.append(AnyType(TypeOfAny.from_another_any, t)) |
294 | 301 | elif self.chk.type_is_iterable(t) and isinstance(t, Instance): |
295 | 302 | sequence_types.append(self.chk.iterable_item_type(t, o)) |
296 | 303 | else: |
297 | 304 | unknown_type = True |
298 | | - # if we only got one unpack tuple type, we can use that |
| 305 | + |
| 306 | + inner_types: list[Type] |
| 307 | + |
| 308 | + # If we only got one unpack tuple type, we can use that |
299 | 309 | unpack_index = None |
300 | 310 | if len(unpack_tuple_types) == 1 and len(sequence_types) == 1 and not tuple_types: |
301 | 311 | update_tuple_type, unpack_index, union_index = unpack_tuple_types[0] |
302 | | - inner_types: list[Type] = update_tuple_type.items |
| 312 | + inner_types = update_tuple_type.items |
303 | 313 | if isinstance(current_type, UnionType): |
304 | 314 | union_items = list(current_type.items) |
305 | 315 | union_items[union_index] = update_tuple_type |
306 | | - current_type = current_type.copy_modified(items=union_items) |
| 316 | + current_type = UnionType.make_union(items=union_items) |
307 | 317 | else: |
308 | 318 | current_type = update_tuple_type |
309 | | - # if we only got tuples we can't match, then exit early |
| 319 | + # If we only got tuples we can't match, then exit early |
310 | 320 | elif not tuple_types and not sequence_types and not unknown_type: |
311 | 321 | return self.early_non_match() |
312 | 322 | elif tuple_types: |
| 323 | + inner_types = [ |
| 324 | + make_simplified_union([*sequence_types, *[t for t in group if t is not None]]) |
| 325 | + for group in itertools.zip_longest(*tuple_types) |
| 326 | + ] |
313 | 327 | inner_types = [make_simplified_union([*sequence_types, *x]) for x in zip(*tuple_types)] |
| 328 | + elif sequence_types: |
| 329 | + inner_types = [make_simplified_union(sequence_types)] * len(o.patterns) |
314 | 330 | else: |
315 | | - object_type = self.chk.named_type("builtins.object") |
316 | | - unioned = make_simplified_union(sequence_types) if sequence_types else object_type |
317 | | - inner_types = [unioned] * len(o.patterns) |
| 331 | + inner_types = [self.chk.named_type("builtins.object")] * len(o.patterns) |
318 | 332 |
|
319 | 333 | # |
320 | 334 | # match inner patterns |
|
0 commit comments