Skip to content

Commit 283fe3d

Browse files
committed
Add improved union support from python#15050
1 parent 9b491f5 commit 283fe3d

File tree

2 files changed

+174
-43
lines changed

2 files changed

+174
-43
lines changed

mypy/plugins/dataclasses.py

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from mypy import errorcodes, message_registry
99
from mypy.expandtype import expand_type, expand_type_by_instance
10+
from mypy.meet import meet_types
1011
from mypy.messages import format_type_bare
1112
from mypy.nodes import (
1213
ARG_NAMED,
@@ -57,10 +58,13 @@
5758
Instance,
5859
LiteralType,
5960
NoneType,
61+
ProperType,
6062
TupleType,
6163
Type,
6264
TypeOfAny,
6365
TypeVarType,
66+
UninhabitedType,
67+
UnionType,
6468
get_proper_type,
6569
)
6670
from mypy.typevars import fill_typevars
@@ -372,7 +376,6 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
372376
arg_names=arg_names,
373377
ret_type=NoneType(),
374378
fallback=self._api.named_type("builtins.function"),
375-
name=f"replace of {self._cls.info.name}",
376379
)
377380

378381
self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
@@ -923,6 +926,91 @@ def _has_direct_dataclass_transform_metaclass(info: TypeInfo) -> bool:
923926
)
924927

925928

929+
def _fail_not_dataclass(ctx: FunctionSigContext, t: Type, parent_t: Type) -> None:
930+
t_name = format_type_bare(t, ctx.api.options)
931+
if parent_t is t:
932+
msg = (
933+
f'Argument 1 to "replace" has a variable type "{t_name}" not bound to a dataclass'
934+
if isinstance(t, TypeVarType)
935+
else f'Argument 1 to "replace" has incompatible type "{t_name}"; expected a dataclass'
936+
)
937+
else:
938+
pt_name = format_type_bare(parent_t, ctx.api.options)
939+
msg = (
940+
f'Argument 1 to "replace" has type "{pt_name}" whose item "{t_name}" is not bound to a dataclass'
941+
if isinstance(t, TypeVarType)
942+
else f'Argument 1 to "replace" has incompatible type "{pt_name}" whose item "{t_name}" is not a dataclass'
943+
)
944+
945+
ctx.api.fail(msg, ctx.context)
946+
947+
948+
def _get_expanded_dataclasses_fields(
949+
ctx: FunctionSigContext, typ: ProperType, display_typ: ProperType, parent_typ: ProperType
950+
) -> list[CallableType] | None:
951+
"""
952+
For a given type, determine what dataclasses it can be: for each class, return the field types.
953+
For generic classes, the field types are expanded.
954+
If the type contains Any or a non-dataclass, returns None; in the latter case, also reports an error.
955+
"""
956+
if isinstance(typ, AnyType):
957+
return None
958+
elif isinstance(typ, UnionType):
959+
ret: list[CallableType] | None = []
960+
for item in typ.relevant_items():
961+
item = get_proper_type(item)
962+
item_types = _get_expanded_dataclasses_fields(ctx, item, item, parent_typ)
963+
if ret is not None and item_types is not None:
964+
ret += item_types
965+
else:
966+
ret = None # but keep iterating to emit all errors
967+
return ret
968+
elif isinstance(typ, TypeVarType):
969+
return _get_expanded_dataclasses_fields(
970+
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
971+
)
972+
elif isinstance(typ, Instance):
973+
replace_sym = typ.type.get_method(_INTERNAL_REPLACE_SYM_NAME)
974+
if replace_sym is None:
975+
_fail_not_dataclass(ctx, display_typ, parent_typ)
976+
return None
977+
replace_sig = get_proper_type(replace_sym.type)
978+
assert isinstance(replace_sig, CallableType)
979+
return [expand_type_by_instance(replace_sig, typ)]
980+
else:
981+
_fail_not_dataclass(ctx, display_typ, parent_typ)
982+
return None
983+
984+
985+
def _meet_replace_sigs(sigs: list[CallableType]) -> CallableType:
986+
"""
987+
Produces the lowest bound of the 'replace' signatures of multiple dataclasses.
988+
"""
989+
args = {
990+
name: (typ, kind)
991+
for name, typ, kind in zip(sigs[0].arg_names, sigs[0].arg_types, sigs[0].arg_kinds)
992+
}
993+
994+
for sig in sigs[1:]:
995+
sig_args = {
996+
name: (typ, kind)
997+
for name, typ, kind in zip(sig.arg_names, sig.arg_types, sig.arg_kinds)
998+
}
999+
for name in (*args.keys(), *sig_args.keys()):
1000+
sig_typ, sig_kind = args.get(name, (UninhabitedType(), ARG_NAMED_OPT))
1001+
sig2_typ, sig2_kind = sig_args.get(name, (UninhabitedType(), ARG_NAMED_OPT))
1002+
args[name] = (
1003+
meet_types(sig_typ, sig2_typ),
1004+
ARG_NAMED_OPT if sig_kind == sig2_kind == ARG_NAMED_OPT else ARG_NAMED,
1005+
)
1006+
1007+
return sigs[0].copy_modified(
1008+
arg_names=list(args.keys()),
1009+
arg_types=[typ for typ, _ in args.values()],
1010+
arg_kinds=[kind for _, kind in args.values()],
1011+
)
1012+
1013+
9261014
def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
9271015
"""
9281016
Returns a signature for the 'dataclasses.replace' function that's dependent on the type
@@ -946,34 +1034,18 @@ def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
9461034
# </hack>
9471035

9481036
obj_type = get_proper_type(obj_type)
949-
obj_type_str = format_type_bare(obj_type)
950-
if isinstance(obj_type, AnyType):
951-
return ctx.default_signature # replace(Any, ...) -> Any
1037+
inst_type_str = format_type_bare(obj_type, ctx.api.options)
9521038

953-
dataclass_type = get_proper_type(
954-
obj_type.upper_bound if isinstance(obj_type, TypeVarType) else obj_type
955-
)
956-
replace_func = None
957-
if isinstance(dataclass_type, Instance):
958-
replace_func = dataclass_type.type.get_method(_INTERNAL_REPLACE_SYM_NAME)
959-
if replace_func is None:
960-
ctx.api.fail(
961-
f'Argument 1 to "replace" has variable type "{obj_type_str}" not bound to a dataclass'
962-
if isinstance(obj_type, TypeVarType)
963-
else f'Argument 1 to "replace" has incompatible type "{obj_type_str}"; expected a dataclass',
964-
ctx.context,
965-
)
1039+
replace_sigs = _get_expanded_dataclasses_fields(ctx, obj_type, obj_type, obj_type)
1040+
if replace_sigs is None:
9661041
return ctx.default_signature
967-
assert isinstance(dataclass_type, Instance)
968-
969-
signature = get_proper_type(replace_func.type)
970-
assert isinstance(signature, CallableType)
971-
signature = expand_type_by_instance(signature, dataclass_type)
972-
# re-add the instance type
973-
return signature.copy_modified(
974-
arg_types=[obj_type, *signature.arg_types],
975-
arg_kinds=[ARG_POS, *signature.arg_kinds],
976-
arg_names=[None, *signature.arg_names],
1042+
replace_sig = _meet_replace_sigs(replace_sigs)
1043+
1044+
return replace_sig.copy_modified(
1045+
arg_names=[None, *replace_sig.arg_names],
1046+
arg_kinds=[ARG_POS, *replace_sig.arg_kinds],
1047+
arg_types=[obj_type, *replace_sig.arg_types],
9771048
ret_type=obj_type,
978-
name=f"{ctx.default_signature.name} of {obj_type_str}",
1049+
fallback=ctx.default_signature.fallback,
1050+
name=f"{ctx.default_signature.name} of {inst_type_str}",
9791051
)

test-data/unit/check-dataclasses.test

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,39 +2059,101 @@ a2 = replace(a, x='42', q=42) # E: Argument "x" to "replace" of "A" has incompa
20592059
a2 = replace(a, q='42') # E: Argument "q" to "replace" of "A" has incompatible type "str"; expected "int"
20602060
reveal_type(a2) # N: Revealed type is "__main__.A"
20612061

2062+
[case testReplaceUnion]
2063+
# flags: --strict-optional
2064+
from typing import Generic, Union, TypeVar
2065+
from dataclasses import dataclass, replace, InitVar
2066+
2067+
T = TypeVar('T')
2068+
2069+
@dataclass
2070+
class A(Generic[T]):
2071+
x: T # exercises meet(T=int, int) = int
2072+
y: bool # exercises meet(bool, int) = bool
2073+
z: str # exercises meet(str, bytes) = <nothing>
2074+
w: dict # exercises meet(dict, <nothing>) = <nothing>
2075+
a: InitVar[int] # exercises (non-optional, optional) = non-optional
2076+
2077+
@dataclass
2078+
class B:
2079+
x: int
2080+
y: bool
2081+
z: bytes
2082+
a: int
2083+
2084+
2085+
a_or_b: Union[A[int], B]
2086+
_ = replace(a_or_b, x=42, y=True, a=42)
2087+
_ = replace(a_or_b, x=42, y=True) # E: Missing named argument "a" for "replace" of "Union[A[int], B]"
2088+
_ = replace(a_or_b, x=42, y=True, z='42', a=42) # E: Argument "z" to "replace" of "Union[A[int], B]" has incompatible type "str"; expected <nothing>
2089+
_ = replace(a_or_b, x=42, y=True, w={}, a=42) # E: Argument "w" to "replace" of "Union[A[int], B]" has incompatible type "Dict[<nothing>, <nothing>]"; expected <nothing>
2090+
20622091
[builtins fixtures/dataclasses.pyi]
20632092

2064-
[case testReplaceTypeVar]
2093+
[case testReplaceUnionOfTypeVar]
2094+
# flags: --strict-optional
2095+
from typing import Generic, Union, TypeVar
20652096
from dataclasses import dataclass, replace
2066-
from typing import TypeVar
20672097

20682098
@dataclass
20692099
class A:
20702100
x: int
2101+
y: int
2102+
z: str
2103+
w: dict
2104+
2105+
class B:
2106+
pass
20712107

20722108
TA = TypeVar('TA', bound=A)
2109+
TB = TypeVar('TB', bound=B)
2110+
2111+
def f(b_or_t: Union[TA, TB, int]) -> None:
2112+
a2 = replace(b_or_t) # E: Argument 1 to "replace" has type "Union[TA, TB, int]" whose item "TB" is not bound to a dataclass # E: Argument 1 to "replace" has incompatible type "Union[TA, TB, int]" whose item "int" is not a dataclass
2113+
2114+
[case testReplaceTypeVarBoundNotDataclass]
2115+
from dataclasses import dataclass, replace
2116+
from typing import Union, TypeVar
2117+
20732118
TInt = TypeVar('TInt', bound=int)
20742119
TAny = TypeVar('TAny')
20752120
TNone = TypeVar('TNone', bound=None)
2121+
TUnion = TypeVar('TUnion', bound=Union[str, int])
20762122

2123+
def f1(t: TInt) -> None:
2124+
_ = replace(t, x=42) # E: Argument 1 to "replace" has a variable type "TInt" not bound to a dataclass
20772125

2078-
def f(t: TA) -> TA:
2079-
_ = replace(t, x='spam') # E: Argument "x" to "replace" of "TA" has incompatible type "str"; expected "int"
2080-
return replace(t, x=42)
2126+
def f2(t: TAny) -> TAny:
2127+
return replace(t, x='spam') # E: Argument 1 to "replace" has a variable type "TAny" not bound to a dataclass
20812128

2129+
def f3(t: TNone) -> TNone:
2130+
return replace(t, x='spam') # E: Argument 1 to "replace" has a variable type "TNone" not bound to a dataclass
20822131

2083-
def g(t: TInt) -> None:
2084-
_ = replace(t, x=42) # E: Argument 1 to "replace" has variable type "TInt" not bound to a dataclass
2132+
def f4(t: TUnion) -> TUnion:
2133+
return replace(t, x='spam') # E: Argument 1 to "replace" has incompatible type "TUnion" whose item "str" is not a dataclass # E: Argument 1 to "replace" has incompatible type "TUnion" whose item "int" is not a dataclass
20852134

2135+
[case testReplaceTypeVarBound]
2136+
from dataclasses import dataclass, replace
2137+
from typing import TypeVar
20862138

2087-
def h(t: TAny) -> TAny:
2088-
return replace(t, x='spam') # E: Argument 1 to "replace" has variable type "TAny" not bound to a dataclass
2139+
@dataclass
2140+
class A:
2141+
x: int
20892142

2143+
@dataclass
2144+
class B(A):
2145+
pass
20902146

2091-
def q(t: TNone) -> TNone:
2092-
return replace(t, x='spam') # E: Argument 1 to "replace" has variable type "TNone" not bound to a dataclass
2147+
TA = TypeVar('TA', bound=A)
20932148

2094-
[builtins fixtures/dataclasses.pyi]
2149+
def f(t: TA) -> TA:
2150+
t2 = replace(t, x=42)
2151+
reveal_type(t2) # N: Revealed type is "TA`-1"
2152+
_ = replace(t, x='42') # E: Argument "x" to "replace" of "TA" has incompatible type "str"; expected "int"
2153+
return t2
2154+
2155+
f(A(x=42))
2156+
f(B(x=42))
20952157

20962158
[case testReplaceAny]
20972159
from dataclasses import replace
@@ -2101,8 +2163,6 @@ a: Any
21012163
a2 = replace(a)
21022164
reveal_type(a2) # N: Revealed type is "Any"
21032165

2104-
[builtins fixtures/dataclasses.pyi]
2105-
21062166
[case testReplaceNotDataclass]
21072167
from dataclasses import replace
21082168

@@ -2125,7 +2185,6 @@ T = TypeVar('T')
21252185
class A(Generic[T]):
21262186
x: T
21272187

2128-
21292188
a = A(x=42)
21302189
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int]"
21312190
a2 = replace(a, x=42)

0 commit comments

Comments
 (0)