@@ -186,24 +186,14 @@ def compute_pvalue(self):
186186 if self .approx_pvalue :
187187 lam = ctrl + pc
188188 k = treat + pc
189- tiny = mx .array (np .finfo (np .float32 ).tiny , dtype = FLOAT32_MX )
190- log2 = mx .log (mx .array (2.0 , dtype = FLOAT32_MX ))
191- log10_const = mx .log (mx .array (10.0 , dtype = FLOAT32_MX ))
192189 sqrt2 = mx .sqrt (mx .array (2.0 , dtype = FLOAT32_MX ))
193190 z = (k + 0.5 - lam ) / mx .sqrt (lam + 1e-9 )
194- erfc_arg = z / sqrt2
195- erf_val = mx .erf (erfc_arg )
196- # Clamp away from 1 to keep log1p well-defined for extreme tails.
197- erf_clamped = mx .minimum (
198- erf_val ,
199- mx .array (1.0 - np .finfo (np .float32 ).eps , dtype = FLOAT32_MX ),
200- )
201- # Compute -log10(tail) in log space to preserve precision for tiny tails.
202- log_tail = mx .log1p (- erf_clamped )
203- log_tail = mx .where (mx .isfinite (log_tail ), log_tail , mx .log (tiny ))
204- log_tail = log_tail - log2
205- score_backend = - (log_tail / log10_const )
191+ tail = 0.5 * (1.0 - mx .erf (z / sqrt2 ))
192+ tail = mx .maximum (tail , mx .array (1e-30 , dtype = FLOAT32_MX ))
193+ score_backend = - _log10 (tail )
206194 score_np = _to_numpy (score_backend )
195+ lam_np = _to_numpy (lam )
196+ k_np = _to_numpy (k )
207197 score = mx .array (score_np , dtype = FLOAT32_MX )
208198 else :
209199 # exact Poisson on CPU
0 commit comments