Skip to content

Commit 9a35360

Browse files
Add add_overloaded_method_to_class helper to plugins/common.py (#16038)
There are several changes: 1. `add_overloaded_method_to_class` itself. It is very useful for plugin authors, because right now it is quite easy to add a regular method, but it is very hard to add a method with `@overload`s. I don't think that user must face all the chalenges that I've covered in this method. Moreover, it is quite easy even for experienced developers to forget some flags / props / etc (I am pretty sure that I might forgot something in the implementation) 2. `add_overloaded_method_to_class` and `add_method_to_class` now return added nodes, it is also helpful if you want to do something with this node in your plugin after it is created 3. I've refactored how `add_method_to_class` works and reused its parts in the new method as well 4. `tvar_def` in `add_method_to_class` can now accept a list of type vars, not just one Notice that `add_method_to_class` is unchanged from the user's POV, it should continue to work as before. Tests are also updated to check that our overloads are correct. Things to do later (in the next PRs / releases): 1. We can possibly add `is_final` param to methods as well 2. We can also support `@property` in a separate method at some point --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ed18fea commit 9a35360

File tree

5 files changed

+222
-23
lines changed

5 files changed

+222
-23
lines changed

mypy/plugins/common.py

Lines changed: 116 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import NamedTuple
4+
35
from mypy.argmap import map_actuals_to_formals
46
from mypy.fixup import TypeFixer
57
from mypy.nodes import (
@@ -16,9 +18,11 @@
1618
JsonDict,
1719
NameExpr,
1820
Node,
21+
OverloadedFuncDef,
1922
PassStmt,
2023
RefExpr,
2124
SymbolTableNode,
25+
TypeInfo,
2226
Var,
2327
)
2428
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
@@ -209,24 +213,99 @@ def add_method(
209213
)
210214

211215

216+
class MethodSpec(NamedTuple):
217+
"""Represents a method signature to be added, except for `name`."""
218+
219+
args: list[Argument]
220+
return_type: Type
221+
self_type: Type | None = None
222+
tvar_defs: list[TypeVarType] | None = None
223+
224+
212225
def add_method_to_class(
213226
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
214227
cls: ClassDef,
215228
name: str,
229+
# MethodSpec items kept for backward compatibility:
216230
args: list[Argument],
217231
return_type: Type,
218232
self_type: Type | None = None,
219-
tvar_def: TypeVarType | None = None,
233+
tvar_def: list[TypeVarType] | TypeVarType | None = None,
220234
is_classmethod: bool = False,
221235
is_staticmethod: bool = False,
222-
) -> None:
236+
) -> FuncDef | Decorator:
223237
"""Adds a new method to a class definition."""
238+
_prepare_class_namespace(cls, name)
224239

225-
assert not (
226-
is_classmethod is True and is_staticmethod is True
227-
), "Can't add a new method that's both staticmethod and classmethod."
240+
if tvar_def is not None and not isinstance(tvar_def, list):
241+
tvar_def = [tvar_def]
242+
243+
func, sym = _add_method_by_spec(
244+
api,
245+
cls.info,
246+
name,
247+
MethodSpec(args=args, return_type=return_type, self_type=self_type, tvar_defs=tvar_def),
248+
is_classmethod=is_classmethod,
249+
is_staticmethod=is_staticmethod,
250+
)
251+
cls.info.names[name] = sym
252+
cls.info.defn.defs.body.append(func)
253+
return func
228254

255+
256+
def add_overloaded_method_to_class(
257+
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
258+
cls: ClassDef,
259+
name: str,
260+
items: list[MethodSpec],
261+
is_classmethod: bool = False,
262+
is_staticmethod: bool = False,
263+
) -> OverloadedFuncDef:
264+
"""Adds a new overloaded method to a class definition."""
265+
assert len(items) >= 2, "Overloads must contain at least two cases"
266+
267+
# Save old definition, if it exists.
268+
_prepare_class_namespace(cls, name)
269+
270+
# Create function bodies for each passed method spec.
271+
funcs: list[Decorator | FuncDef] = []
272+
for item in items:
273+
func, _sym = _add_method_by_spec(
274+
api,
275+
cls.info,
276+
name=name,
277+
spec=item,
278+
is_classmethod=is_classmethod,
279+
is_staticmethod=is_staticmethod,
280+
)
281+
if isinstance(func, FuncDef):
282+
var = Var(func.name, func.type)
283+
var.set_line(func.line)
284+
func.is_decorated = True
285+
func.deco_line = func.line
286+
287+
deco = Decorator(func, [], var)
288+
else:
289+
deco = func
290+
deco.is_overload = True
291+
funcs.append(deco)
292+
293+
# Create the final OverloadedFuncDef node:
294+
overload_def = OverloadedFuncDef(funcs)
295+
overload_def.info = cls.info
296+
overload_def.is_class = is_classmethod
297+
overload_def.is_static = is_staticmethod
298+
sym = SymbolTableNode(MDEF, overload_def)
299+
sym.plugin_generated = True
300+
301+
cls.info.names[name] = sym
302+
cls.info.defn.defs.body.append(overload_def)
303+
return overload_def
304+
305+
306+
def _prepare_class_namespace(cls: ClassDef, name: str) -> None:
229307
info = cls.info
308+
assert info
230309

231310
# First remove any previously generated methods with the same name
232311
# to avoid clashes and problems in the semantic analyzer.
@@ -235,6 +314,29 @@ def add_method_to_class(
235314
if sym.plugin_generated and isinstance(sym.node, FuncDef):
236315
cls.defs.body.remove(sym.node)
237316

317+
# NOTE: we would like the plugin generated node to dominate, but we still
318+
# need to keep any existing definitions so they get semantically analyzed.
319+
if name in info.names:
320+
# Get a nice unique name instead.
321+
r_name = get_unique_redefinition_name(name, info.names)
322+
info.names[r_name] = info.names[name]
323+
324+
325+
def _add_method_by_spec(
326+
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
327+
info: TypeInfo,
328+
name: str,
329+
spec: MethodSpec,
330+
*,
331+
is_classmethod: bool,
332+
is_staticmethod: bool,
333+
) -> tuple[FuncDef | Decorator, SymbolTableNode]:
334+
args, return_type, self_type, tvar_defs = spec
335+
336+
assert not (
337+
is_classmethod is True and is_staticmethod is True
338+
), "Can't add a new method that's both staticmethod and classmethod."
339+
238340
if isinstance(api, SemanticAnalyzerPluginInterface):
239341
function_type = api.named_type("builtins.function")
240342
else:
@@ -258,8 +360,8 @@ def add_method_to_class(
258360
arg_kinds.append(arg.kind)
259361

260362
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
261-
if tvar_def:
262-
signature.variables = [tvar_def]
363+
if tvar_defs:
364+
signature.variables = tvar_defs
263365

264366
func = FuncDef(name, args, Block([PassStmt()]))
265367
func.info = info
@@ -269,13 +371,6 @@ def add_method_to_class(
269371
func._fullname = info.fullname + "." + name
270372
func.line = info.line
271373

272-
# NOTE: we would like the plugin generated node to dominate, but we still
273-
# need to keep any existing definitions so they get semantically analyzed.
274-
if name in info.names:
275-
# Get a nice unique name instead.
276-
r_name = get_unique_redefinition_name(name, info.names)
277-
info.names[r_name] = info.names[name]
278-
279374
# Add decorator for is_staticmethod. It's unnecessary for is_classmethod.
280375
if is_staticmethod:
281376
func.is_decorated = True
@@ -286,12 +381,12 @@ def add_method_to_class(
286381
dec = Decorator(func, [], v)
287382
dec.line = info.line
288383
sym = SymbolTableNode(MDEF, dec)
289-
else:
290-
sym = SymbolTableNode(MDEF, func)
291-
sym.plugin_generated = True
292-
info.names[name] = sym
384+
sym.plugin_generated = True
385+
return dec, sym
293386

294-
info.defn.defs.body.append(func)
387+
sym = SymbolTableNode(MDEF, func)
388+
sym.plugin_generated = True
389+
return func, sym
295390

296391

297392
def add_attribute_to_class(
@@ -304,7 +399,7 @@ def add_attribute_to_class(
304399
override_allow_incompatible: bool = False,
305400
fullname: str | None = None,
306401
is_classvar: bool = False,
307-
) -> None:
402+
) -> Var:
308403
"""
309404
Adds a new attribute to a class definition.
310405
This currently only generates the symbol table entry and no corresponding AssignmentStatement
@@ -335,6 +430,7 @@ def add_attribute_to_class(
335430
info.names[name] = SymbolTableNode(
336431
MDEF, node, plugin_generated=True, no_serialize=no_serialize
337432
)
433+
return node
338434

339435

340436
def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type:

test-data/unit/check-custom-plugin.test

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,13 +1011,35 @@ class BaseAddMethod: pass
10111011
class MyClass(BaseAddMethod):
10121012
pass
10131013

1014-
my_class = MyClass()
10151014
reveal_type(MyClass.foo_classmethod) # N: Revealed type is "def ()"
10161015
reveal_type(MyClass.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str"
1016+
1017+
my_class = MyClass()
1018+
reveal_type(my_class.foo_classmethod) # N: Revealed type is "def ()"
1019+
reveal_type(my_class.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str"
10171020
[file mypy.ini]
10181021
\[mypy]
10191022
plugins=<ROOT>/test-data/unit/plugins/add_classmethod.py
10201023

1024+
[case testAddOverloadedMethodPlugin]
1025+
# flags: --config-file tmp/mypy.ini
1026+
class AddOverloadedMethod: pass
1027+
1028+
class MyClass(AddOverloadedMethod):
1029+
pass
1030+
1031+
reveal_type(MyClass.method) # N: Revealed type is "Overload(def (self: __main__.MyClass, arg: builtins.int) -> builtins.str, def (self: __main__.MyClass, arg: builtins.str) -> builtins.int)"
1032+
reveal_type(MyClass.clsmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
1033+
reveal_type(MyClass.stmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
1034+
1035+
my_class = MyClass()
1036+
reveal_type(my_class.method) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
1037+
reveal_type(my_class.clsmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
1038+
reveal_type(my_class.stmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
1039+
[file mypy.ini]
1040+
\[mypy]
1041+
plugins=<ROOT>/test-data/unit/plugins/add_overloaded_method.py
1042+
10211043
[case testCustomErrorCodePlugin]
10221044
# flags: --config-file tmp/mypy.ini --show-error-codes
10231045
def main() -> int:

test-data/unit/check-incremental.test

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5935,6 +5935,44 @@ tmp/b.py:4: note: Revealed type is "def ()"
59355935
tmp/b.py:5: note: Revealed type is "def (builtins.int) -> builtins.str"
59365936
tmp/b.py:6: note: Revealed type is "def ()"
59375937
tmp/b.py:7: note: Revealed type is "def (builtins.int) -> builtins.str"
5938+
5939+
[case testIncrementalAddOverloadedMethodPlugin]
5940+
# flags: --config-file tmp/mypy.ini
5941+
import b
5942+
5943+
[file mypy.ini]
5944+
\[mypy]
5945+
plugins=<ROOT>/test-data/unit/plugins/add_overloaded_method.py
5946+
5947+
[file a.py]
5948+
class AddOverloadedMethod: pass
5949+
5950+
class MyClass(AddOverloadedMethod):
5951+
pass
5952+
5953+
[file b.py]
5954+
import a
5955+
5956+
[file b.py.2]
5957+
import a
5958+
5959+
reveal_type(a.MyClass.method)
5960+
reveal_type(a.MyClass.clsmethod)
5961+
reveal_type(a.MyClass.stmethod)
5962+
5963+
my_class = a.MyClass()
5964+
reveal_type(my_class.method)
5965+
reveal_type(my_class.clsmethod)
5966+
reveal_type(my_class.stmethod)
5967+
[rechecked b]
5968+
[out2]
5969+
tmp/b.py:3: note: Revealed type is "Overload(def (self: a.MyClass, arg: builtins.int) -> builtins.str, def (self: a.MyClass, arg: builtins.str) -> builtins.int)"
5970+
tmp/b.py:4: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
5971+
tmp/b.py:5: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
5972+
tmp/b.py:8: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
5973+
tmp/b.py:9: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
5974+
tmp/b.py:10: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
5975+
59385976
[case testGenericNamedTupleSerialization]
59395977
import b
59405978
[file a.py]

test-data/unit/deps.test

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,12 +1387,13 @@ class B(A):
13871387
<m.A.(abstract)> -> <m.B.__init__>, m
13881388
<m.A.__dataclass_fields__> -> <m.B.__dataclass_fields__>
13891389
<m.A.__init__> -> <m.B.__init__>, m.B.__init__
1390-
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m.B.__mypy-replace
1390+
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m, m.B.__mypy-replace
13911391
<m.A.__new__> -> <m.B.__new__>
13921392
<m.A.x> -> <m.B.x>
13931393
<m.A.y> -> <m.B.y>
13941394
<m.A> -> m, m.A, m.B
13951395
<m.A[wildcard]> -> m
1396+
<m.B.__mypy-replace> -> m
13961397
<m.B.y> -> m
13971398
<m.B> -> m.B
13981399
<m.Z> -> m
@@ -1419,12 +1420,13 @@ class B(A):
14191420
<m.A.__dataclass_fields__> -> <m.B.__dataclass_fields__>
14201421
<m.A.__init__> -> <m.B.__init__>, m.B.__init__
14211422
<m.A.__match_args__> -> <m.B.__match_args__>
1422-
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m.B.__mypy-replace
1423+
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m, m.B.__mypy-replace
14231424
<m.A.__new__> -> <m.B.__new__>
14241425
<m.A.x> -> <m.B.x>
14251426
<m.A.y> -> <m.B.y>
14261427
<m.A> -> m, m.A, m.B
14271428
<m.A[wildcard]> -> m
1429+
<m.B.__mypy-replace> -> m
14281430
<m.B.y> -> m
14291431
<m.B> -> m.B
14301432
<m.Z> -> m
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
5+
from mypy.nodes import ARG_POS, Argument, Var
6+
from mypy.plugin import ClassDefContext, Plugin
7+
from mypy.plugins.common import MethodSpec, add_overloaded_method_to_class
8+
9+
10+
class OverloadedMethodPlugin(Plugin):
11+
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
12+
if "AddOverloadedMethod" in fullname:
13+
return add_overloaded_method_hook
14+
return None
15+
16+
17+
def add_overloaded_method_hook(ctx: ClassDefContext) -> None:
18+
add_overloaded_method_to_class(ctx.api, ctx.cls, "method", _generate_method_specs(ctx))
19+
add_overloaded_method_to_class(
20+
ctx.api, ctx.cls, "clsmethod", _generate_method_specs(ctx), is_classmethod=True
21+
)
22+
add_overloaded_method_to_class(
23+
ctx.api, ctx.cls, "stmethod", _generate_method_specs(ctx), is_staticmethod=True
24+
)
25+
26+
27+
def _generate_method_specs(ctx: ClassDefContext) -> list[MethodSpec]:
28+
return [
29+
MethodSpec(
30+
args=[Argument(Var("arg"), ctx.api.named_type("builtins.int"), None, ARG_POS)],
31+
return_type=ctx.api.named_type("builtins.str"),
32+
),
33+
MethodSpec(
34+
args=[Argument(Var("arg"), ctx.api.named_type("builtins.str"), None, ARG_POS)],
35+
return_type=ctx.api.named_type("builtins.int"),
36+
),
37+
]
38+
39+
40+
def plugin(version: str) -> type[OverloadedMethodPlugin]:
41+
return OverloadedMethodPlugin

0 commit comments

Comments
 (0)