@@ -479,33 +479,31 @@ def aten_gelu(self: TReal, approximate: str = "none") -> TReal:
479479 return result
480480
481481
482- @torch_op ("aten::gelu" , private = True )
483482def _aten_gelu_approximate_none (self : TReal ) -> TReal :
484483 """gelu(Tensor self, *, str approximate='none') -> Tensor"""
485484
486485 # GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)]
487- inner = op .Div (self , 1.4142135623730951 )
486+ inner = op .Div (self , ir . tensor ( 1.4142135623730951 , dtype = self . dtype ) )
488487 erf = op .Erf (inner )
489- inner = op .Add (erf , 1 )
490- inner = op .Mul (0.5 , inner )
488+ inner = op .Add (erf , ir . tensor ( 1 , dtype = self . dtype ) )
489+ inner = op .Mul (ir . tensor ( 0.5 , dtype = self . dtype ) , inner )
491490 result = op .Mul (self , inner )
492491 return result
493492
494493
495- @torch_op ("aten::gelu" , private = True )
496494def _aten_gelu_approximate_tanh (self : TReal ) -> TReal :
497495 """gelu(Tensor self, *, str approximate='none') -> Tensor"""
498496
499497 # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
500- cubed = op .Pow (self , 3 )
501- inner = op .Mul (0.044715 , cubed )
498+ cubed = op .Pow (self , ir . tensor ( 3 , dtype = self . dtype ) )
499+ inner = op .Mul (ir . tensor ( 0.044715 , dtype = self . dtype ) , cubed )
502500 inner = op .Add (self , inner )
503- # Prefer explicit graph construction over precomputed constants for clarity.
504- two_over_pi = op . CastLike ( op . Div ( 2.0 , _MATH_PI ), self )
505- inner = op .Mul (op . Sqrt ( two_over_pi ) , inner )
501+ # math.sqrt(2.0/math.pi) = 0.7978845608028654
502+ sqrt_two_over_pi = ir . tensor ( 0.7978845608028654 , dtype = self . dtype )
503+ inner = op .Mul (sqrt_two_over_pi , inner )
506504 inner = op .Tanh (inner )
507- inner = op .Add (inner , 1 )
508- inner = op .Mul (0.5 , inner )
505+ inner = op .Add (inner , ir . tensor ( 1 , dtype = self . dtype ) )
506+ inner = op .Mul (ir . tensor ( 0.5 , dtype = self . dtype ) , inner )
509507 result = op .Mul (self , inner )
510508 return result
511509
0 commit comments