Skip to content

Commit a18a0db

Browse files
authored
[mypyc] Optimize away some bool/bit registers (#17022)
If a register is always used in a branch immediately after assignment, and it isn't used for anything else, we can replace the assignment with a branch op. This avoids some assignment ops and gotos. This is not a very interesting optimization in general, but it will help a lot with tagged integer operations once I refactor them to be generated in the lowering pass (in follow-up PRs).
1 parent a00fcba commit a18a0db

File tree

5 files changed

+445
-14
lines changed

5 files changed

+445
-14
lines changed

mypyc/codegen/emitmodule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from mypyc.options import CompilerOptions
5959
from mypyc.transform.copy_propagation import do_copy_propagation
6060
from mypyc.transform.exceptions import insert_exception_handling
61+
from mypyc.transform.flag_elimination import do_flag_elimination
6162
from mypyc.transform.refcount import insert_ref_count_opcodes
6263
from mypyc.transform.uninit import insert_uninit_checks
6364

@@ -234,8 +235,9 @@ def compile_scc_to_ir(
234235
insert_exception_handling(fn)
235236
# Insert refcount handling.
236237
insert_ref_count_opcodes(fn)
237-
# Perform copy propagation optimization.
238+
# Perform optimizations.
238239
do_copy_propagation(fn, compiler_options)
240+
do_flag_elimination(fn, compiler_options)
239241

240242
return modules
241243

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
-- Test cases for "flag elimination" optimization. Used to optimize away
2+
-- registers that are always used immediately after assignment as branch conditions.
3+
4+
[case testFlagEliminationSimple]
5+
def c() -> bool:
6+
return True
7+
def d() -> bool:
8+
return True
9+
10+
def f(x: bool) -> int:
11+
if x:
12+
b = c()
13+
else:
14+
b = d()
15+
if b:
16+
return 1
17+
else:
18+
return 2
19+
[out]
20+
def c():
21+
L0:
22+
return 1
23+
def d():
24+
L0:
25+
return 1
26+
def f(x):
27+
x, r0, r1 :: bool
28+
L0:
29+
if x goto L1 else goto L2 :: bool
30+
L1:
31+
r0 = c()
32+
if r0 goto L4 else goto L5 :: bool
33+
L2:
34+
r1 = d()
35+
if r1 goto L4 else goto L5 :: bool
36+
L3:
37+
unreachable
38+
L4:
39+
return 2
40+
L5:
41+
return 4
42+
43+
[case testFlagEliminationOneAssignment]
44+
def c() -> bool:
45+
return True
46+
47+
def f(x: bool) -> int:
48+
# Not applied here
49+
b = c()
50+
if b:
51+
return 1
52+
else:
53+
return 2
54+
[out]
55+
def c():
56+
L0:
57+
return 1
58+
def f(x):
59+
x, r0, b :: bool
60+
L0:
61+
r0 = c()
62+
b = r0
63+
if b goto L1 else goto L2 :: bool
64+
L1:
65+
return 2
66+
L2:
67+
return 4
68+
69+
[case testFlagEliminationThreeCases]
70+
def c(x: int) -> bool:
71+
return True
72+
73+
def f(x: bool, y: bool) -> int:
74+
if x:
75+
b = c(1)
76+
elif y:
77+
b = c(2)
78+
else:
79+
b = c(3)
80+
if b:
81+
return 1
82+
else:
83+
return 2
84+
[out]
85+
def c(x):
86+
x :: int
87+
L0:
88+
return 1
89+
def f(x, y):
90+
x, y, r0, r1, r2 :: bool
91+
L0:
92+
if x goto L1 else goto L2 :: bool
93+
L1:
94+
r0 = c(2)
95+
if r0 goto L6 else goto L7 :: bool
96+
L2:
97+
if y goto L3 else goto L4 :: bool
98+
L3:
99+
r1 = c(4)
100+
if r1 goto L6 else goto L7 :: bool
101+
L4:
102+
r2 = c(6)
103+
if r2 goto L6 else goto L7 :: bool
104+
L5:
105+
unreachable
106+
L6:
107+
return 2
108+
L7:
109+
return 4
110+
111+
[case testFlagEliminationAssignmentNotLastOp]
112+
def f(x: bool) -> int:
113+
y = 0
114+
if x:
115+
b = True
116+
y = 1
117+
else:
118+
b = False
119+
if b:
120+
return 1
121+
else:
122+
return 2
123+
[out]
124+
def f(x):
125+
x :: bool
126+
y :: int
127+
b :: bool
128+
L0:
129+
y = 0
130+
if x goto L1 else goto L2 :: bool
131+
L1:
132+
b = 1
133+
y = 2
134+
goto L3
135+
L2:
136+
b = 0
137+
L3:
138+
if b goto L4 else goto L5 :: bool
139+
L4:
140+
return 2
141+
L5:
142+
return 4
143+
144+
[case testFlagEliminationAssignmentNoDirectGoto]
145+
def f(x: bool) -> int:
146+
if x:
147+
b = True
148+
else:
149+
b = False
150+
if x:
151+
if b:
152+
return 1
153+
else:
154+
return 2
155+
return 4
156+
[out]
157+
def f(x):
158+
x, b :: bool
159+
L0:
160+
if x goto L1 else goto L2 :: bool
161+
L1:
162+
b = 1
163+
goto L3
164+
L2:
165+
b = 0
166+
L3:
167+
if x goto L4 else goto L7 :: bool
168+
L4:
169+
if b goto L5 else goto L6 :: bool
170+
L5:
171+
return 2
172+
L6:
173+
return 4
174+
L7:
175+
return 8
176+
177+
[case testFlagEliminationBranchNotNextOpAfterGoto]
178+
def f(x: bool) -> int:
179+
if x:
180+
b = True
181+
else:
182+
b = False
183+
y = 1 # Prevents the optimization
184+
if b:
185+
return 1
186+
else:
187+
return 2
188+
[out]
189+
def f(x):
190+
x, b :: bool
191+
y :: int
192+
L0:
193+
if x goto L1 else goto L2 :: bool
194+
L1:
195+
b = 1
196+
goto L3
197+
L2:
198+
b = 0
199+
L3:
200+
y = 2
201+
if b goto L4 else goto L5 :: bool
202+
L4:
203+
return 2
204+
L5:
205+
return 4
206+
207+
[case testFlagEliminationFlagReadTwice]
208+
def f(x: bool) -> bool:
209+
if x:
210+
b = True
211+
else:
212+
b = False
213+
if b:
214+
return b # Prevents the optimization
215+
else:
216+
return False
217+
[out]
218+
def f(x):
219+
x, b :: bool
220+
L0:
221+
if x goto L1 else goto L2 :: bool
222+
L1:
223+
b = 1
224+
goto L3
225+
L2:
226+
b = 0
227+
L3:
228+
if b goto L4 else goto L5 :: bool
229+
L4:
230+
return b
231+
L5:
232+
return 0
233+
234+
[case testFlagEliminationArgumentNotEligible]
235+
def f(x: bool, b: bool) -> bool:
236+
if x:
237+
b = True
238+
else:
239+
b = False
240+
if b:
241+
return True
242+
else:
243+
return False
244+
[out]
245+
def f(x, b):
246+
x, b :: bool
247+
L0:
248+
if x goto L1 else goto L2 :: bool
249+
L1:
250+
b = 1
251+
goto L3
252+
L2:
253+
b = 0
254+
L3:
255+
if b goto L4 else goto L5 :: bool
256+
L4:
257+
return 1
258+
L5:
259+
return 0
260+
261+
[case testFlagEliminationFlagNotAlwaysDefined]
262+
def f(x: bool, y: bool) -> bool:
263+
if x:
264+
b = True
265+
elif y:
266+
b = False
267+
else:
268+
bb = False # b not assigned here -> can't optimize
269+
if b:
270+
return True
271+
else:
272+
return False
273+
[out]
274+
def f(x, y):
275+
x, y, r0, b, bb, r1 :: bool
276+
L0:
277+
r0 = <error> :: bool
278+
b = r0
279+
if x goto L1 else goto L2 :: bool
280+
L1:
281+
b = 1
282+
goto L5
283+
L2:
284+
if y goto L3 else goto L4 :: bool
285+
L3:
286+
b = 0
287+
goto L5
288+
L4:
289+
bb = 0
290+
L5:
291+
if is_error(b) goto L6 else goto L7
292+
L6:
293+
r1 = raise UnboundLocalError('local variable "b" referenced before assignment')
294+
unreachable
295+
L7:
296+
if b goto L8 else goto L9 :: bool
297+
L8:
298+
return 1
299+
L9:
300+
return 0
Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Runner for copy propagation optimization tests."""
1+
"""Runner for IR optimization tests."""
22

33
from __future__ import annotations
44

@@ -8,6 +8,7 @@
88
from mypy.test.config import test_temp_dir
99
from mypy.test.data import DataDrivenTestCase
1010
from mypyc.common import TOP_LEVEL_NAME
11+
from mypyc.ir.func_ir import FuncIR
1112
from mypyc.ir.pprint import format_func
1213
from mypyc.options import CompilerOptions
1314
from mypyc.test.testutil import (
@@ -19,13 +20,16 @@
1920
use_custom_builtins,
2021
)
2122
from mypyc.transform.copy_propagation import do_copy_propagation
23+
from mypyc.transform.flag_elimination import do_flag_elimination
2224
from mypyc.transform.uninit import insert_uninit_checks
2325

24-
files = ["opt-copy-propagation.test"]
2526

27+
class OptimizationSuite(MypycDataSuite):
28+
"""Base class for IR optimization test suites.
29+
30+
To use this, add a base class and define "files" and "do_optimizations".
31+
"""
2632

27-
class TestCopyPropagation(MypycDataSuite):
28-
files = files
2933
base_path = test_temp_dir
3034

3135
def run_case(self, testcase: DataDrivenTestCase) -> None:
@@ -41,7 +45,24 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
4145
if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"):
4246
continue
4347
insert_uninit_checks(fn)
44-
do_copy_propagation(fn, CompilerOptions())
48+
self.do_optimizations(fn)
4549
actual.extend(format_func(fn))
4650

4751
assert_test_output(testcase, actual, "Invalid source code output", expected_output)
52+
53+
def do_optimizations(self, fn: FuncIR) -> None:
54+
raise NotImplementedError
55+
56+
57+
class TestCopyPropagation(OptimizationSuite):
58+
files = ["opt-copy-propagation.test"]
59+
60+
def do_optimizations(self, fn: FuncIR) -> None:
61+
do_copy_propagation(fn, CompilerOptions())
62+
63+
64+
class TestFlagElimination(OptimizationSuite):
65+
files = ["opt-flag-elimination.test"]
66+
67+
def do_optimizations(self, fn: FuncIR) -> None:
68+
do_flag_elimination(fn, CompilerOptions())

0 commit comments

Comments
 (0)