|
35 | 35 | from .runtime.settings import _get_backend |
36 | 36 | from .runtime.settings import is_pallas_interpret |
37 | 37 | from helion.autotuner.base_search import _clone_args |
| 38 | +from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance |
38 | 39 |
|
39 | 40 | if _get_backend() == "pallas": |
40 | 41 | from .autotuner.benchmarking import compute_repeat_generic as compute_repeat |
@@ -1541,82 +1542,3 @@ def capture_output(self) -> Generator[_OutputCapture, None, None]: |
1541 | 1542 | finally: |
1542 | 1543 | sys.stdout, sys.stderr = old_stdout, old_stderr |
1543 | 1544 |
|
1544 | | - |
1545 | | -def assert_close_with_mismatch_tolerance( |
1546 | | - actual: object, |
1547 | | - expected: object, |
1548 | | - *, |
1549 | | - atol: float = 1e-4, |
1550 | | - rtol: float = 1e-4, |
1551 | | - max_mismatch_pct: float = 0.01, |
1552 | | - max_abs_diff: float | None = None, |
1553 | | - max_rel_diff: float | None = None, |
1554 | | -) -> None: |
1555 | | - """Check that actual and expected are close, tolerating a small fraction of mismatches. |
1556 | | -
|
1557 | | - First tries ``torch.testing.assert_close`` with the given *atol*/*rtol*. |
1558 | | - If that fails **and** both arguments are tensors, falls back to a relaxed |
1559 | | - check using the same mismatch definition as ``torch.testing.assert_close`` |
1560 | | - (``|actual - expected| > atol + rtol * |expected|``): |
1561 | | -
|
1562 | | - - *max_mismatch_pct*: maximum allowed fraction of mismatched elements |
1563 | | - (default 1%). Always checked. |
1564 | | - - *max_abs_diff*: if not None, the greatest absolute difference across |
1565 | | - all elements must not exceed this value. |
1566 | | - - *max_rel_diff*: if not None, the greatest relative difference |
1567 | | - (``|actual - expected| / |expected|``) must not exceed this value. |
1568 | | -
|
1569 | | - This is useful for kernels where most elements match but a tiny |
1570 | | - fraction have large relative differences. Pass this function directly as |
1571 | | - ``autotune_baseline_accuracy_check_fn`` for the default thresholds, or use |
1572 | | - ``functools.partial`` to customize them:: |
1573 | | -
|
1574 | | - from functools import partial |
1575 | | - from helion._testing import assert_close_with_mismatch_tolerance |
1576 | | -
|
1577 | | - @helion.kernel( |
1578 | | - autotune_baseline_accuracy_check_fn=partial( |
1579 | | - assert_close_with_mismatch_tolerance, |
1580 | | - max_mismatch_pct=0.05, |
1581 | | - max_abs_diff=10.0, |
1582 | | - max_rel_diff=15.0, |
1583 | | - ), |
1584 | | - ) |
1585 | | - def my_kernel(...): ... |
1586 | | - """ |
1587 | | - try: |
1588 | | - torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) |
1589 | | - return |
1590 | | - except AssertionError: |
1591 | | - if not ( |
1592 | | - isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor) |
1593 | | - ): |
1594 | | - raise |
1595 | | - |
1596 | | - abs_diff = (actual - expected).abs() |
1597 | | - total = actual.numel() |
1598 | | - |
1599 | | - # Use the same mismatch definition as torch.testing.assert_close: |
1600 | | - # an element is mismatched when |actual - expected| > atol + rtol * |expected| |
1601 | | - mismatched = (abs_diff > atol + rtol * expected.abs()).sum().item() |
1602 | | - mismatch_pct = mismatched / total if total > 0 else 0.0 |
1603 | | - |
1604 | | - if mismatch_pct > max_mismatch_pct: |
1605 | | - raise AssertionError( |
1606 | | - f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}" |
1607 | | - ) |
1608 | | - |
1609 | | - if max_abs_diff is not None: |
1610 | | - worst_abs = abs_diff.max().item() |
1611 | | - if worst_abs > max_abs_diff: |
1612 | | - raise AssertionError( |
1613 | | - f"Absolute diff too large: {worst_abs} > {max_abs_diff}" |
1614 | | - ) |
1615 | | - |
1616 | | - if max_rel_diff is not None: |
1617 | | - rel_diff = abs_diff / expected.abs().clamp(min=1e-6) |
1618 | | - worst_rel = rel_diff.max().item() |
1619 | | - if worst_rel > max_rel_diff: |
1620 | | - raise AssertionError( |
1621 | | - f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}" |
1622 | | - ) |
0 commit comments