Skip to content

Commit f33c9a3

Browse files
Some final touches for variadic types support (#16334)
I decided to go again over various parts of variadic types implementation to double-check nothing is missing, checked interaction with various "advanced" features (dataclasses, protocols, self-types, match statement, etc.), added some more tests (including incremental), and `grep`ed for potentially unhandled cases (and did found few crashes). This mostly touches only variadic types but one thing goes beyond, the fix for self-types upper bound, I think it is correct and should be safe. If there are no objections, next PR will flip the switch. --------- Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
1 parent 6c7faf3 commit f33c9a3

19 files changed

+726
-72
lines changed

mypy/applytype.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Callable, Sequence
44

55
import mypy.subtypes
6+
from mypy.erasetype import erase_typevars
67
from mypy.expandtype import expand_type
78
from mypy.nodes import Context
89
from mypy.types import (
@@ -62,6 +63,11 @@ def get_target_type(
6263
report_incompatible_typevar_value(callable, type, tvar.name, context)
6364
else:
6465
upper_bound = tvar.upper_bound
66+
if tvar.name == "Self":
67+
# Internally constructed Self-types contain class type variables in upper bound,
68+
# so we need to erase them to avoid false positives. This is safe because we do
69+
# not support type variables in upper bounds of user defined types.
70+
upper_bound = erase_typevars(upper_bound)
6571
if not mypy.subtypes.is_subtype(type, upper_bound):
6672
if skip_unsatisfied:
6773
return None
@@ -121,6 +127,7 @@ def apply_generic_arguments(
121127
# Apply arguments to argument types.
122128
var_arg = callable.var_arg()
123129
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
130+
# Same as for ParamSpec, callable with variadic types needs to be expanded as a whole.
124131
callable = expand_type(callable, id_to_type)
125132
assert isinstance(callable, CallableType)
126133
return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type])

mypy/checker.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,7 +1852,6 @@ def expand_typevars(
18521852
if defn.info:
18531853
# Class type variables
18541854
tvars += defn.info.defn.type_vars or []
1855-
# TODO(PEP612): audit for paramspec
18561855
for tvar in tvars:
18571856
if isinstance(tvar, TypeVarType) and tvar.values:
18581857
subst.append([(tvar.id, value) for value in tvar.values])
@@ -2538,6 +2537,9 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
25382537
object_type = Instance(info.mro[-1], [])
25392538
tvars = info.defn.type_vars
25402539
for i, tvar in enumerate(tvars):
2540+
if not isinstance(tvar, TypeVarType):
2541+
# Variance of TypeVarTuple and ParamSpec is underspecified by PEPs.
2542+
continue
25412543
up_args: list[Type] = [
25422544
object_type if i == j else AnyType(TypeOfAny.special_form)
25432545
for j, _ in enumerate(tvars)
@@ -2554,7 +2556,7 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
25542556
expected = CONTRAVARIANT
25552557
else:
25562558
expected = INVARIANT
2557-
if isinstance(tvar, TypeVarType) and expected != tvar.variance:
2559+
if expected != tvar.variance:
25582560
self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn)
25592561

25602562
def check_multiple_inheritance(self, typ: TypeInfo) -> None:
@@ -6695,19 +6697,6 @@ def check_possible_missing_await(
66956697
return
66966698
self.msg.possible_missing_await(context, code)
66976699

6698-
def contains_none(self, t: Type) -> bool:
6699-
t = get_proper_type(t)
6700-
return (
6701-
isinstance(t, NoneType)
6702-
or (isinstance(t, UnionType) and any(self.contains_none(ut) for ut in t.items))
6703-
or (isinstance(t, TupleType) and any(self.contains_none(tt) for tt in t.items))
6704-
or (
6705-
isinstance(t, Instance)
6706-
and bool(t.args)
6707-
and any(self.contains_none(it) for it in t.args)
6708-
)
6709-
)
6710-
67116700
def named_type(self, name: str) -> Instance:
67126701
"""Return an instance type with given name and implicit Any type args.
67136702
@@ -7471,10 +7460,22 @@ def builtin_item_type(tp: Type) -> Type | None:
74717460
return None
74727461
if not isinstance(get_proper_type(tp.args[0]), AnyType):
74737462
return tp.args[0]
7474-
elif isinstance(tp, TupleType) and all(
7475-
not isinstance(it, AnyType) for it in get_proper_types(tp.items)
7476-
):
7477-
return make_simplified_union(tp.items) # this type is not externally visible
7463+
elif isinstance(tp, TupleType):
7464+
normalized_items = []
7465+
for it in tp.items:
7466+
# This use case is probably rare, but not handling unpacks here can cause crashes.
7467+
if isinstance(it, UnpackType):
7468+
unpacked = get_proper_type(it.type)
7469+
if isinstance(unpacked, TypeVarTupleType):
7470+
unpacked = get_proper_type(unpacked.upper_bound)
7471+
assert (
7472+
isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
7473+
)
7474+
normalized_items.append(unpacked.args[0])
7475+
else:
7476+
normalized_items.append(it)
7477+
if all(not isinstance(it, AnyType) for it in get_proper_types(normalized_items)):
7478+
return make_simplified_union(normalized_items) # this type is not externally visible
74787479
elif isinstance(tp, TypedDictType):
74797480
# TypedDict always has non-optional string keys. Find the key type from the Mapping
74807481
# base class.

mypy/checkexpr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
410410
result = self.alias_type_in_runtime_context(
411411
node, ctx=e, alias_definition=e.is_alias_rvalue or lvalue
412412
)
413-
elif isinstance(node, (TypeVarExpr, ParamSpecExpr)):
413+
elif isinstance(node, (TypeVarExpr, ParamSpecExpr, TypeVarTupleExpr)):
414414
result = self.object_type()
415415
else:
416416
if isinstance(node, PlaceholderNode):
@@ -3316,6 +3316,7 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty
33163316

33173317
def concat_tuples(self, left: TupleType, right: TupleType) -> TupleType:
33183318
"""Concatenate two fixed length tuples."""
3319+
assert not (find_unpack_in_list(left.items) and find_unpack_in_list(right.items))
33193320
return TupleType(
33203321
items=left.items + right.items, fallback=self.named_type("builtins.tuple")
33213322
)
@@ -6507,8 +6508,8 @@ def merge_typevars_in_callables_by_name(
65076508
for tv in target.variables:
65086509
name = tv.fullname
65096510
if name not in unique_typevars:
6510-
# TODO(PEP612): fix for ParamSpecType
6511-
if isinstance(tv, ParamSpecType):
6511+
# TODO: support ParamSpecType and TypeVarTuple.
6512+
if isinstance(tv, (ParamSpecType, TypeVarTupleType)):
65126513
continue
65136514
assert isinstance(tv, TypeVarType)
65146515
unique_typevars[name] = tv

mypy/checkpattern.py

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,13 @@
4545
Type,
4646
TypedDictType,
4747
TypeOfAny,
48+
TypeVarTupleType,
4849
UninhabitedType,
4950
UnionType,
51+
UnpackType,
52+
find_unpack_in_list,
5053
get_proper_type,
54+
split_with_prefix_and_suffix,
5155
)
5256
from mypy.typevars import fill_typevars
5357
from mypy.visitor import PatternVisitor
@@ -239,13 +243,29 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
239243
#
240244
# get inner types of original type
241245
#
246+
unpack_index = None
242247
if isinstance(current_type, TupleType):
243248
inner_types = current_type.items
244-
size_diff = len(inner_types) - required_patterns
245-
if size_diff < 0:
246-
return self.early_non_match()
247-
elif size_diff > 0 and star_position is None:
248-
return self.early_non_match()
249+
unpack_index = find_unpack_in_list(inner_types)
250+
if unpack_index is None:
251+
size_diff = len(inner_types) - required_patterns
252+
if size_diff < 0:
253+
return self.early_non_match()
254+
elif size_diff > 0 and star_position is None:
255+
return self.early_non_match()
256+
else:
257+
normalized_inner_types = []
258+
for it in inner_types:
259+
# Unfortunately, it is not possible to "split" the TypeVarTuple
260+
# into individual items, so we just use its upper bound for the whole
261+
# analysis instead.
262+
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
263+
it = UnpackType(it.type.upper_bound)
264+
normalized_inner_types.append(it)
265+
inner_types = normalized_inner_types
266+
current_type = current_type.copy_modified(items=normalized_inner_types)
267+
if len(inner_types) - 1 > required_patterns and star_position is None:
268+
return self.early_non_match()
249269
else:
250270
inner_type = self.get_sequence_type(current_type, o)
251271
if inner_type is None:
@@ -270,18 +290,18 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
270290
self.update_type_map(captures, type_map)
271291

272292
new_inner_types = self.expand_starred_pattern_types(
273-
contracted_new_inner_types, star_position, len(inner_types)
293+
contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None
274294
)
275295
rest_inner_types = self.expand_starred_pattern_types(
276-
contracted_rest_inner_types, star_position, len(inner_types)
296+
contracted_rest_inner_types, star_position, len(inner_types), unpack_index is not None
277297
)
278298

279299
#
280300
# Calculate new type
281301
#
282302
new_type: Type
283303
rest_type: Type = current_type
284-
if isinstance(current_type, TupleType):
304+
if isinstance(current_type, TupleType) and unpack_index is None:
285305
narrowed_inner_types = []
286306
inner_rest_types = []
287307
for inner_type, new_inner_type in zip(inner_types, new_inner_types):
@@ -301,6 +321,14 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
301321
if all(is_uninhabited(typ) for typ in inner_rest_types):
302322
# All subpatterns always match, so we can apply negative narrowing
303323
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
324+
elif isinstance(current_type, TupleType):
325+
# For variadic tuples it is too tricky to match individual items like for fixed
326+
# tuples, so we instead try to narrow the entire type.
327+
# TODO: use more precise narrowing when possible (e.g. for identical shapes).
328+
new_tuple_type = TupleType(new_inner_types, current_type.partial_fallback)
329+
new_type, rest_type = self.chk.conditional_types_with_intersection(
330+
new_tuple_type, [get_type_range(current_type)], o, default=new_tuple_type
331+
)
304332
else:
305333
new_inner_type = UninhabitedType()
306334
for typ in new_inner_types:
@@ -345,17 +373,45 @@ def contract_starred_pattern_types(
345373
346374
If star_pos in None the types are returned unchanged.
347375
"""
348-
if star_pos is None:
349-
return types
350-
new_types = types[:star_pos]
351-
star_length = len(types) - num_patterns
352-
new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
353-
new_types += types[star_pos + star_length :]
354-
355-
return new_types
376+
unpack_index = find_unpack_in_list(types)
377+
if unpack_index is not None:
378+
# Variadic tuples require "re-shaping" to match the requested pattern.
379+
unpack = types[unpack_index]
380+
assert isinstance(unpack, UnpackType)
381+
unpacked = get_proper_type(unpack.type)
382+
# This should be guaranteed by the normalization in the caller.
383+
assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
384+
if star_pos is None:
385+
missing = num_patterns - len(types) + 1
386+
new_types = types[:unpack_index]
387+
new_types += [unpacked.args[0]] * missing
388+
new_types += types[unpack_index + 1 :]
389+
return new_types
390+
prefix, middle, suffix = split_with_prefix_and_suffix(
391+
tuple([UnpackType(unpacked) if isinstance(t, UnpackType) else t for t in types]),
392+
star_pos,
393+
num_patterns - star_pos,
394+
)
395+
new_middle = []
396+
for m in middle:
397+
# The existing code expects the star item type, rather than the type of
398+
# the whole tuple "slice".
399+
if isinstance(m, UnpackType):
400+
new_middle.append(unpacked.args[0])
401+
else:
402+
new_middle.append(m)
403+
return list(prefix) + [make_simplified_union(new_middle)] + list(suffix)
404+
else:
405+
if star_pos is None:
406+
return types
407+
new_types = types[:star_pos]
408+
star_length = len(types) - num_patterns
409+
new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
410+
new_types += types[star_pos + star_length :]
411+
return new_types
356412

357413
def expand_starred_pattern_types(
358-
self, types: list[Type], star_pos: int | None, num_types: int
414+
self, types: list[Type], star_pos: int | None, num_types: int, original_unpack: bool
359415
) -> list[Type]:
360416
"""Undoes the contraction done by contract_starred_pattern_types.
361417
@@ -364,6 +420,17 @@ def expand_starred_pattern_types(
364420
"""
365421
if star_pos is None:
366422
return types
423+
if original_unpack:
424+
# In the case where original tuple type has an unpack item, it is not practical
425+
# to coerce pattern type back to the original shape (and may not even be possible),
426+
# so we only restore the type of the star item.
427+
res = []
428+
for i, t in enumerate(types):
429+
if i != star_pos:
430+
res.append(t)
431+
else:
432+
res.append(UnpackType(self.chk.named_generic_type("builtins.tuple", [t])))
433+
return res
367434
new_types = types[:star_pos]
368435
star_length = num_types - len(types) + 1
369436
new_types += [types[star_pos]] * star_length
@@ -459,7 +526,15 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
459526
return self.early_non_match()
460527
if isinstance(type_info, TypeInfo):
461528
any_type = AnyType(TypeOfAny.implementation_artifact)
462-
typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars))
529+
args: list[Type] = []
530+
for tv in type_info.defn.type_vars:
531+
if isinstance(tv, TypeVarTupleType):
532+
args.append(
533+
UnpackType(self.chk.named_generic_type("builtins.tuple", [any_type]))
534+
)
535+
else:
536+
args.append(any_type)
537+
typ: Type = Instance(type_info, args)
463538
elif isinstance(type_info, TypeAlias):
464539
typ = type_info.target
465540
elif (

mypy/constraints.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Instance,
2929
LiteralType,
3030
NoneType,
31+
NormalizedCallableType,
3132
Overloaded,
3233
Parameters,
3334
ParamSpecType,
@@ -1388,7 +1389,7 @@ def find_matching_overload_items(
13881389
return res
13891390

13901391

1391-
def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo | None:
1392+
def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo:
13921393
"""Get builtins.tuple type from available types to construct homogeneous tuples."""
13931394
tp = get_proper_type(unpack.type)
13941395
if isinstance(tp, Instance) and tp.type.fullname == "builtins.tuple":
@@ -1399,10 +1400,10 @@ def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo | None:
13991400
for base in tp.partial_fallback.type.mro:
14001401
if base.fullname == "builtins.tuple":
14011402
return base
1402-
return None
1403+
assert False, "Invalid unpack type"
14031404

14041405

1405-
def repack_callable_args(callable: CallableType, tuple_type: TypeInfo | None) -> list[Type]:
1406+
def repack_callable_args(callable: CallableType, tuple_type: TypeInfo) -> list[Type]:
14061407
"""Present callable with star unpack in a normalized form.
14071408
14081409
Since positional arguments cannot follow star argument, they are packed in a suffix,
@@ -1417,12 +1418,8 @@ def repack_callable_args(callable: CallableType, tuple_type: TypeInfo | None) ->
14171418
star_type = callable.arg_types[star_index]
14181419
suffix_types = []
14191420
if not isinstance(star_type, UnpackType):
1420-
if tuple_type is not None:
1421-
# Re-normalize *args: X -> *args: *tuple[X, ...]
1422-
star_type = UnpackType(Instance(tuple_type, [star_type]))
1423-
else:
1424-
# This is unfortunate, something like tuple[Any, ...] would be better.
1425-
star_type = UnpackType(AnyType(TypeOfAny.from_error))
1421+
# Re-normalize *args: X -> *args: *tuple[X, ...]
1422+
star_type = UnpackType(Instance(tuple_type, [star_type]))
14261423
else:
14271424
tp = get_proper_type(star_type.type)
14281425
if isinstance(tp, TupleType):
@@ -1544,7 +1541,9 @@ def infer_directed_arg_constraints(left: Type, right: Type, direction: int) -> l
15441541

15451542

15461543
def infer_callable_arguments_constraints(
1547-
template: CallableType | Parameters, actual: CallableType | Parameters, direction: int
1544+
template: NormalizedCallableType | Parameters,
1545+
actual: NormalizedCallableType | Parameters,
1546+
direction: int,
15481547
) -> list[Constraint]:
15491548
"""Infer constraints between argument types of two callables.
15501549

mypy/erasetype.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
100100
raise RuntimeError("Parameters should have been bound to a class")
101101

102102
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
103-
return AnyType(TypeOfAny.special_form)
103+
# Likely, we can never get here because of aggressive erasure of types that
104+
# can contain this, but better still return a valid replacement.
105+
return t.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
104106

105107
def visit_unpack_type(self, t: UnpackType) -> ProperType:
106108
return AnyType(TypeOfAny.special_form)

0 commit comments

Comments
 (0)