Skip to content

Commit 43883fa

Browse files
authored
[mypyc] Optimize classmethod calls via cls (#14789)
If the class has no subclasses, we can statically bind the call target: ``` class C: @classmethod def f(cls) -> int: return cls.g() # This can be statically bound, same as C.g() @classmethod def g(cls) -> int: return 1 ``` For this to be safe, also reject assignments to the "cls" argument in classmethods in compiled code. This makes the deltablue benchmark about 11% faster.
1 parent 9393c22 commit 43883fa

File tree

7 files changed

+217
-50
lines changed

7 files changed

+217
-50
lines changed

mypy/nodes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ def deserialize(cls, data: JsonDict) -> Decorator:
901901

902902
VAR_FLAGS: Final = [
903903
"is_self",
904+
"is_cls",
904905
"is_initialized_in_class",
905906
"is_staticmethod",
906907
"is_classmethod",
@@ -935,6 +936,7 @@ class Var(SymbolNode):
935936
"type",
936937
"final_value",
937938
"is_self",
939+
"is_cls",
938940
"is_ready",
939941
"is_inferred",
940942
"is_initialized_in_class",
@@ -967,6 +969,8 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
967969
self.type: mypy.types.Type | None = type # Declared or inferred type, or None
968970
# Is this the first argument to an ordinary method (usually "self")?
969971
self.is_self = False
972+
# Is this the first argument to a classmethod (typically "cls")?
973+
self.is_cls = False
970974
self.is_ready = True # If inferred, is the inferred type available?
971975
self.is_inferred = self.type is None
972976
# Is this initialized explicitly to a non-None value in class body?

mypy/semanal.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,8 +1369,11 @@ def analyze_function_body(self, defn: FuncItem) -> None:
13691369
# The first argument of a non-static, non-class method is like 'self'
13701370
# (though the name could be different), having the enclosing class's
13711371
# instance type.
1372-
if is_method and not defn.is_static and not defn.is_class and defn.arguments:
1373-
defn.arguments[0].variable.is_self = True
1372+
if is_method and not defn.is_static and defn.arguments:
1373+
if not defn.is_class:
1374+
defn.arguments[0].variable.is_self = True
1375+
else:
1376+
defn.arguments[0].variable.is_cls = True
13741377

13751378
defn.body.accept(self)
13761379
self.function_stack.pop()

mypyc/ir/class_ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def __init__(
169169
self.base_mro: list[ClassIR] = [self]
170170

171171
# Direct subclasses of this class (use subclasses() to also include non-direct ones)
172-
# None if separate compilation prevents this from working
172+
# None if separate compilation prevents this from working.
173+
#
174+
# Often it's better to use has_no_subclasses() or subclasses() instead.
173175
self.children: list[ClassIR] | None = []
174176

175177
# Instance attributes that are initialized in the class body.
@@ -301,6 +303,9 @@ def get_method(self, name: str, *, prefer_method: bool = False) -> FuncIR | None
301303
def has_method_decl(self, name: str) -> bool:
302304
return any(name in ir.method_decls for ir in self.mro)
303305

306+
def has_no_subclasses(self) -> bool:
307+
return self.children == [] and not self.allow_interpreted_subclasses
308+
304309
def subclasses(self) -> set[ClassIR] | None:
305310
"""Return all subclasses of this class, both direct and indirect.
306311

mypyc/irbuild/builder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,11 @@ def load_final_literal_value(self, val: int | str | bytes | float | bool, line:
567567
else:
568568
assert False, "Unsupported final literal value"
569569

570-
def get_assignment_target(self, lvalue: Lvalue, line: int = -1) -> AssignmentTarget:
570+
def get_assignment_target(
571+
self, lvalue: Lvalue, line: int = -1, *, for_read: bool = False
572+
) -> AssignmentTarget:
573+
if line == -1:
574+
line = lvalue.line
571575
if isinstance(lvalue, NameExpr):
572576
# If we are visiting a decorator, then the SymbolNode we really want to be looking at
573577
# is the function that is decorated, not the entire Decorator node itself.
@@ -578,6 +582,8 @@ def get_assignment_target(self, lvalue: Lvalue, line: int = -1) -> AssignmentTar
578582
# New semantic analyzer doesn't create ad-hoc Vars for special forms.
579583
assert lvalue.is_special_form
580584
symbol = Var(lvalue.name)
585+
if not for_read and isinstance(symbol, Var) and symbol.is_cls:
586+
self.error("Cannot assign to the first argument of classmethod", line)
581587
if lvalue.kind == LDEF:
582588
if symbol not in self.symtables[-1]:
583589
# If the function is a generator function, then first define a new variable

mypyc/irbuild/expression.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from mypy.types import Instance, ProperType, TupleType, TypeType, get_proper_type
5050
from mypyc.common import MAX_SHORT_INT
51+
from mypyc.ir.class_ir import ClassIR
5152
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD
5253
from mypyc.ir.ops import (
5354
Assign,
@@ -174,7 +175,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
174175
)
175176
return obj
176177
else:
177-
return builder.read(builder.get_assignment_target(expr), expr.line)
178+
return builder.read(builder.get_assignment_target(expr, for_read=True), expr.line)
178179

179180
return builder.load_global(expr)
180181

@@ -336,30 +337,7 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
336337
# Call a method via the *class*
337338
assert isinstance(callee.expr.node, TypeInfo)
338339
ir = builder.mapper.type_to_ir[callee.expr.node]
339-
decl = ir.method_decl(callee.name)
340-
args = []
341-
arg_kinds, arg_names = expr.arg_kinds[:], expr.arg_names[:]
342-
# Add the class argument for class methods in extension classes
343-
if decl.kind == FUNC_CLASSMETHOD and ir.is_ext_class:
344-
args.append(builder.load_native_type_object(callee.expr.node.fullname))
345-
arg_kinds.insert(0, ARG_POS)
346-
arg_names.insert(0, None)
347-
args += [builder.accept(arg) for arg in expr.args]
348-
349-
if ir.is_ext_class:
350-
return builder.builder.call(decl, args, arg_kinds, arg_names, expr.line)
351-
else:
352-
obj = builder.accept(callee.expr)
353-
return builder.gen_method_call(
354-
obj,
355-
callee.name,
356-
args,
357-
builder.node_type(expr),
358-
expr.line,
359-
expr.arg_kinds,
360-
expr.arg_names,
361-
)
362-
340+
return call_classmethod(builder, ir, expr, callee)
363341
elif builder.is_module_member_expr(callee):
364342
# Fall back to a PyCall for non-native module calls
365343
function = builder.accept(callee)
@@ -368,6 +346,17 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
368346
function, args, expr.line, arg_kinds=expr.arg_kinds, arg_names=expr.arg_names
369347
)
370348
else:
349+
if isinstance(callee.expr, RefExpr):
350+
node = callee.expr.node
351+
if isinstance(node, Var) and node.is_cls:
352+
typ = get_proper_type(node.type)
353+
if isinstance(typ, TypeType) and isinstance(typ.item, Instance):
354+
class_ir = builder.mapper.type_to_ir.get(typ.item.type)
355+
if class_ir and class_ir.is_ext_class and class_ir.has_no_subclasses():
356+
# Call a native classmethod via cls that can be statically bound,
357+
# since the class has no subclasses.
358+
return call_classmethod(builder, class_ir, expr, callee)
359+
371360
receiver_typ = builder.node_type(callee.expr)
372361

373362
# If there is a specializer for this method name/type, try calling it.
@@ -389,6 +378,32 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
389378
)
390379

391380

381+
def call_classmethod(builder: IRBuilder, ir: ClassIR, expr: CallExpr, callee: MemberExpr) -> Value:
382+
decl = ir.method_decl(callee.name)
383+
args = []
384+
arg_kinds, arg_names = expr.arg_kinds[:], expr.arg_names[:]
385+
# Add the class argument for class methods in extension classes
386+
if decl.kind == FUNC_CLASSMETHOD and ir.is_ext_class:
387+
args.append(builder.load_native_type_object(ir.fullname))
388+
arg_kinds.insert(0, ARG_POS)
389+
arg_names.insert(0, None)
390+
args += [builder.accept(arg) for arg in expr.args]
391+
392+
if ir.is_ext_class:
393+
return builder.builder.call(decl, args, arg_kinds, arg_names, expr.line)
394+
else:
395+
obj = builder.accept(callee.expr)
396+
return builder.gen_method_call(
397+
obj,
398+
callee.name,
399+
args,
400+
builder.node_type(expr),
401+
expr.line,
402+
expr.arg_kinds,
403+
expr.arg_names,
404+
)
405+
406+
392407
def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: SuperExpr) -> Value:
393408
if callee.info is None or (len(callee.call.args) != 0 and len(callee.call.args) != 2):
394409
return translate_call(builder, expr, callee)

mypyc/test-data/irbuild-classes.test

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,75 @@ L0:
656656
r3 = CPyTagged_Add(r0, r2)
657657
return r3
658658

659+
[case testCallClassMethodViaCls]
660+
class C:
661+
@classmethod
662+
def f(cls, x: int) -> int:
663+
return cls.g(x)
664+
665+
@classmethod
666+
def g(cls, x: int) -> int:
667+
return x
668+
669+
class D:
670+
@classmethod
671+
def f(cls, x: int) -> int:
672+
# TODO: This could aso be optimized, since g is not ever overridden
673+
return cls.g(x)
674+
675+
@classmethod
676+
def g(cls, x: int) -> int:
677+
return x
678+
679+
class DD(D):
680+
pass
681+
[out]
682+
def C.f(cls, x):
683+
cls :: object
684+
x :: int
685+
r0 :: object
686+
r1 :: int
687+
L0:
688+
r0 = __main__.C :: type
689+
r1 = C.g(r0, x)
690+
return r1
691+
def C.g(cls, x):
692+
cls :: object
693+
x :: int
694+
L0:
695+
return x
696+
def D.f(cls, x):
697+
cls :: object
698+
x :: int
699+
r0 :: str
700+
r1, r2 :: object
701+
r3 :: int
702+
L0:
703+
r0 = 'g'
704+
r1 = box(int, x)
705+
r2 = CPyObject_CallMethodObjArgs(cls, r0, r1, 0)
706+
r3 = unbox(int, r2)
707+
return r3
708+
def D.g(cls, x):
709+
cls :: object
710+
x :: int
711+
L0:
712+
return x
713+
714+
[case testCannotAssignToClsArgument]
715+
from typing import Any, cast
716+
717+
class C:
718+
@classmethod
719+
def m(cls) -> None:
720+
cls = cast(Any, D) # E: Cannot assign to the first argument of classmethod
721+
cls, x = cast(Any, D), 1 # E: Cannot assign to the first argument of classmethod
722+
cls, x = cast(Any, [1, 2]) # E: Cannot assign to the first argument of classmethod
723+
cls.m()
724+
725+
class D:
726+
pass
727+
659728
[case testSuper1]
660729
class A:
661730
def __init__(self, x: int) -> None:

mypyc/test-data/run-classes.test

Lines changed: 86 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -662,42 +662,107 @@ Traceback (most recent call last):
662662
AttributeError: attribute 'x' of 'X' undefined
663663

664664
[case testClassMethods]
665-
MYPY = False
666-
if MYPY:
667-
from typing import ClassVar
665+
from typing import ClassVar, Any
666+
from typing_extensions import final
667+
from mypy_extensions import mypyc_attr
668+
669+
from interp import make_interpreted_subclass
670+
668671
class C:
669-
lurr: 'ClassVar[int]' = 9
672+
lurr: ClassVar[int] = 9
670673
@staticmethod
671-
def foo(x: int) -> int: return 10 + x
674+
def foo(x: int) -> int:
675+
return 10 + x
672676
@classmethod
673-
def bar(cls, x: int) -> int: return cls.lurr + x
677+
def bar(cls, x: int) -> int:
678+
return cls.lurr + x
674679
@staticmethod
675-
def baz(x: int, y: int = 10) -> int: return y - x
680+
def baz(x: int, y: int = 10) -> int:
681+
return y - x
676682
@classmethod
677-
def quux(cls, x: int, y: int = 10) -> int: return y - x
683+
def quux(cls, x: int, y: int = 10) -> int:
684+
return y - x
685+
@classmethod
686+
def call_other(cls, x: int) -> int:
687+
return cls.quux(x, 3)
678688

679689
class D(C):
680690
def f(self) -> int:
681691
return super().foo(1) + super().bar(2) + super().baz(10) + super().quux(10)
682692

683-
def test1() -> int:
693+
def ctest1() -> int:
684694
return C.foo(1) + C.bar(2) + C.baz(10) + C.quux(10) + C.quux(y=10, x=9)
685-
def test2() -> int:
695+
696+
def ctest2() -> int:
686697
c = C()
687698
return c.foo(1) + c.bar(2) + c.baz(10)
688-
[file driver.py]
689-
from native import *
690-
assert C.foo(10) == 20
691-
assert C.bar(10) == 19
692-
c = C()
693-
assert c.foo(10) == 20
694-
assert c.bar(10) == 19
695699

696-
assert test1() == 23
697-
assert test2() == 22
700+
CAny: Any = C
701+
702+
def test_classmethod_using_any() -> None:
703+
assert CAny.foo(10) == 20
704+
assert CAny.bar(10) == 19
705+
706+
def test_classmethod_on_instance() -> None:
707+
c = C()
708+
assert c.foo(10) == 20
709+
assert c.bar(10) == 19
710+
assert c.call_other(1) == 2
711+
712+
def test_classmethod_misc() -> None:
713+
assert ctest1() == 23
714+
assert ctest2() == 22
715+
assert C.call_other(2) == 1
716+
717+
def test_classmethod_using_super() -> None:
718+
d = D()
719+
assert d.f() == 22
698720

699-
d = D()
700-
assert d.f() == 22
721+
@final
722+
class F1:
723+
@classmethod
724+
def f(cls, x: int) -> int:
725+
return cls.g(x)
726+
727+
@classmethod
728+
def g(cls, x: int) -> int:
729+
return x + 1
730+
731+
class F2: # Implicitly final (no subclasses)
732+
@classmethod
733+
def f(cls, x: int) -> int:
734+
return cls.g(x)
735+
736+
@classmethod
737+
def g(cls, x: int) -> int:
738+
return x + 1
739+
740+
def test_classmethod_of_final_class() -> None:
741+
assert F1.f(5) == 6
742+
assert F2.f(7) == 8
743+
744+
@mypyc_attr(allow_interpreted_subclasses=True)
745+
class CI:
746+
@classmethod
747+
def f(cls, x: int) -> int:
748+
return cls.g(x)
749+
750+
@classmethod
751+
def g(cls, x: int) -> int:
752+
return x + 1
753+
754+
def test_classmethod_with_allow_interpreted() -> None:
755+
assert CI.f(4) == 5
756+
sub = make_interpreted_subclass(CI)
757+
assert sub.f(4) == 7
758+
759+
[file interp.py]
760+
def make_interpreted_subclass(base):
761+
class Sub(base):
762+
@classmethod
763+
def g(cls, x: int) -> int:
764+
return x + 3
765+
return Sub
701766

702767
[case testSuper]
703768
from mypy_extensions import trait

0 commit comments

Comments
 (0)