Skip to content

Commit 59ab660

Browse files
committed
dbg
stack-info: PR: #1911, branch: shunting314/stack/30
1 parent bd61daf commit 59ab660

File tree

4 files changed

+30
-14
lines changed

4 files changed

+30
-14
lines changed

examples/distributed/fp8_matmul_reduce_scatter.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import torch.distributed._symmetric_memory as symm_mem
2222

2323
import helion
24-
from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance
2524
from helion._testing import DEVICE
2625
from helion._testing import run_example
26+
from helion.autotuner.base_search import (
27+
_assert_close as assert_close_with_mismatch_tolerance,
28+
)
2729
import helion.language as hl
2830
from helion.runtime.dist_utils import symm_mem_sync
2931

@@ -33,13 +35,17 @@
3335
max_mismatch_pct=1e-3,
3436
)
3537

38+
config = helion.Config(
39+
block_sizes=[64, 64, 32], # M, N, K
40+
num_warps=8,
41+
num_stages=3,
42+
)
43+
44+
# config = helion.Config(block_sizes=[64, 128, 128], indexing=['pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[1], load_eviction_policies=['last', '', '', '', '', 'first', '', '', '', '', '', 'last'], loop_orders=[[0, 1]], num_sm_multiplier=2, num_stages=1, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, None], range_multi_buffers=[True, None], range_unroll_factors=[3, 0], range_warp_specializes=[])
45+
3646

3747
@helion.kernel(
38-
config=helion.Config(
39-
block_sizes=[64, 64, 32], # M, N, K
40-
num_warps=8,
41-
num_stages=3,
42-
),
48+
config=config,
4349
static_shapes=True,
4450
ignore_warnings=[helion.exc.TensorOperationInWrapper],
4551
autotune_baseline_accuracy_check_fn=functools.partial(
@@ -82,8 +88,10 @@ def fp8_matmul_reduce_scatter_kernel(
8288
acc = hl.dot(a[tile_m, tile_k], b[tile_k, tile_n], acc=acc)
8389

8490
# Apply per-row and per-column scales
85-
acc = acc * scale_a[tile_m, :].to(torch.float32) * scale_b[:, tile_n].to(
86-
torch.float32
91+
acc = (
92+
acc
93+
* scale_a[tile_m, :].to(torch.float32)
94+
* scale_b[:, tile_n].to(torch.float32)
8795
)
8896

8997
# Store bfloat16 partial result to this rank's symmetric-memory buffer
@@ -165,7 +173,9 @@ def reference_fp8_matmul_reduce_scatter(
165173
if group is None:
166174
raise RuntimeError("Distributed group is not initialized")
167175

168-
c = torch._scaled_mm(a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)
176+
c = torch._scaled_mm(
177+
a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16
178+
)
169179

170180
world_size = dist.get_world_size(group)
171181
M_scatter = c.shape[0] // world_size
@@ -223,7 +233,10 @@ def test(M: int, N: int, K: int, device: torch.device) -> None:
223233

224234
run_example(
225235
functools.partial(helion_fp8_matmul_reduce_scatter, symm_mem_buffer),
226-
{"nccl+cublas": reference_fp8_matmul_reduce_scatter, "fused_baseline": reference_fused_scaled_matmul_reduce_scatter},
236+
{
237+
"nccl+cublas": reference_fp8_matmul_reduce_scatter,
238+
"fused_baseline": reference_fused_scaled_matmul_reduce_scatter,
239+
},
227240
(a, b, scale_a, scale_b),
228241
**tolerance,
229242
)

helion/_testing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
from ._utils import counters
3535
from .runtime.settings import _get_backend
3636
from .runtime.settings import is_pallas_interpret
37+
from helion.autotuner.base_search import (
38+
_assert_close as assert_close_with_mismatch_tolerance,
39+
)
3740
from helion.autotuner.base_search import _clone_args
38-
from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance
3941

4042
if _get_backend() == "pallas":
4143
from .autotuner.benchmarking import compute_repeat_generic as compute_repeat
@@ -1541,4 +1543,3 @@ def capture_output(self) -> Generator[_OutputCapture, None, None]:
15411543
yield capture
15421544
finally:
15431545
sys.stdout, sys.stderr = old_stdout, old_stderr
1544-

helion/autotuner/base_search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,8 +635,9 @@ def _validate_against_baseline(
635635
custom_check = self.settings.autotune_baseline_accuracy_check_fn
636636
if custom_check is not None:
637637
custom_check(output, self._baseline_output)
638-
if len(self._mutated_arg_indices) > 0:
639-
custom_check(args, self._baseline_post_args)
638+
if os.getenv("CHECK_INPUT_ACCURACY", "1") == "1":
639+
if len(self._mutated_arg_indices) > 0:
640+
custom_check(args, self._baseline_post_args)
640641
else:
641642
_assert_close(
642643
output,

helion/runtime/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ class Settings(_Settings):
660660
"If True, allow torch.compile to fuse this Helion kernel with surrounding Inductor ops "
661661
"(prologue/epilogue) when used inside torch.compile. Default False. "
662662
"Set HELION_TORCH_COMPILE_FUSION=1 to enable globally."
663+
),
663664
"config_filter": (
664665
"Optional callable ``(config: Config) -> bool`` that the autotuner calls on every "
665666
"candidate config before compiling or benchmarking it. Configs for which the "

0 commit comments

Comments
 (0)