Skip to content

Commit d25f8ad

Browse files
committed
dbg
stack-info: PR: #1911, branch: shunting314/stack/30
1 parent e9841d2 commit d25f8ad

File tree

4 files changed

+19
-9
lines changed

4 files changed

+19
-9
lines changed

examples/distributed/fp8_matmul_reduce_scatter.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from helion._testing import DEVICE
2525
from helion._testing import assert_close_with_mismatch_tolerance
2626
from helion._testing import run_example
27+
from helion.autotuner.base_search import (
28+
_assert_close as assert_close_with_mismatch_tolerance,
29+
)
2730
import helion.language as hl
2831
from helion.runtime.dist_utils import symm_mem_sync
2932

@@ -33,13 +36,17 @@
3336
"max_mismatch_pct": 1e-3,
3437
}
3538

39+
config = helion.Config(
40+
block_sizes=[64, 64, 32], # M, N, K
41+
num_warps=8,
42+
num_stages=3,
43+
)
44+
45+
# config = helion.Config(block_sizes=[64, 128, 128], indexing=['pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[1], load_eviction_policies=['last', '', '', '', '', 'first', '', '', '', '', '', 'last'], loop_orders=[[0, 1]], num_sm_multiplier=2, num_stages=1, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, None], range_multi_buffers=[True, None], range_unroll_factors=[3, 0], range_warp_specializes=[])
46+
3647

3748
@helion.kernel(
38-
config=helion.Config(
39-
block_sizes=[64, 64, 32], # M, N, K
40-
num_warps=8,
41-
num_stages=3,
42-
),
49+
config=config,
4350
static_shapes=True,
4451
ignore_warnings=[helion.exc.TensorOperationInWrapper],
4552
autotune_baseline_accuracy_check_fn=functools.partial(

helion/_testing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
from ._utils import counters
3535
from .runtime.settings import _get_backend
3636
from .runtime.settings import is_pallas_interpret
37+
from helion.autotuner.base_search import (
38+
_assert_close as assert_close_with_mismatch_tolerance,
39+
)
3740
from helion.autotuner.base_search import _clone_args
38-
from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance
3941

4042
if _get_backend() == "pallas":
4143
from .autotuner.benchmarking import compute_repeat_generic as compute_repeat
@@ -1541,4 +1543,3 @@ def capture_output(self) -> Generator[_OutputCapture, None, None]:
15411543
yield capture
15421544
finally:
15431545
sys.stdout, sys.stderr = old_stdout, old_stderr
1544-

helion/autotuner/base_search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,8 +635,9 @@ def _validate_against_baseline(
635635
custom_check = self.settings.autotune_baseline_accuracy_check_fn
636636
if custom_check is not None:
637637
custom_check(output, self._baseline_output)
638-
if len(self._mutated_arg_indices) > 0:
639-
custom_check(args, self._baseline_post_args)
638+
if os.getenv("CHECK_INPUT_ACCURACY", "1") == "1":
639+
if len(self._mutated_arg_indices) > 0:
640+
custom_check(args, self._baseline_post_args)
640641
else:
641642
_assert_close(
642643
output,

helion/runtime/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ class Settings(_Settings):
660660
"If True, allow torch.compile to fuse this Helion kernel with surrounding Inductor ops "
661661
"(prologue/epilogue) when used inside torch.compile. Default False. "
662662
"Set HELION_TORCH_COMPILE_FUSION=1 to enable globally."
663+
),
663664
"config_filter": (
664665
"Optional callable ``(config: Config) -> bool`` that the autotuner calls on every "
665666
"candidate config before compiling or benchmarking it. Configs for which the "

0 commit comments

Comments
 (0)