66
77from onnxscript .rewriter import _fusion_utils , pattern
88
9- _sqrt_two_over_pi = math .sqrt (2.0 / math .pi )
9+ _SQRT_TWO_OVER_PI = math .sqrt (2.0 / math .pi )
10+ _SQRT_TWO = math .sqrt (2.0 )
1011
1112
1213class GeluTanhFusion (pattern .RewriteRuleClassBase ):
@@ -16,7 +17,7 @@ def pattern(self, op, x):
1617 t2 = op .Mul (0.044715 , t1 )
1718 t3 = op .Add (x , t2 )
1819
19- t4 = op .Mul (_sqrt_two_over_pi , t3 )
20+ t4 = op .Mul (_SQRT_TWO_OVER_PI , t3 )
2021 t5 = op .Tanh (t4 )
2122 t6 = op .Add (t5 , 1 )
2223 t7 = op .Mul (0.5 , t6 )
@@ -27,9 +28,23 @@ def rewrite(self, op, x):
2728 return op .FastGelu (x , _domain = "com.microsoft" )
2829
2930
30- _rule = GeluTanhFusion .rule ()
31+ class GeluErfFusion (pattern .RewriteRuleClassBase ):
32+ def pattern (self , op , x ):
33+ # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
34+ t1 = op .Div (x , _SQRT_TWO )
35+ t2 = op .Erf (t1 )
36+ t3 = op .Add (t2 , 1.0 )
37+ t4 = op .Mul (x , t3 )
38+ result = op .Mul (t4 , 0.5 )
39+ return result
40+
41+ def rewrite (self , op , x ):
42+ return op .Gelu (x , _domain = "com.microsoft" )
43+
3144
32- gelu_rules = pattern .RewriteRuleSet ([_rule ])
45+ _tanh_rule = GeluTanhFusion .rule ()
46+ _erf_rule = GeluErfFusion .rule ()
3347
48+ gelu_rules = pattern .RewriteRuleSet ([_tanh_rule , _erf_rule ])
3449
3550fuse_gelu = _fusion_utils .apply_fusion_rules (gelu_rules )
0 commit comments