Skip to content

Commit c7c3e80

Browse files
committed
WIP: consolidate accuracy check APIs
1 parent 9ff9e88 commit c7c3e80

File tree

3 files changed

+47
-83
lines changed

3 files changed

+47
-83
lines changed

examples/distributed/fp8_matmul_reduce_scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch.distributed._symmetric_memory as symm_mem
2222

2323
import helion
24-
from helion._testing import assert_close_with_mismatch_tolerance
24+
from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance
2525
from helion._testing import DEVICE
2626
from helion._testing import run_example
2727
import helion.language as hl

helion/_testing.py

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ._utils import counters
3535
from .runtime.settings import _get_backend
3636
from helion.autotuner.base_search import _clone_args
37+
from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance
3738

3839
if _get_backend() == "pallas":
3940
from .autotuner.benchmarking import compute_repeat_generic as compute_repeat
@@ -1531,82 +1532,3 @@ def capture_output(self) -> Generator[_OutputCapture, None, None]:
15311532
finally:
15321533
sys.stdout, sys.stderr = old_stdout, old_stderr
15331534

1534-
1535-
def assert_close_with_mismatch_tolerance(
1536-
actual: object,
1537-
expected: object,
1538-
*,
1539-
atol: float = 1e-4,
1540-
rtol: float = 1e-4,
1541-
max_mismatch_pct: float = 0.01,
1542-
max_abs_diff: float | None = None,
1543-
max_rel_diff: float | None = None,
1544-
) -> None:
1545-
"""Check that actual and expected are close, tolerating a small fraction of mismatches.
1546-
1547-
First tries ``torch.testing.assert_close`` with the given *atol*/*rtol*.
1548-
If that fails **and** both arguments are tensors, falls back to a relaxed
1549-
check using the same mismatch definition as ``torch.testing.assert_close``
1550-
(``|actual - expected| > atol + rtol * |expected|``):
1551-
1552-
- *max_mismatch_pct*: maximum allowed fraction of mismatched elements
1553-
(default 1%). Always checked.
1554-
- *max_abs_diff*: if not None, the greatest absolute difference across
1555-
all elements must not exceed this value.
1556-
- *max_rel_diff*: if not None, the greatest relative difference
1557-
(``|actual - expected| / |expected|``) must not exceed this value.
1558-
1559-
This is useful for kernels where most elements match but a tiny
1560-
fraction have large relative differences. Pass this function directly as
1561-
``autotune_baseline_accuracy_check_fn`` for the default thresholds, or use
1562-
``functools.partial`` to customize them::
1563-
1564-
from functools import partial
1565-
from helion._testing import assert_close_with_mismatch_tolerance
1566-
1567-
@helion.kernel(
1568-
autotune_baseline_accuracy_check_fn=partial(
1569-
assert_close_with_mismatch_tolerance,
1570-
max_mismatch_pct=0.05,
1571-
max_abs_diff=10.0,
1572-
max_rel_diff=15.0,
1573-
),
1574-
)
1575-
def my_kernel(...): ...
1576-
"""
1577-
try:
1578-
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
1579-
return
1580-
except AssertionError:
1581-
if not (
1582-
isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor)
1583-
):
1584-
raise
1585-
1586-
abs_diff = (actual - expected).abs()
1587-
total = actual.numel()
1588-
1589-
# Use the same mismatch definition as torch.testing.assert_close:
1590-
# an element is mismatched when |actual - expected| > atol + rtol * |expected|
1591-
mismatched = (abs_diff > atol + rtol * expected.abs()).sum().item()
1592-
mismatch_pct = mismatched / total if total > 0 else 0.0
1593-
1594-
if mismatch_pct > max_mismatch_pct:
1595-
raise AssertionError(
1596-
f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}"
1597-
)
1598-
1599-
if max_abs_diff is not None:
1600-
worst_abs = abs_diff.max().item()
1601-
if worst_abs > max_abs_diff:
1602-
raise AssertionError(
1603-
f"Absolute diff too large: {worst_abs} > {max_abs_diff}"
1604-
)
1605-
1606-
if max_rel_diff is not None:
1607-
rel_diff = abs_diff / expected.abs().clamp(min=1e-6)
1608-
worst_rel = rel_diff.max().item()
1609-
if worst_rel > max_rel_diff:
1610-
raise AssertionError(
1611-
f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}"
1612-
)

helion/autotuner/base_search.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,22 @@ class BenchmarkResult(NamedTuple):
212212
}
213213

214214

215-
def _assert_close(actual: object, expected: object, atol: float, rtol: float) -> None:
216-
"""Like torch.testing.assert_close but handles fp8 and uses chunked comparison for large tensors."""
215+
def _assert_close(
216+
actual: object,
217+
expected: object,
218+
*,
219+
atol: float,
220+
rtol: float,
221+
max_mismatch_pct: float | None = None,
222+
max_abs_diff: float | None = None,
223+
max_rel_diff: float | None = None,
224+
) -> None:
225+
"""Like torch.testing.assert_close but handles fp8, pytree structures, and strings.
226+
227+
For tensors, uses chunked comparison for large tensors. When
228+
*max_mismatch_pct* is set, falls back to a relaxed mismatch-fraction check
229+
instead of raising immediately on the first out-of-tolerance element.
230+
"""
217231

218232
def convert(t: torch.Tensor) -> torch.Tensor:
219233
return t.view(torch.uint8) if t.dtype in _FP8_DTYPES else t
@@ -234,7 +248,35 @@ def convert(t: torch.Tensor) -> torch.Tensor:
234248

235249
for a, e in zip(actual_flat, expected_flat, strict=True):
236250
if isinstance(a, torch.Tensor):
237-
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
251+
if max_mismatch_pct is not None:
252+
try:
253+
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
254+
continue
255+
except AssertionError:
256+
pass
257+
abs_diff = (a - e).abs()
258+
total = a.numel()
259+
mismatched = (abs_diff > atol + rtol * e.abs()).sum().item()
260+
mismatch_pct = mismatched / total if total > 0 else 0.0
261+
if mismatch_pct > max_mismatch_pct:
262+
raise AssertionError(
263+
f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}"
264+
)
265+
if max_abs_diff is not None:
266+
worst_abs = abs_diff.max().item()
267+
if worst_abs > max_abs_diff:
268+
raise AssertionError(
269+
f"Absolute diff too large: {worst_abs} > {max_abs_diff}"
270+
)
271+
if max_rel_diff is not None:
272+
rel_diff = abs_diff / e.abs().clamp(min=1e-6)
273+
worst_rel = rel_diff.max().item()
274+
if worst_rel > max_rel_diff:
275+
raise AssertionError(
276+
f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}"
277+
)
278+
else:
279+
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
238280
elif isinstance(a, str):
239281
if not isinstance(e, str):
240282
raise AssertionError(f"Type mismatch {a} vs {e}")

0 commit comments

Comments
 (0)