Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 1 addition & 79 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .runtime.settings import _get_backend
from .runtime.settings import is_pallas_interpret
from helion.autotuner.base_search import _clone_args
from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance

if _get_backend() == "pallas":
from .autotuner.benchmarking import compute_repeat_generic as compute_repeat
Expand Down Expand Up @@ -1541,82 +1542,3 @@ def capture_output(self) -> Generator[_OutputCapture, None, None]:
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr


def assert_close_with_mismatch_tolerance(
actual: object,
expected: object,
*,
atol: float = 1e-4,
rtol: float = 1e-4,
max_mismatch_pct: float = 0.01,
max_abs_diff: float | None = None,
max_rel_diff: float | None = None,
) -> None:
"""Check that actual and expected are close, tolerating a small fraction of mismatches.

First tries ``torch.testing.assert_close`` with the given *atol*/*rtol*.
If that fails **and** both arguments are tensors, falls back to a relaxed
check using the same mismatch definition as ``torch.testing.assert_close``
(``|actual - expected| > atol + rtol * |expected|``):

- *max_mismatch_pct*: maximum allowed fraction of mismatched elements
(default 1%). Always checked.
- *max_abs_diff*: if not None, the greatest absolute difference across
all elements must not exceed this value.
- *max_rel_diff*: if not None, the greatest relative difference
(``|actual - expected| / |expected|``) must not exceed this value.

This is useful for kernels where most elements match but a tiny
fraction have large relative differences. Pass this function directly as
``autotune_baseline_accuracy_check_fn`` for the default thresholds, or use
``functools.partial`` to customize them::

from functools import partial
from helion._testing import assert_close_with_mismatch_tolerance

@helion.kernel(
autotune_baseline_accuracy_check_fn=partial(
assert_close_with_mismatch_tolerance,
max_mismatch_pct=0.05,
max_abs_diff=10.0,
max_rel_diff=15.0,
),
)
def my_kernel(...): ...
"""
try:
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
return
except AssertionError:
if not (
isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor)
):
raise

abs_diff = (actual - expected).abs()
total = actual.numel()

# Use the same mismatch definition as torch.testing.assert_close:
# an element is mismatched when |actual - expected| > atol + rtol * |expected|
mismatched = (abs_diff > atol + rtol * expected.abs()).sum().item()
mismatch_pct = mismatched / total if total > 0 else 0.0

if mismatch_pct > max_mismatch_pct:
raise AssertionError(
f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}"
)

if max_abs_diff is not None:
worst_abs = abs_diff.max().item()
if worst_abs > max_abs_diff:
raise AssertionError(
f"Absolute diff too large: {worst_abs} > {max_abs_diff}"
)

if max_rel_diff is not None:
rel_diff = abs_diff / expected.abs().clamp(min=1e-6)
worst_rel = rel_diff.max().item()
if worst_rel > max_rel_diff:
raise AssertionError(
f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}"
)
48 changes: 45 additions & 3 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,22 @@ class BenchmarkResult(NamedTuple):
}


def _assert_close(actual: object, expected: object, atol: float, rtol: float) -> None:
"""Like torch.testing.assert_close but handles fp8 and uses chunked comparison for large tensors."""
def _assert_close(
actual: object,
expected: object,
*,
atol: float,
rtol: float,
max_mismatch_pct: float | None = None,
max_abs_diff: float | None = None,
max_rel_diff: float | None = None,
) -> None:
"""Like torch.testing.assert_close but handles fp8, pytree structures, and strings.

For tensors, uses chunked comparison for large tensors. When
*max_mismatch_pct* is set, falls back to a relaxed mismatch-fraction check
instead of raising immediately on the first out-of-tolerance element.
"""

def convert(t: torch.Tensor) -> torch.Tensor:
return t.view(torch.uint8) if t.dtype in _FP8_DTYPES else t
Expand All @@ -241,7 +255,35 @@ def convert(t: torch.Tensor) -> torch.Tensor:

for a, e in zip(actual_flat, expected_flat, strict=True):
if isinstance(a, torch.Tensor):
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
if max_mismatch_pct is not None:
try:
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
continue
except AssertionError:
pass
abs_diff = (a - e).abs()
total = a.numel()
mismatched = (abs_diff > atol + rtol * e.abs()).sum().item()
mismatch_pct = mismatched / total if total > 0 else 0.0
if mismatch_pct > max_mismatch_pct:
raise AssertionError(
f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}"
)
if max_abs_diff is not None:
worst_abs = abs_diff.max().item()
if worst_abs > max_abs_diff:
raise AssertionError(
f"Absolute diff too large: {worst_abs} > {max_abs_diff}"
)
if max_rel_diff is not None:
rel_diff = abs_diff / e.abs().clamp(min=1e-6)
worst_rel = rel_diff.max().item()
if worst_rel > max_rel_diff:
raise AssertionError(
f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}"
)
else:
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
elif isinstance(a, str):
if not isinstance(e, str):
raise AssertionError(f"Type mismatch {a} vs {e}")
Expand Down
Loading