Skip to content

Commit 9649742

Browse files
committed
WIP: consolidate accuracy check APIs
stack-info: PR: #1910, branch: shunting314/stack/29
1 parent 1302286 commit 9649742

File tree

3 files changed

+47
-83
lines changed

3 files changed

+47
-83
lines changed

examples/distributed/fp8_matmul_reduce_scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch.distributed._symmetric_memory as symm_mem
2222

2323
import helion
24-
from helion._testing import assert_close_with_mismatch_tolerance
24+
from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance
2525
from helion._testing import DEVICE
2626
from helion._testing import run_example
2727
import helion.language as hl

helion/_testing.py

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .runtime.settings import _get_backend
3636
from .runtime.settings import is_pallas_interpret
3737
from helion.autotuner.base_search import _clone_args
38+
from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance
3839

3940
if _get_backend() == "pallas":
4041
from .autotuner.benchmarking import compute_repeat_generic as compute_repeat
@@ -1541,82 +1542,3 @@ def capture_output(self) -> Generator[_OutputCapture, None, None]:
15411542
finally:
15421543
sys.stdout, sys.stderr = old_stdout, old_stderr
15431544

1544-
1545-
def assert_close_with_mismatch_tolerance(
1546-
actual: object,
1547-
expected: object,
1548-
*,
1549-
atol: float = 1e-4,
1550-
rtol: float = 1e-4,
1551-
max_mismatch_pct: float = 0.01,
1552-
max_abs_diff: float | None = None,
1553-
max_rel_diff: float | None = None,
1554-
) -> None:
1555-
"""Check that actual and expected are close, tolerating a small fraction of mismatches.
1556-
1557-
First tries ``torch.testing.assert_close`` with the given *atol*/*rtol*.
1558-
If that fails **and** both arguments are tensors, falls back to a relaxed
1559-
check using the same mismatch definition as ``torch.testing.assert_close``
1560-
(``|actual - expected| > atol + rtol * |expected|``):
1561-
1562-
- *max_mismatch_pct*: maximum allowed fraction of mismatched elements
1563-
(default 1%). Always checked.
1564-
- *max_abs_diff*: if not None, the greatest absolute difference across
1565-
all elements must not exceed this value.
1566-
- *max_rel_diff*: if not None, the greatest relative difference
1567-
(``|actual - expected| / |expected|``) must not exceed this value.
1568-
1569-
This is useful for kernels where most elements match but a tiny
1570-
fraction have large relative differences. Pass this function directly as
1571-
``autotune_baseline_accuracy_check_fn`` for the default thresholds, or use
1572-
``functools.partial`` to customize them::
1573-
1574-
from functools import partial
1575-
from helion._testing import assert_close_with_mismatch_tolerance
1576-
1577-
@helion.kernel(
1578-
autotune_baseline_accuracy_check_fn=partial(
1579-
assert_close_with_mismatch_tolerance,
1580-
max_mismatch_pct=0.05,
1581-
max_abs_diff=10.0,
1582-
max_rel_diff=15.0,
1583-
),
1584-
)
1585-
def my_kernel(...): ...
1586-
"""
1587-
try:
1588-
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
1589-
return
1590-
except AssertionError:
1591-
if not (
1592-
isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor)
1593-
):
1594-
raise
1595-
1596-
abs_diff = (actual - expected).abs()
1597-
total = actual.numel()
1598-
1599-
# Use the same mismatch definition as torch.testing.assert_close:
1600-
# an element is mismatched when |actual - expected| > atol + rtol * |expected|
1601-
mismatched = (abs_diff > atol + rtol * expected.abs()).sum().item()
1602-
mismatch_pct = mismatched / total if total > 0 else 0.0
1603-
1604-
if mismatch_pct > max_mismatch_pct:
1605-
raise AssertionError(
1606-
f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}"
1607-
)
1608-
1609-
if max_abs_diff is not None:
1610-
worst_abs = abs_diff.max().item()
1611-
if worst_abs > max_abs_diff:
1612-
raise AssertionError(
1613-
f"Absolute diff too large: {worst_abs} > {max_abs_diff}"
1614-
)
1615-
1616-
if max_rel_diff is not None:
1617-
rel_diff = abs_diff / expected.abs().clamp(min=1e-6)
1618-
worst_rel = rel_diff.max().item()
1619-
if worst_rel > max_rel_diff:
1620-
raise AssertionError(
1621-
f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}"
1622-
)

helion/autotuner/base_search.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,22 @@ class BenchmarkResult(NamedTuple):
219219
}
220220

221221

222-
def _assert_close(actual: object, expected: object, atol: float, rtol: float) -> None:
223-
"""Like torch.testing.assert_close but handles fp8 and uses chunked comparison for large tensors."""
222+
def _assert_close(
223+
actual: object,
224+
expected: object,
225+
*,
226+
atol: float,
227+
rtol: float,
228+
max_mismatch_pct: float | None = None,
229+
max_abs_diff: float | None = None,
230+
max_rel_diff: float | None = None,
231+
) -> None:
232+
"""Like torch.testing.assert_close but handles fp8, pytree structures, and strings.
233+
234+
For tensors, uses chunked comparison for large tensors. When
235+
*max_mismatch_pct* is set, falls back to a relaxed mismatch-fraction check
236+
instead of raising immediately on the first out-of-tolerance element.
237+
"""
224238

225239
def convert(t: torch.Tensor) -> torch.Tensor:
226240
return t.view(torch.uint8) if t.dtype in _FP8_DTYPES else t
@@ -241,7 +255,35 @@ def convert(t: torch.Tensor) -> torch.Tensor:
241255

242256
for a, e in zip(actual_flat, expected_flat, strict=True):
243257
if isinstance(a, torch.Tensor):
244-
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
258+
if max_mismatch_pct is not None:
259+
try:
260+
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
261+
continue
262+
except AssertionError:
263+
pass
264+
abs_diff = (a - e).abs()
265+
total = a.numel()
266+
mismatched = (abs_diff > atol + rtol * e.abs()).sum().item()
267+
mismatch_pct = mismatched / total if total > 0 else 0.0
268+
if mismatch_pct > max_mismatch_pct:
269+
raise AssertionError(
270+
f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}"
271+
)
272+
if max_abs_diff is not None:
273+
worst_abs = abs_diff.max().item()
274+
if worst_abs > max_abs_diff:
275+
raise AssertionError(
276+
f"Absolute diff too large: {worst_abs} > {max_abs_diff}"
277+
)
278+
if max_rel_diff is not None:
279+
rel_diff = abs_diff / e.abs().clamp(min=1e-6)
280+
worst_rel = rel_diff.max().item()
281+
if worst_rel > max_rel_diff:
282+
raise AssertionError(
283+
f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}"
284+
)
285+
else:
286+
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
245287
elif isinstance(a, str):
246288
if not isinstance(e, str):
247289
raise AssertionError(f"Type mismatch {a} vs {e}")

0 commit comments

Comments
 (0)