Skip to content

Commit 44a653c

Browse files
authored
[dataclass_transform] support class decorator parameters (#14561)
The initial implementation of `typing.dataclass_transform` only supported the no-argument `@decorator` form; this adds support for the `@decorator(...)` form supporting the same arguments we support for `dataclasses.dataclass`. This also matches the list of arguments specified in PEP 681. Co-authored-by: Wesley Collin Wright <wesleyw@dropbox.com>
1 parent c375009 commit 44a653c

File tree

3 files changed

+79
-13
lines changed

3 files changed

+79
-13
lines changed

mypy/plugins/common.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Expression,
1414
FuncDef,
1515
JsonDict,
16+
Node,
1617
PassStmt,
1718
RefExpr,
1819
SymbolTableNode,
@@ -68,19 +69,7 @@ def _get_argument(call: CallExpr, name: str) -> Expression | None:
6869
#
6970
# Note: I'm not hard-coding the index so that in the future we can support other
7071
# attrib and class makers.
71-
if not isinstance(call.callee, RefExpr):
72-
return None
73-
74-
callee_type = None
75-
callee_node = call.callee.node
76-
if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type:
77-
callee_node_type = get_proper_type(callee_node.type)
78-
if isinstance(callee_node_type, Overloaded):
79-
# We take the last overload.
80-
callee_type = callee_node_type.items[-1]
81-
elif isinstance(callee_node_type, CallableType):
82-
callee_type = callee_node_type
83-
72+
callee_type = _get_callee_type(call)
8473
if not callee_type:
8574
return None
8675

@@ -94,6 +83,31 @@ def _get_argument(call: CallExpr, name: str) -> Expression | None:
9483
return attr_value
9584
if attr_name == argument.name:
9685
return attr_value
86+
87+
return None
88+
89+
90+
def _get_callee_type(call: CallExpr) -> CallableType | None:
91+
"""Return the type of the callee, regardless of its syntatic form."""
92+
93+
callee_node: Node | None = call.callee
94+
95+
if isinstance(callee_node, RefExpr):
96+
callee_node = callee_node.node
97+
98+
# Some decorators may be using typing.dataclass_transform, which is itself a decorator, so we
99+
# need to unwrap them to get at the true callee
100+
if isinstance(callee_node, Decorator):
101+
callee_node = callee_node.func
102+
103+
if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type:
104+
callee_node_type = get_proper_type(callee_node.type)
105+
if isinstance(callee_node_type, Overloaded):
106+
# We take the last overload.
107+
return callee_node_type.items[-1]
108+
elif isinstance(callee_node_type, CallableType):
109+
return callee_node_type
110+
97111
return None
98112

99113

mypy/semanal.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6646,5 +6646,16 @@ def halt(self, reason: str = ...) -> NoReturn:
66466646
def is_dataclass_transform_decorator(node: Node | None) -> bool:
66476647
if isinstance(node, RefExpr):
66486648
return is_dataclass_transform_decorator(node.node)
6649+
if isinstance(node, CallExpr):
6650+
# Like dataclasses.dataclass, transform-based decorators can be applied either with or
6651+
# without parameters; ie, both of these forms are accepted:
6652+
#
6653+
# @typing.dataclass_transform
6654+
# class Foo: ...
6655+
# @typing.dataclass_transform(eq=True, order=True, ...)
6656+
# class Bar: ...
6657+
#
6658+
# We need to unwrap the call for the second variant.
6659+
return is_dataclass_transform_decorator(node.callee)
66496660

66506661
return isinstance(node, Decorator) and node.func.is_dataclass_transform

test-data/unit/check-dataclass-transform.test

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,44 @@ Person('Jonh', 21, None) # E: Too many arguments for "Person"
4444

4545
[typing fixtures/typing-full.pyi]
4646
[builtins fixtures/dataclasses.pyi]
47+
48+
[case testDataclassTransformParametersAreApplied]
49+
# flags: --python-version 3.7
50+
from typing import dataclass_transform, Callable, Type
51+
52+
@dataclass_transform()
53+
def my_dataclass(*, eq: bool, order: bool) -> Callable[[Type], Type]:
54+
def transform(cls: Type) -> Type:
55+
return cls
56+
return transform
57+
58+
@my_dataclass(eq=False, order=True)
59+
class Person: # E: eq must be True if order is True
60+
name: str
61+
age: int
62+
63+
reveal_type(Person) # N: Revealed type is "def (name: builtins.str, age: builtins.int) -> __main__.Person"
64+
Person('John', 32)
65+
Person('John', 21, None) # E: Too many arguments for "Person"
66+
67+
[typing fixtures/typing-medium.pyi]
68+
[builtins fixtures/dataclasses.pyi]
69+
70+
[case testDataclassTransformParametersMustBeBoolLiterals]
71+
# flags: --python-version 3.7
72+
from typing import dataclass_transform, Callable, Type
73+
74+
@dataclass_transform()
75+
def my_dataclass(*, eq: bool = True, order: bool = False) -> Callable[[Type], Type]:
76+
def transform(cls: Type) -> Type:
77+
return cls
78+
return transform
79+
80+
BOOL_CONSTANT = True
81+
@my_dataclass(eq=BOOL_CONSTANT) # E: "eq" argument must be True or False.
82+
class A: ...
83+
@my_dataclass(order=not False) # E: "order" argument must be True or False.
84+
class B: ...
85+
86+
[typing fixtures/typing-medium.pyi]
87+
[builtins fixtures/dataclasses.pyi]

0 commit comments

Comments
 (0)