88from typing import Union
99from typing_extensions import Final
1010
11- from mypy .nodes import Expression , FloatExpr , IntExpr , NameExpr , OpExpr , StrExpr , UnaryExpr , Var
11+ from mypy .nodes import (
12+ ComplexExpr ,
13+ Expression ,
14+ FloatExpr ,
15+ IntExpr ,
16+ NameExpr ,
17+ OpExpr ,
18+ StrExpr ,
19+ UnaryExpr ,
20+ Var ,
21+ )
1222
1323# All possible result types of constant folding
14- ConstantValue = Union [int , bool , float , str ]
15- CONST_TYPES : Final = (int , bool , float , str )
24+ ConstantValue = Union [int , bool , float , complex , str ]
25+ CONST_TYPES : Final = (int , bool , float , complex , str )
1626
1727
1828def constant_fold_expr (expr : Expression , cur_mod_id : str ) -> ConstantValue | None :
@@ -39,6 +49,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
3949 return expr .value
4050 if isinstance (expr , FloatExpr ):
4151 return expr .value
52+ if isinstance (expr , ComplexExpr ):
53+ return expr .value
4254 elif isinstance (expr , NameExpr ):
4355 if expr .name == "True" :
4456 return True
@@ -56,26 +68,60 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
5668 elif isinstance (expr , OpExpr ):
5769 left = constant_fold_expr (expr .left , cur_mod_id )
5870 right = constant_fold_expr (expr .right , cur_mod_id )
59- if isinstance (left , int ) and isinstance (right , int ):
60- return constant_fold_binary_int_op (expr .op , left , right )
61- elif isinstance (left , str ) and isinstance (right , str ):
62- return constant_fold_binary_str_op (expr .op , left , right )
71+ if left is not None and right is not None :
72+ return constant_fold_binary_op (expr .op , left , right )
6373 elif isinstance (expr , UnaryExpr ):
6474 value = constant_fold_expr (expr .expr , cur_mod_id )
65- if isinstance (value , int ):
66- return constant_fold_unary_int_op (expr .op , value )
67- if isinstance (value , float ):
68- return constant_fold_unary_float_op (expr .op , value )
75+ if value is not None :
76+ return constant_fold_unary_op (expr .op , value )
6977 return None
7078
7179
72- def constant_fold_binary_int_op (op : str , left : int , right : int ) -> int | None :
80+ def constant_fold_binary_op (
81+ op : str , left : ConstantValue , right : ConstantValue
82+ ) -> ConstantValue | None :
83+ if isinstance (left , int ) and isinstance (right , int ):
84+ return constant_fold_binary_int_op (op , left , right )
85+
86+ # Float and mixed int/float arithmetic.
87+ if isinstance (left , float ) and isinstance (right , float ):
88+ return constant_fold_binary_float_op (op , left , right )
89+ elif isinstance (left , float ) and isinstance (right , int ):
90+ return constant_fold_binary_float_op (op , left , right )
91+ elif isinstance (left , int ) and isinstance (right , float ):
92+ return constant_fold_binary_float_op (op , left , right )
93+
94+ # String concatenation and multiplication.
95+ if op == "+" and isinstance (left , str ) and isinstance (right , str ):
96+ return left + right
97+ elif op == "*" and isinstance (left , str ) and isinstance (right , int ):
98+ return left * right
99+ elif op == "*" and isinstance (left , int ) and isinstance (right , str ):
100+ return left * right
101+
102+ # Complex construction.
103+ if op == "+" and isinstance (left , (int , float )) and isinstance (right , complex ):
104+ return left + right
105+ elif op == "+" and isinstance (left , complex ) and isinstance (right , (int , float )):
106+ return left + right
107+ elif op == "-" and isinstance (left , (int , float )) and isinstance (right , complex ):
108+ return left - right
109+ elif op == "-" and isinstance (left , complex ) and isinstance (right , (int , float )):
110+ return left - right
111+
112+ return None
113+
114+
115+ def constant_fold_binary_int_op (op : str , left : int , right : int ) -> int | float | None :
73116 if op == "+" :
74117 return left + right
75118 if op == "-" :
76119 return left - right
77120 elif op == "*" :
78121 return left * right
122+ elif op == "/" :
123+ if right != 0 :
124+ return left / right
79125 elif op == "//" :
80126 if right != 0 :
81127 return left // right
@@ -102,25 +148,41 @@ def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
102148 return None
103149
104150
105- def constant_fold_unary_int_op (op : str , value : int ) -> int | None :
106- if op == "-" :
107- return - value
108- elif op == "~" :
109- return ~ value
110- elif op == "+" :
111- return value
151+ def constant_fold_binary_float_op (op : str , left : int | float , right : int | float ) -> float | None :
152+ assert not (isinstance (left , int ) and isinstance (right , int )), (op , left , right )
153+ if op == "+" :
154+ return left + right
155+ elif op == "-" :
156+ return left - right
157+ elif op == "*" :
158+ return left * right
159+ elif op == "/" :
160+ if right != 0 :
161+ return left / right
162+ elif op == "//" :
163+ if right != 0 :
164+ return left // right
165+ elif op == "%" :
166+ if right != 0 :
167+ return left % right
168+ elif op == "**" :
169+ if (left < 0 and isinstance (right , int )) or left > 0 :
170+ try :
171+ ret = left ** right
172+ except OverflowError :
173+ return None
174+ else :
175+ assert isinstance (ret , float ), ret
176+ return ret
177+
112178 return None
113179
114180
115- def constant_fold_unary_float_op (op : str , value : float ) -> float | None :
116- if op == "-" :
181+ def constant_fold_unary_op (op : str , value : ConstantValue ) -> int | float | None :
182+ if op == "-" and isinstance ( value , ( int , float )) :
117183 return - value
118- elif op == "+" :
184+ elif op == "~" and isinstance (value , int ):
185+ return ~ value
186+ elif op == "+" and isinstance (value , (int , float )):
119187 return value
120188 return None
121-
122-
123- def constant_fold_binary_str_op (op : str , left : str , right : str ) -> str | None :
124- if op == "+" :
125- return left + right
126- return None
0 commit comments