Skip to content

Commit 9b46217

Browse files
committed
add max_mismatch_pct argument to run_example
stack-info: PR: #1909, branch: shunting314/stack/28
1 parent 8830fec commit 9b46217

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

helion/_testing.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,7 @@ def run_example(
945945
baseline_name: str = "torch",
946946
rtol: float = 1e-2,
947947
atol: float = 1e-1,
948+
max_mismatch_pct: float | None = None,
948949
bwd: bool = False,
949950
trace_path: str | None = None,
950951
process_group_name: str | None = None,
@@ -959,6 +960,8 @@ def run_example(
959960
baseline_name: Name for single baseline in output (default: "torch")
960961
rtol: Relative tolerance for correctness check (default: 1e-2)
961962
atol: Absolute tolerance for correctness check (default: 1e-1)
963+
max_mismatch_pct: If set, use assert_close_with_mismatch_tolerance with this mismatch
964+
fraction tolerance instead of strict assert_close (default: None)
962965
bwd: Whether to also test backward pass (default: False)
963966
trace_path: if not None, do profiling and save trace to this path
964967
"""
@@ -988,12 +991,21 @@ def run_example(
988991
# Clone args to avoid buffer donation issues (e.g., Pallas/TPU)
989992
cloned_args = _clone_args(args, process_group_name=process_group_name)
990993
result = func(*cloned_args).clone()
991-
torch.testing.assert_close(
992-
result.to(torch.float32),
993-
expected.to(torch.float32),
994-
rtol=rtol,
995-
atol=atol,
996-
)
994+
if max_mismatch_pct is not None:
995+
assert_close_with_mismatch_tolerance(
996+
result.to(torch.float32),
997+
expected.to(torch.float32),
998+
atol=atol,
999+
rtol=rtol,
1000+
max_mismatch_pct=max_mismatch_pct,
1001+
)
1002+
else:
1003+
torch.testing.assert_close(
1004+
result.to(torch.float32),
1005+
expected.to(torch.float32),
1006+
rtol=rtol,
1007+
atol=atol,
1008+
)
9971009

9981010
# Test backward pass
9991011
if bwd:

0 commit comments

Comments
 (0)