Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
from mypy.options import Options
from mypy.patterns import AsPattern, StarredPattern
from mypy.plugin import CheckerPluginInterface, Plugin
from mypy.plugins import dataclasses as dataclasses_plugin
from mypy.scope import Scope
from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name
from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS
Expand Down Expand Up @@ -1044,6 +1045,9 @@ def check_func_item(

if name == "__exit__":
self.check__exit__return_type(defn)
if name == "__post_init__":
if dataclasses_plugin.is_processed_dataclass(defn.info):
dataclasses_plugin.check_post_init(self, defn, defn.info)
Comment thread
sobolevn marked this conversation as resolved.

@contextmanager
def enter_attribute_inference_context(self) -> Iterator[None]:
Expand Down Expand Up @@ -1851,7 +1855,7 @@ def check_method_or_accessor_override_for_base(
found_base_method = True

# Check the type of override.
if name not in ("__init__", "__new__", "__init_subclass__"):
if name not in ("__init__", "__new__", "__init_subclass__", "__post_init__"):
# Check method override
# (__init__, __new__, __init_subclass__ are special).
if self.check_method_override_for_base_with_name(defn, name, base):
Expand Down Expand Up @@ -2812,6 +2816,9 @@ def check_assignment(
if name == "__match_args__" and inferred is not None:
typ = self.expr_checker.accept(rvalue)
self.check_match_args(inferred, typ, lvalue)
if name == "__post_init__":
if dataclasses_plugin.is_processed_dataclass(self.scope.active_class()):
self.fail(message_registry.DATACLASS_POST_INIT_MUST_BE_A_FUNCTION, rvalue)

# Defer PartialType's super type checking.
if (
Expand Down
1 change: 1 addition & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = (
'"alias" argument to dataclass field must be a string literal'
)
DATACLASS_POST_INIT_MUST_BE_A_FUNCTION: Final = '"__post_init__" method must be an instance method'

# fastparse
FAILED_TO_MERGE_OVERLOADS: Final = ErrorMessage(
Expand Down
86 changes: 82 additions & 4 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Iterator, Optional
from typing import TYPE_CHECKING, Iterator, Optional
from typing_extensions import Final

from mypy import errorcodes, message_registry
Expand All @@ -26,6 +26,7 @@
DataclassTransformSpec,
Expression,
FuncDef,
FuncItem,
IfStmt,
JsonDict,
NameExpr,
Expand Down Expand Up @@ -55,6 +56,7 @@
from mypy.types import (
AnyType,
CallableType,
FunctionLike,
Instance,
LiteralType,
NoneType,
Expand All @@ -69,19 +71,23 @@
)
from mypy.typevars import fill_typevars

if TYPE_CHECKING:
from mypy.checker import TypeChecker

# The set of decorators that generate dataclasses.
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}


SELF_TVAR_NAME: Final = "_DT"
_TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec(
_TRANSFORM_SPEC_FOR_DATACLASSES: Final = DataclassTransformSpec(
eq_default=True,
order_default=False,
kw_only_default=False,
frozen_default=False,
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)
_INTERNAL_REPLACE_SYM_NAME = "__mypy-replace"
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-__post_init__"


class DataclassAttribute:
Expand Down Expand Up @@ -352,6 +358,8 @@ def transform(self) -> bool:

if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
self._add_internal_replace_method(attributes)
if "__post_init__" in info.names:
self._add_internal_post_init_method(attributes)

info.metadata["dataclass"] = {
"attributes": [attr.serialize() for attr in attributes],
Expand Down Expand Up @@ -387,7 +395,47 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
fallback=self._api.named_type("builtins.function"),
)

self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
)

def _add_internal_post_init_method(self, attributes: list[DataclassAttribute]) -> None:
arg_types: list[Type] = [fill_typevars(self._cls.info)]
arg_kinds = [ARG_POS]
arg_names: list[str | None] = ["self"]

info = self._cls.info
for attr in attributes:
if not attr.is_init_var:
continue
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
# We always use `ARG_POS` without a default value, because it is practical.
# Consider this case:
#
# @dataclass
# class My:
# y: dataclasses.InitVar[str] = 'a'
# def __post_init__(self, y: str) -> None: ...
#
# We would be *required* to specify `y: str = ...` if default is added here.
# But, most people won't care about adding default values to `__post_init__`,
# because it is not designed to called directly and duplicating default values
Comment thread
sobolevn marked this conversation as resolved.
Outdated
# for the sake of type-checking is unpleasant.
arg_kinds.append(ARG_POS)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
name="__post_init__",
)

info.names[_INTERNAL_POST_INIT_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
)

Expand Down Expand Up @@ -1054,3 +1102,33 @@ def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
fallback=ctx.default_signature.fallback,
name=f"{ctx.default_signature.name} of {inst_type_str}",
)


def is_processed_dataclass(info: TypeInfo | None) -> bool:
return info is not None and "dataclass" in info.metadata


def check_post_init(api: TypeChecker, defn: FuncItem, info: TypeInfo) -> None:
if defn.type is None:
return

ideal_sig = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
if ideal_sig is None or ideal_sig.type is None:
return

# We set it ourself, so it is always fine:
assert isinstance(ideal_sig.type, ProperType)
assert isinstance(ideal_sig.type, FunctionLike)
# Type of `FuncItem` is always `FunctionLike`:
assert isinstance(defn.type, FunctionLike)

api.check_override(
override=defn.type,
original=ideal_sig.type,
name="__post_init__",
name_in_super="__post_init__",
supertype="dataclass",
original_class_or_static=False,
override_class_or_static=False,
node=defn,
)
166 changes: 166 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,172 @@ reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
a2 = replace(a, x='42') # E: Argument "x" to "replace" of "A[int]" has incompatible type "str"; expected "int"
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"

[case testPostInitCorrectSignature]
from dataclasses import dataclass, InitVar

@dataclass
class Test1:
x: int
def __post_init__(self) -> None: ...

@dataclass
class Test2:
x: int
y: InitVar[int]
z: str
def __post_init__(self, y: int) -> None: ...

@dataclass
class Test3:
x: InitVar[int]
y: InitVar[str]
def __post_init__(self, x: int, y: str) -> None: ...

@dataclass
class Test4:
x: int
y: InitVar[str]
z: InitVar[bool] = True
def __post_init__(self, y: str, z: bool) -> None: ...

@dataclass
class Test5:
y: InitVar[str] = 'a'
z: InitVar[bool] = True
def __post_init__(self, y: str = 'a', z: bool = True) -> None: ...

from typing import Any, Callable, TypeVar
F = TypeVar('F', bound=Callable[..., Any])
def identity(f: F) -> F: return f

@dataclass
class Test6:
y: InitVar[str]
@identity # decorated method works
def __post_init__(self, y: str) -> None: ...
[builtins fixtures/dataclasses.pyi]

[case testPostInitSubclassing]
from dataclasses import dataclass, InitVar

@dataclass
class Base:
a: str
x: InitVar[int]
def __post_init__(self, x: int) -> None: ...

@dataclass
class Child(Base):
b: str
y: InitVar[str]
def __post_init__(self, x: int, y: str) -> None: ...

@dataclass
class GrandChild(Child):
c: int
z: InitVar[str] = "a"
def __post_init__(self, x: int, y: str, z: str) -> None: ...
[builtins fixtures/dataclasses.pyi]

[case testPostInitNotADataclassCheck]
from dataclasses import dataclass, InitVar

class Regular:
__post_init__ = 1 # can be whatever

class Base:
x: InitVar[int]
def __post_init__(self) -> None: ... # can be whatever

@dataclass
class Child(Base):
y: InitVar[str]
def __post_init__(self, y: str) -> None: ...
[builtins fixtures/dataclasses.pyi]

[case testPostInitMissingParam]
from dataclasses import dataclass, InitVar

@dataclass
class Child:
y: InitVar[str]
def __post_init__(self) -> None: ...
[builtins fixtures/dataclasses.pyi]
[out]
main:6: error: Signature of "__post_init__" incompatible with supertype "dataclass"
main:6: note: Superclass:
main:6: note: def __post_init__(self: Child, y: str) -> None
main:6: note: Subclass:
main:6: note: def __post_init__(self: Child) -> None

[case testPostInitWrongTypeAndName]
from dataclasses import dataclass, InitVar

@dataclass
class Child:
y: InitVar[str]
def __post_init__(self, x: int) -> None: ...
[builtins fixtures/dataclasses.pyi]
[out]
main:6: error: Argument 2 of "__post_init__" is incompatible with supertype "dataclass"; supertype defines the argument type as "str"
main:6: note: This violates the Liskov substitution principle
main:6: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
Comment thread
sobolevn marked this conversation as resolved.
Outdated

[case testPostInitExtraParam]
from dataclasses import dataclass, InitVar

@dataclass
class Child:
y: InitVar[str]
def __post_init__(self, y: str, z: int) -> None: ...
[builtins fixtures/dataclasses.pyi]
[out]
main:6: error: Signature of "__post_init__" incompatible with supertype "dataclass"
main:6: note: Superclass:
main:6: note: def __post_init__(self: Child, y: str) -> None
main:6: note: Subclass:
main:6: note: def __post_init__(self: Child, y: str, z: int) -> None

[case testPostInitReturnType]
from dataclasses import dataclass, InitVar

@dataclass
class Child:
y: InitVar[str]
def __post_init__(self, y: str) -> int: ...
[builtins fixtures/dataclasses.pyi]
[out]
main:6: error: Return type "int" of "__post_init__" incompatible with return type "None" in supertype "dataclass"

[case testPostInitDecoratedMethodError]
from dataclasses import dataclass, InitVar
from typing import Any, Callable, TypeVar

F = TypeVar('F', bound=Callable[..., Any])
def identity(f: F) -> F: return f

@dataclass
class Klass:
y: InitVar[str]
@identity
def __post_init__(self) -> None: ...
[builtins fixtures/dataclasses.pyi]
[out]
main:11: error: Signature of "__post_init__" incompatible with supertype "dataclass"
main:11: note: Superclass:
main:11: note: def __post_init__(self: Klass, y: str) -> None
main:11: note: Subclass:
main:11: note: def __post_init__(self: Klass) -> None

[case testPostInitIsNotAFunction]
from dataclasses import dataclass, InitVar

@dataclass
class Child:
y: InitVar[str]
__post_init__ = 1 # E: "__post_init__" method must be an instance method
[builtins fixtures/dataclasses.pyi]
Comment thread
sobolevn marked this conversation as resolved.

[case testProtocolNoCrash]
from typing import Protocol, Union, ClassVar
from dataclasses import dataclass, field
Expand Down