Skip to content
Merged
4 changes: 3 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,9 @@ def check_func_def(
):
if defn.is_class or defn.name == "__new__":
ref_type = mypy.types.TypeType.make_normalized(ref_type)
erased = get_proper_type(erase_to_bound(arg_type))
# This level of erasure matches the one in checkmember.check_self_arg(),
# better keep these two checks consistent.
erased = get_proper_type(erase_typevars(erase_to_bound(arg_type)))
Comment thread
ilevkivskyi marked this conversation as resolved.
if not is_subtype(ref_type, erased, ignore_type_params=True):
if (
isinstance(erased, Instance)
Expand Down
7 changes: 7 additions & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
return self.replacement
return t

def visit_callable_type(self, t: CallableType) -> Type:
result = super().visit_callable_type(t)
if t.param_spec():
assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.erased = True
return result

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Type alias target can't contain bound type variables (not bound by the type
# alias itself), so it is safe to just erase the arguments.
Expand Down
3 changes: 3 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,9 @@ def report_protocol_problems(
not is_subtype(subtype, erase_type(supertype), options=self.options)
or not subtype.type.defn.type_vars
or not supertype.type.defn.type_vars
# Always show detailed message for ParamSpec
or subtype.type.has_param_spec_type
or supertype.type.has_param_spec_type
):
type_name = format_type(subtype, self.options, module_names=True)
self.note(f"Following member(s) of {type_name} have conflicts:", context, code=code)
Expand Down
21 changes: 18 additions & 3 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,18 @@ def are_trivial_parameters(param: Parameters | NormalizedCallableType) -> bool:
)


def is_trivial_suffix(param: Parameters | NormalizedCallableType) -> bool:
param_star = param.var_arg()
param_star2 = param.kw_arg()
return (
param.arg_kinds[-2:] == [ARG_STAR, ARG_STAR2]
and param_star is not None
and isinstance(get_proper_type(param_star.typ), AnyType)
and param_star2 is not None
and isinstance(get_proper_type(param_star2.typ), AnyType)
)


def are_parameters_compatible(
left: Parameters | NormalizedCallableType,
right: Parameters | NormalizedCallableType,
Expand All @@ -1498,6 +1510,9 @@ def are_parameters_compatible(
if are_trivial_parameters(right):
return True

# Parameters should not contain nested ParamSpec, so erasure doesn't make them less general.
trivial_suffix = isinstance(right, CallableType) and right.erased and is_trivial_suffix(right)

# Match up corresponding arguments and check them for compatibility. In
# every pair (argL, argR) of corresponding arguments from L and R, argL must
# be "more general" than argR if L is to be a subtype of R.
Expand Down Expand Up @@ -1527,7 +1542,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
if right_arg is None:
return False
if left_arg is None:
return not allow_partial_overlap
return not allow_partial_overlap and not trivial_suffix
return not is_compat(right_arg.typ, left_arg.typ)

if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
Expand All @@ -1551,7 +1566,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# arguments. Get all further positional args of left, and make sure
# they're more general than the corresponding member in right.
# TODO: are we handling UnpackType correctly here?
if right_star is not None:
if right_star is not None and not trivial_suffix:
# Synthesize an anonymous formal argument for the right
right_by_position = right.try_synthesizing_arg_from_vararg(None)
assert right_by_position is not None
Expand All @@ -1578,7 +1593,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# Phase 1d: Check kw args. Right has an infinite series of optional named
# arguments. Get all further named args of left, and make sure
# they're more general than the corresponding member in right.
if right_star2 is not None:
if right_star2 is not None and not trivial_suffix:
right_names = {name for name in right.arg_names if name is not None}
left_only_names = set()
for name, kind in zip(left.arg_names, left.arg_kinds):
Expand Down
4 changes: 4 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def supported_self_type(typ: ProperType) -> bool:
"""
if isinstance(typ, TypeType):
return supported_self_type(typ.item)
if isinstance(typ, CallableType):
# Special case: allow class callable instead of Type[...] as cls annotation,
# as well as callable self for callback protocols.
return True
return isinstance(typ, TypeVarType) or (
isinstance(typ, Instance) and typ != fill_typevars(typ.type)
)
Expand Down
7 changes: 7 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,7 @@ class CallableType(FunctionLike):
# (this is used for error messages)
"imprecise_arg_kinds",
"unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable?
"erased", # Is this callable created as an erased form of a more precise type?
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels kind of ad hoc -- and there seems to overlap with is_ellipsis_args. Is there a way to merge these two? For example, allow is_ellipsis_args to be used with an argument prefix. Then erasing a ParamSpec could produce a type with is_ellipsis_args=True.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This looks quite ad-hoc. The more think about it the more I think we should do this unconditionally. This will make the whole thing much simpler.

)

def __init__(
Expand All @@ -1803,6 +1804,7 @@ def __init__(
from_concatenate: bool = False,
imprecise_arg_kinds: bool = False,
unpack_kwargs: bool = False,
erased: bool = False,
) -> None:
super().__init__(line, column)
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand Down Expand Up @@ -1850,6 +1852,7 @@ def __init__(
self.def_extras = {}
self.type_guard = type_guard
self.unpack_kwargs = unpack_kwargs
self.erased = erased

def copy_modified(
self: CT,
Expand All @@ -1873,6 +1876,7 @@ def copy_modified(
from_concatenate: Bogus[bool] = _dummy,
imprecise_arg_kinds: Bogus[bool] = _dummy,
unpack_kwargs: Bogus[bool] = _dummy,
erased: Bogus[bool] = _dummy,
) -> CT:
modified = CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
Expand Down Expand Up @@ -1903,6 +1907,7 @@ def copy_modified(
else self.imprecise_arg_kinds
),
unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs,
erased=erased if erased is not _dummy else self.erased,
)
# Optimization: Only NewTypes are supported as subtypes since
# the class is effectively final, so we can use a cast safely.
Expand Down Expand Up @@ -2220,6 +2225,7 @@ def serialize(self) -> JsonDict:
"from_concatenate": self.from_concatenate,
"imprecise_arg_kinds": self.imprecise_arg_kinds,
"unpack_kwargs": self.unpack_kwargs,
"erased": self.erased,
}

@classmethod
Expand All @@ -2244,6 +2250,7 @@ def deserialize(cls, data: JsonDict) -> CallableType:
from_concatenate=data["from_concatenate"],
imprecise_arg_kinds=data["imprecise_arg_kinds"],
unpack_kwargs=data["unpack_kwargs"],
erased=data["erased"],
)


Expand Down
113 changes: 112 additions & 1 deletion test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,12 @@ class A(Protocol[P]):
...

def bar(b: A[P]) -> A[Concatenate[int, P]]:
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") \
# N: Following member(s) of "A[P]" have conflicts: \
# N: Expected: \
# N: def foo(self, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \
# N: Got: \
# N: def foo(self, *args: P.args, **kwargs: P.kwargs) -> Any
[builtins fixtures/paramspec.pyi]

[case testParamSpecPrefixSubtypingValidNonStrict]
Expand Down Expand Up @@ -1825,6 +1830,112 @@ c: C[int, [int, str], str] # E: Nested parameter specifications are not allowed
reveal_type(c) # N: Revealed type is "__main__.C[Any]"
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateSelfType]
from typing import Callable
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
class A:
def __init__(self, a_param_1: str) -> None: ...

@classmethod
def add_params(cls: Callable[P, A]) -> Callable[Concatenate[float, P], A]:
def new_constructor(i: float, *args: P.args, **kwargs: P.kwargs) -> A:
return cls(*args, **kwargs)
return new_constructor

@classmethod
def remove_params(cls: Callable[Concatenate[str, P], A]) -> Callable[P, A]:
def new_constructor(*args: P.args, **kwargs: P.kwargs) -> A:
return cls("my_special_str", *args, **kwargs)
return new_constructor

reveal_type(A.add_params()) # N: Revealed type is "def (builtins.float, a_param_1: builtins.str) -> __main__.A"
reveal_type(A.remove_params()) # N: Revealed type is "def () -> __main__.A"
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateCallbackProtocol]
from typing import Protocol, TypeVar
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
R = TypeVar("R", covariant=True)

class Path: ...

class Function(Protocol[P, R]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...

def file_cache(fn: Function[Concatenate[Path, P], R]) -> Function[P, R]:
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
return fn(Path(), *args, **kw)
return wrapper

@file_cache
def get_thing(path: Path, *, some_arg: int) -> int: ...
reveal_type(get_thing) # N: Revealed type is "__main__.Function[[*, some_arg: builtins.int], builtins.int]"
get_thing(some_arg=1) # OK
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateKeywordOnly]
from typing import Callable, TypeVar
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
R = TypeVar("R")

class Path: ...

def file_cache(fn: Callable[Concatenate[Path, P], R]) -> Callable[P, R]:
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
return fn(Path(), *args, **kw)
return wrapper

@file_cache
def get_thing(path: Path, *, some_arg: int) -> int: ...
reveal_type(get_thing) # N: Revealed type is "def (*, some_arg: builtins.int) -> builtins.int"
get_thing(some_arg=1) # OK
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateCallbackApply]
from typing import Callable, Protocol
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")

class FuncType(Protocol[P]):
def __call__(self, x: int, s: str, *args: P.args, **kw_args: P.kwargs) -> str: ...

def forwarder1(fp: FuncType[P], *args: P.args, **kw_args: P.kwargs) -> str:
return fp(0, '', *args, **kw_args)

def forwarder2(fp: Callable[Concatenate[int, str, P], str], *args: P.args, **kw_args: P.kwargs) -> str:
return fp(0, '', *args, **kw_args)

def my_f(x: int, s: str, d: bool) -> str: ...
forwarder1(my_f, True) # OK
forwarder2(my_f, True) # OK
forwarder1(my_f, 1.0) # E: Argument 2 to "forwarder1" has incompatible type "float"; expected "bool"
forwarder2(my_f, 1.0) # E: Argument 2 to "forwarder2" has incompatible type "float"; expected "bool"
[builtins fixtures/paramspec.pyi]

[case testParamSpecCallbackProtocolSelf]
from typing import Callable, Protocol, TypeVar
from typing_extensions import ParamSpec, Concatenate

Params = ParamSpec("Params")
Result = TypeVar("Result", covariant=True)

class FancyMethod(Protocol):
def __call__(self, arg1: int, arg2: str) -> bool: ...
def return_me(self: Callable[Params, Result]) -> Callable[Params, Result]: ...
def return_part(self: Callable[Concatenate[int, Params], Result]) -> Callable[Params, Result]: ...

m: FancyMethod
reveal_type(m.return_me()) # N: Revealed type is "def (arg1: builtins.int, arg2: builtins.str) -> builtins.bool"
reveal_type(m.return_part()) # N: Revealed type is "def (arg2: builtins.str) -> builtins.bool"
[builtins fixtures/paramspec.pyi]

[case testParamSpecInferenceWithCallbackProtocol]
from typing import Protocol, Callable, ParamSpec

Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/paramspec.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class object:

class function: ...
class ellipsis: ...
class classmethod: ...

class type:
def __init__(self, *a: object) -> None: ...
Expand Down