Skip to content

Commit 1ca5230

Browse files
committed
add max_mismatch_pct argument to run_example
stack-info: PR: #1909, branch: shunting314/stack/28
1 parent 061553f commit 1ca5230

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

helion/_testing.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ def run_example(
955955
baseline_name: str = "torch",
956956
rtol: float = 1e-2,
957957
atol: float = 1e-1,
958+
max_mismatch_pct: float | None = None,
958959
bwd: bool = False,
959960
trace_path: str | None = None,
960961
process_group_name: str | None = None,
@@ -969,6 +970,8 @@ def run_example(
969970
baseline_name: Name for single baseline in output (default: "torch")
970971
rtol: Relative tolerance for correctness check (default: 1e-2)
971972
atol: Absolute tolerance for correctness check (default: 1e-1)
973+
max_mismatch_pct: If set, use assert_close_with_mismatch_tolerance with this mismatch
974+
fraction tolerance instead of strict assert_close (default: None)
972975
bwd: Whether to also test backward pass (default: False)
973976
trace_path: if not None, do profiling and save trace to this path
974977
"""
@@ -998,12 +1001,21 @@ def run_example(
9981001
# Clone args to avoid buffer donation issues (e.g., Pallas/TPU)
9991002
cloned_args = _clone_args(args, process_group_name=process_group_name)
10001003
result = func(*cloned_args).clone()
1001-
torch.testing.assert_close(
1002-
result.to(torch.float32),
1003-
expected.to(torch.float32),
1004-
rtol=rtol,
1005-
atol=atol,
1006-
)
1004+
if max_mismatch_pct is not None:
1005+
assert_close_with_mismatch_tolerance(
1006+
result.to(torch.float32),
1007+
expected.to(torch.float32),
1008+
atol=atol,
1009+
rtol=rtol,
1010+
max_mismatch_pct=max_mismatch_pct,
1011+
)
1012+
else:
1013+
torch.testing.assert_close(
1014+
result.to(torch.float32),
1015+
expected.to(torch.float32),
1016+
rtol=rtol,
1017+
atol=atol,
1018+
)
10071019

10081020
# Test backward pass
10091021
if bwd:

0 commit comments

Comments
 (0)