Skip to content

Commit d9a1db1

Browse files
committed
no tensor descriptor for the 2d paralell gemm example
stack-info: PR: #1848, branch: shunting314/stack/26
1 parent a42cffc commit d9a1db1

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

examples/distributed/two_dim_parallel_matmul.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@
4949
from helion.runtime.dist_utils import symm_mem_sync
5050

5151

52+
def _no_tensor_descriptor(config: helion.Config) -> bool:
53+
"""
54+
Reject configs that use tensor_descriptor indexing.
55+
56+
TMA (cp.async.bulk.tensor) cannot access symmetric-memory virtual address
57+
ranges, so any config with tensor_descriptor indexing will fault with
58+
'misaligned address' at runtime.
59+
60+
Work around https://github.com/pytorch/helion/issues/1846
61+
"""
62+
indexing = config.get("indexing")
63+
if isinstance(indexing, list):
64+
return "tensor_descriptor" not in indexing
65+
return indexing != "tensor_descriptor"
66+
67+
5268
@helion.kernel(
5369
config=helion.Config(
5470
block_sizes=[64, 64, 32],
@@ -58,6 +74,7 @@
5874
),
5975
static_shapes=True,
6076
ignore_warnings=[helion.exc.TensorOperationInWrapper],
77+
config_filter=_no_tensor_descriptor,
6178
)
6279
def two_dim_parallel_matmul_kernel(
6380
a_local: torch.Tensor, # [M/SP, K/TP]

0 commit comments

Comments
 (0)