Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# This is to match the direction the implementation's return
# needs to be compatible in.
if impl_type.variables:
impl = unify_generic_callable(
impl_type,
sig1,
impl: CallableType | None = unify_generic_callable(
# Normalize both before unifying
impl_type.with_unpacked_kwargs(),
sig1.with_unpacked_kwargs(),
ignore_return=False,
return_constraint_direction=SUPERTYPE_OF,
)
Expand Down Expand Up @@ -1166,7 +1167,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str | None) ->
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
if not isinstance(arg_type, ParamSpecType):
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
arg_type = self.named_generic_type(
"builtins.dict", [self.str_type(), arg_type]
)
Expand Down
4 changes: 4 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,8 @@ def check_callable_call(

See the docstring of check_call for more information.
"""
# Always unpack **kwargs before checking a call.
callee = callee.with_unpacked_kwargs()
if callable_name is None and callee.name:
callable_name = callee.name
ret_type = get_proper_type(callee.ret_type)
Expand Down Expand Up @@ -2057,6 +2059,8 @@ def check_overload_call(
context: Context,
) -> tuple[Type, Type]:
"""Checks a call to an overloaded function."""
# Normalize unpacked kwargs before checking the call.
callee = callee.with_unpacked_kwargs()
arg_types = self.infer_arg_types_in_empty_context(args)
# Step 1: Filter call targets to remove ones where the argument counts don't match
plausible_targets = self.plausible_overload_call_targets(
Expand Down
6 changes: 5 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,13 @@ def infer_constraints_from_protocol_members(
return res

def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# Normalize callables before matching against each other.
# Note that non-normalized callables can be created in annotations
# using e.g. callback protocols.
template = template.with_unpacked_kwargs()
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual
cactual = self.actual.with_unpacked_kwargs()
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
Expand Down
18 changes: 17 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Tuple

import mypy.typeops
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
Expand Down Expand Up @@ -141,7 +143,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:

def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
"""Return a simple least upper bound given the declared type."""
# TODO: check infinite recursion for aliases here.
# TODO: check infinite recursion for aliases here?
declaration = get_proper_type(declaration)
s = get_proper_type(s)
t = get_proper_type(t)
Expand Down Expand Up @@ -172,6 +174,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

value = t.accept(TypeJoinVisitor(s))
if declaration is None or is_subtype(value, declaration):
return value
Expand Down Expand Up @@ -229,6 +234,9 @@ def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None)
elif isinstance(t, PlaceholderType):
return AnyType(TypeOfAny.from_error)

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

# Use a visitor to handle non-trivial cases.
return t.accept(TypeJoinVisitor(s, instance_joiner))

Expand Down Expand Up @@ -528,6 +536,14 @@ def is_better(t: Type, s: Type) -> bool:
return False


def normalize_callables(s: ProperType, t: ProperType) -> Tuple[ProperType, ProperType]:
if isinstance(s, (CallableType, Overloaded)):
s = s.with_unpacked_kwargs()
if isinstance(t, (CallableType, Overloaded)):
t = t.with_unpacked_kwargs()
return s, t


def is_similar_callables(t: CallableType, s: CallableType) -> bool:
"""Return True if t and s have identical numbers of
arguments, default arguments and varargs.
Expand Down
4 changes: 4 additions & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def meet_types(s: Type, t: Type) -> ProperType:
return t
if isinstance(s, UnionType) and not isinstance(t, UnionType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = join.normalize_callables(s, t)

return t.accept(TypeMeetVisitor(s))


Expand Down
5 changes: 4 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,7 +2391,10 @@ def [T <: int] f(self, x: int, y: T) -> None
name = tp.arg_names[i]
if name:
s += name + ": "
s += format_type_bare(tp.arg_types[i])
type_str = format_type_bare(tp.arg_types[i])
if tp.arg_kinds[i] == ARG_STAR2 and tp.unpack_kwargs:
type_str = f"Unpack[{type_str}]"
s += type_str
if tp.arg_kinds[i].is_optional():
s += " = ..."

Expand Down
26 changes: 26 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@
TypeVarLikeType,
TypeVarType,
UnboundType,
UnpackType,
get_proper_type,
get_proper_types,
invalid_recursive_alias,
Expand Down Expand Up @@ -830,6 +831,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
self.defer(defn)
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
result = self.remove_unpack_kwargs(defn, result)
defn.type = result
self.add_type_alias_deps(analyzer.aliases_used)
self.check_function_signature(defn)
Expand Down Expand Up @@ -872,6 +875,29 @@ def analyze_func_def(self, defn: FuncDef) -> None:
defn.type = defn.type.copy_modified(ret_type=ret_type)
self.wrapped_coro_return_types[defn] = defn.type

def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
return typ
last_type = get_proper_type(typ.arg_types[-1])
if not isinstance(last_type, UnpackType):
return typ
last_type = get_proper_type(last_type.type)
if not isinstance(last_type, TypedDictType):
self.fail("Unpack item in ** argument must be a TypedDict", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
overlap = set(typ.arg_names) & set(last_type.items)
# It is OK for TypedDict to have a key named 'kwargs'.
overlap.discard(typ.arg_names[-1])
if overlap:
overlapped = ", ".join([f'"{name}"' for name in overlap])
self.fail(f"Overlap between argument names and ** TypedDict items: {overlapped}", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
# OK, everything looks right now, mark the callable type as using unpack.
new_arg_types = typ.arg_types[:-1] + [last_type]
return typ.copy_modified(arg_types=new_arg_types, unpack_kwargs=True)

def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
"""Check basic signature validity and tweak annotation of self/cls argument."""
# Only non-static methods are special.
Expand Down
25 changes: 16 additions & 9 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Instance,
LiteralType,
NoneType,
NormalizedCallableType,
Overloaded,
Parameters,
ParamSpecType,
Expand Down Expand Up @@ -626,8 +627,10 @@ def visit_unpack_type(self, left: UnpackType) -> bool:
return False

def visit_parameters(self, left: Parameters) -> bool:
right = self.right
if isinstance(right, Parameters) or isinstance(right, CallableType):
if isinstance(self.right, Parameters) or isinstance(self.right, CallableType):
right = self.right
if isinstance(right, CallableType):
right = right.with_unpacked_kwargs()
return are_parameters_compatible(
left,
right,
Expand Down Expand Up @@ -671,7 +674,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
elif isinstance(right, Parameters):
# this doesn't check return types.... but is needed for is_equivalent
return are_parameters_compatible(
left,
left.with_unpacked_kwargs(),
right,
is_compat=self._is_subtype,
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
Expand Down Expand Up @@ -1249,6 +1252,10 @@ def g(x: int) -> int: ...
If the 'some_check' function is also symmetric, the two calls would be equivalent
whether or not we check the args covariantly.
"""
# Normalize both types before comparing them.
left = left.with_unpacked_kwargs()
right = right.with_unpacked_kwargs()

if is_compat_return is None:
is_compat_return = is_compat

Expand Down Expand Up @@ -1313,8 +1320,8 @@ def g(x: int) -> int: ...


def are_parameters_compatible(
left: Parameters | CallableType,
right: Parameters | CallableType,
left: Parameters | NormalizedCallableType,
right: Parameters | NormalizedCallableType,
*,
is_compat: Callable[[Type, Type], bool],
ignore_pos_arg_names: bool = False,
Expand Down Expand Up @@ -1535,11 +1542,11 @@ def new_is_compat(left: Type, right: Type) -> bool:


def unify_generic_callable(
type: CallableType,
target: CallableType,
type: NormalizedCallableType,
target: NormalizedCallableType,
ignore_return: bool,
return_constraint_direction: int | None = None,
) -> CallableType | None:
) -> NormalizedCallableType | None:
"""Try to unify a generic callable type with another callable type.

Return unified CallableType if successful; otherwise, return None.
Expand Down Expand Up @@ -1576,7 +1583,7 @@ def report(*args: Any) -> None:
)
if had_errors:
return None
return applied
return cast(NormalizedCallableType, applied)


def try_restrict_literal_union(t: UnionType, s: Type) -> list[Type] | None:
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
elif fullname in ("typing.Unpack", "typing_extensions.Unpack"):
# We don't want people to try to use this yet.
if not self.options.enable_incomplete_features:
self.fail('"Unpack" is not supported by mypy yet', t)
self.fail('"Unpack" is not supported yet, use --enable-incomplete-features', t)
return AnyType(TypeOfAny.from_error)
return UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column)
return None
Expand Down
Loading