Skip to content

Commit 9e75d29

Browse files
author
Ethan Che
committed
Generalize codegen_atomic_add to codegen_atomic for all atomic ops
Refactor the atomic codegen path so all atomic operations (add, and, or, xor, max, min, xchg) route through a generic codegen_atomic(op, ...) method on IndexingStrategy. This enables tensor_descriptor-based TMA atomics for all supported reduction ops (add, and, max, min, or, xor), with automatic fallback to pointer for unsupported ops (xchg, cas), return-value-consuming calls, and non-relaxed memory semantics.
1 parent e351047 commit 9e75d29

File tree

3 files changed

+182
-85
lines changed

3 files changed

+182
-85
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import collections
55
import dataclasses
66
from typing import TYPE_CHECKING
7+
from typing import ClassVar
78
from typing import NamedTuple
89

910
import sympy
@@ -166,8 +167,9 @@ def codegen_store(
166167
) -> ast.AST:
167168
raise NotImplementedError
168169

169-
def codegen_atomic_add(
170+
def codegen_atomic(
170171
self,
172+
op: str,
171173
state: CodegenState,
172174
fake_tensor: torch.Tensor,
173175
subscript: list[object],
@@ -315,8 +317,9 @@ def codegen_store(
315317
mask=indexing.mask_expr,
316318
)
317319

318-
def codegen_atomic_add(
320+
def codegen_atomic(
319321
self,
322+
op: str,
320323
state: CodegenState,
321324
fake_tensor: torch.Tensor,
322325
subscript: list[object],
@@ -326,7 +329,7 @@ def codegen_atomic_add(
326329
indexing = SubscriptIndexing.create(state, fake_tensor, subscript)
327330
name = state.device_function.tensor_arg(fake_tensor).name
328331
return expr_from_string(
329-
f"tl.atomic_add({name} + {{offset}}, {{value}}, mask={{mask}}, sem={{sem}})",
332+
f"tl.{op}({name} + {{offset}}, {{value}}, mask={{mask}}, sem={{sem}})",
330333
offset=indexing.index_expr,
331334
value=value,
332335
mask=indexing.mask_expr,
@@ -394,7 +397,6 @@ def codegen_store(
394397
)
395398

396399

397-
398400
class TensorDescriptorIndexingStrategy(IndexingStrategy):
399401
"""Use TensorDescriptor to load/store from tensors"""
400402

@@ -573,23 +575,37 @@ def codegen_store(
573575
value=store_value,
574576
)
575577

576-
def codegen_atomic_add(
578+
# Ops supported by TMA cp.reduce.async.bulk.tensor via Triton descriptor API
579+
_TMA_ATOMIC_OPS: ClassVar[set[str]] = {
580+
"atomic_add",
581+
"atomic_and",
582+
"atomic_max",
583+
"atomic_min",
584+
"atomic_or",
585+
"atomic_xor",
586+
}
587+
588+
def codegen_atomic(
577589
self,
590+
op: str,
578591
state: CodegenState,
579592
fake_tensor: torch.Tensor,
580593
subscript: list[object],
581594
value: ast.AST,
582595
sem: ast.AST,
583596
) -> ast.AST:
584-
fallback = PointerIndexingStrategy().codegen_atomic_add
585-
# Descriptor atomic_add returns void; fall back if the return value is used
597+
fallback = PointerIndexingStrategy().codegen_atomic
598+
# Only certain ops are supported by TMA reduce
599+
if op not in self._TMA_ATOMIC_OPS:
600+
return fallback(op, state, fake_tensor, subscript, value, sem)
601+
# Descriptor atomics return void; fall back if the return value is used
586602
if state.fx_node is not None and len(state.fx_node.users) > 0:
587-
return fallback(state, fake_tensor, subscript, value, sem)
588-
# Descriptor atomic_add has no sem parameter; fall back for non-relaxed
603+
return fallback(op, state, fake_tensor, subscript, value, sem)
604+
# Descriptor atomics have no sem parameter; fall back for non-relaxed
589605
if isinstance(sem, ast.Constant) and sem.value != "relaxed":
590-
return fallback(state, fake_tensor, subscript, value, sem)
606+
return fallback(op, state, fake_tensor, subscript, value, sem)
591607
if not self.is_supported(state, fake_tensor, subscript):
592-
return fallback(state, fake_tensor, subscript, value, sem)
608+
return fallback(op, state, fake_tensor, subscript, value, sem)
593609
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
594610
desc_arg = indexing.tensor_descriptor_arg(state)
595611
atomic_value = indexing.reshape_store(state, value)
@@ -601,7 +617,7 @@ def codegen_atomic_add(
601617
)
602618

603619
return expr_from_string(
604-
f"{indexing.tensor_descriptor(state)}.atomic_add({indexing.offsets_str_permuted(state)}, {{value}})",
620+
f"{indexing.tensor_descriptor(state)}.{op}({indexing.offsets_str_permuted(state)}, {{value}})",
605621
value=atomic_value,
606622
)
607623

helion/language/atomic_ops.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ def _prepare_mem_args(
5959

6060

6161
def _codegen_common(
62-
tl_func: str, state: CodegenState, value_exprs: list[ast.AST]
62+
op: str, state: CodegenState, value_exprs: list[ast.AST]
6363
) -> ast.AST:
64+
"""Route any single-value atomic op through the atomic_indexing strategy."""
6465
target = state.proxy_arg(0)
6566
index = state.proxy_arg(1)
6667
sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1)))
@@ -70,23 +71,13 @@ def _codegen_common(
7071

7172
host_function = HostFunction.current()
7273
if target not in host_function.tensor_to_origin:
73-
raise exc.AtomicOnDeviceTensor(tl_func)
74+
raise exc.AtomicOnDeviceTensor(op)
7475

75-
indices = SubscriptIndexing.create(state, target, index)
76-
name = state.device_function.tensor_arg(target).name
77-
78-
placeholder_names = [f"v{i}" for i in range(len(value_exprs))]
79-
values_section = (
80-
", " + ", ".join([f"{{{n}}}" for n in placeholder_names]) if value_exprs else ""
81-
)
82-
placeholders = dict(zip(placeholder_names, value_exprs, strict=False))
83-
return expr_from_string(
84-
f"tl.{tl_func}({name} + {{offset}}{values_section}, mask={{mask}}, sem={{sem}})",
85-
offset=indices.index_expr,
86-
mask=indices.mask_expr,
87-
sem=sem,
88-
**placeholders,
89-
)
76+
device_fn = state.device_function
77+
indexing_idx = device_fn.atomic_op_index
78+
device_fn.atomic_op_index += 1
79+
strategy = device_fn.get_atomic_indexing_strategy(indexing_idx)
80+
return strategy.codegen_atomic(op, state, target, index, value_exprs[0], sem)
9081

9182

9283
def _cute_pointer_expr(
@@ -590,23 +581,8 @@ def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None:
590581

591582
@_decorators.codegen(atomic_add, "triton")
592583
def _(state: CodegenState) -> ast.AST:
593-
target = state.proxy_arg(0)
594-
index = state.proxy_arg(1)
595-
value_expr = _to_ast_values([state.ast_args[2]])[0]
596-
sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1)))
597-
598-
assert isinstance(target, torch.Tensor)
599-
assert isinstance(index, list)
600-
601-
host_function = HostFunction.current()
602-
if target not in host_function.tensor_to_origin:
603-
raise exc.AtomicOnDeviceTensor("atomic_add")
604-
605-
device_fn = state.device_function
606-
indexing_idx = device_fn.atomic_op_index
607-
device_fn.atomic_op_index += 1
608-
strategy = device_fn.get_atomic_indexing_strategy(indexing_idx)
609-
return strategy.codegen_atomic_add(state, target, index, value_expr, sem)
584+
value_expr = state.ast_args[2]
585+
return _codegen_common("atomic_add", state, _to_ast_values([value_expr]))
610586

611587

612588
@_decorators.codegen(atomic_add, "cute")

test/test_atomic_ops.py

Lines changed: 144 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,58 @@ def atomic_cas_kernel(
151151
return x
152152

153153

154+
# 2D kernels for tensor descriptor atomic tests (TD requires ndim >= 2 + static_shapes)
155+
156+
157+
@helion.kernel(static_shapes=True)
158+
def atomic_add_2d_td_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
159+
for i, j in hl.tile([x.size(0), x.size(1)]):
160+
hl.atomic_add(x, [i, j], y[i, j])
161+
return x
162+
163+
164+
@helion.kernel(static_shapes=True)
165+
def atomic_and_2d_td_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
166+
for i, j in hl.tile([x.size(0), x.size(1)]):
167+
hl.atomic_and(x, [i, j], y[i, j])
168+
return x
169+
170+
171+
@helion.kernel(static_shapes=True)
172+
def atomic_or_2d_td_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
173+
for i, j in hl.tile([x.size(0), x.size(1)]):
174+
hl.atomic_or(x, [i, j], y[i, j])
175+
return x
176+
177+
178+
@helion.kernel(static_shapes=True)
179+
def atomic_xor_2d_td_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
180+
for i, j in hl.tile([x.size(0), x.size(1)]):
181+
hl.atomic_xor(x, [i, j], y[i, j])
182+
return x
183+
184+
185+
@helion.kernel(static_shapes=True)
186+
def atomic_max_2d_td_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
187+
for i, j in hl.tile([x.size(0), x.size(1)]):
188+
hl.atomic_max(x, [i, j], y[i, j])
189+
return x
190+
191+
192+
@helion.kernel(static_shapes=True)
193+
def atomic_min_2d_td_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
194+
for i, j in hl.tile([x.size(0), x.size(1)]):
195+
hl.atomic_min(x, [i, j], y[i, j])
196+
return x
197+
198+
199+
@helion.kernel(static_shapes=True)
200+
def atomic_xchg_2d_td_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
201+
for i, j in hl.tile([x.size(0), x.size(1)]):
202+
hl.atomic_xchg(x, [i, j], y[i, j])
203+
return x
204+
205+
154206
@onlyBackends(["triton", "cute", "pallas"])
155207
class TestAtomicOperations(RefEagerTestBase, TestCase):
156208
def test_basic_atomic_add(self):
@@ -425,30 +477,8 @@ def test_atomic_cas(self):
425477

426478
@onlyBackends("triton")
427479
@skipIfRocm("Tensor descriptor not supported on ROCm")
428-
def test_atomic_add_tensor_descriptor(self):
429-
"""Test that atomic_add with tensor_descriptor indexing generates desc.atomic_add."""
430-
431-
@helion.kernel(
432-
config=helion.Config(
433-
block_sizes=[64, 64],
434-
indexing="tensor_descriptor",
435-
atomic_indexing="tensor_descriptor",
436-
),
437-
static_shapes=True,
438-
)
439-
def atomic_add_td_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
440-
for i, j in hl.tile([x.size(0), x.size(1)]):
441-
hl.atomic_add(x, [i, j], y[i, j])
442-
return x
443-
444-
M, N = 128, 64
445-
x = torch.zeros(M, N, device=DEVICE, dtype=torch.float32)
446-
y = torch.ones(M, N, device=DEVICE, dtype=torch.float32)
447-
code, result = code_and_output(atomic_add_td_kernel, (x, y))
448-
expected = torch.ones(M, N, device=DEVICE, dtype=torch.float32)
449-
torch.testing.assert_close(result, expected)
450-
self.assertIn("desc.atomic_add(", code)
451-
self.assertNotIn("tl.atomic_add", code)
480+
def test_atomic_td_fallbacks(self):
481+
"""Test that tensor_descriptor atomics fall back to pointer when needed."""
452482

453483
# Return value consumed: should fall back to pointer
454484
@helion.kernel(
@@ -466,14 +496,14 @@ def atomic_add_td_prev_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
466496
out[i, j] = prev
467497
return out
468498

469-
x2 = torch.zeros(M, N, device=DEVICE, dtype=torch.float32)
470-
y2 = torch.ones(M, N, device=DEVICE, dtype=torch.float32)
471-
code2, result2 = code_and_output(atomic_add_td_prev_kernel, (x2, y2))
472-
# prev should be zeros (the old values before adding ones)
473-
expected2 = torch.zeros(M, N, device=DEVICE, dtype=torch.float32)
474-
torch.testing.assert_close(result2, expected2)
475-
self.assertIn("tl.atomic_add", code2)
476-
self.assertNotIn("desc.atomic_add(", code2)
499+
M, N = 128, 64
500+
x = torch.zeros(M, N, device=DEVICE, dtype=torch.float32)
501+
y = torch.ones(M, N, device=DEVICE, dtype=torch.float32)
502+
code, result = code_and_output(atomic_add_td_prev_kernel, (x, y))
503+
expected = torch.zeros(M, N, device=DEVICE, dtype=torch.float32)
504+
torch.testing.assert_close(result, expected)
505+
self.assertIn("tl.atomic_add", code)
506+
self.assertNotIn("desc.atomic_add(", code)
477507

478508
# Non-relaxed sem: should fall back to pointer
479509
@helion.kernel(
@@ -491,13 +521,13 @@ def atomic_add_td_release_kernel(
491521
hl.atomic_add(x, [i, j], y[i, j], sem="release")
492522
return x
493523

494-
x3 = torch.zeros(M, N, device=DEVICE, dtype=torch.float32)
495-
y3 = torch.ones(M, N, device=DEVICE, dtype=torch.float32)
496-
code3, result3 = code_and_output(atomic_add_td_release_kernel, (x3, y3))
497-
expected3 = torch.ones(M, N, device=DEVICE, dtype=torch.float32)
498-
torch.testing.assert_close(result3, expected3)
499-
self.assertIn("tl.atomic_add", code3)
500-
self.assertNotIn("desc.atomic_add(", code3)
524+
x2 = torch.zeros(M, N, device=DEVICE, dtype=torch.float32)
525+
y2 = torch.ones(M, N, device=DEVICE, dtype=torch.float32)
526+
code2, result2 = code_and_output(atomic_add_td_release_kernel, (x2, y2))
527+
expected2 = torch.ones(M, N, device=DEVICE, dtype=torch.float32)
528+
torch.testing.assert_close(result2, expected2)
529+
self.assertIn("tl.atomic_add", code2)
530+
self.assertNotIn("desc.atomic_add(", code2)
501531

502532
@onlyBackends("triton")
503533
@skipIfRocm("Tensor descriptor not supported on ROCm")
@@ -536,6 +566,81 @@ def two_atomic_adds(
536566
self.assertNotIn("out1_desc", code)
537567
self.assertNotIn("tl.atomic_add(out2", code)
538568

569+
@onlyBackends("triton")
570+
@skipIfRocm("Tensor descriptor not supported on ROCm")
571+
def test_atomic_ops_tensor_descriptor(self):
572+
"""Test all TMA-supported atomic ops generate desc.atomic_{op} codegen."""
573+
M, N = 128, 64
574+
td_config = {
575+
"block_sizes": [64, 64],
576+
"indexing": "tensor_descriptor",
577+
"atomic_indexing": "tensor_descriptor",
578+
}
579+
# (op_name, kernel, x, y, expected)
580+
cases = [
581+
(
582+
"add",
583+
atomic_add_2d_td_kernel,
584+
torch.zeros(M, N, device=DEVICE, dtype=torch.float32),
585+
torch.ones(M, N, device=DEVICE, dtype=torch.float32),
586+
torch.ones(M, N, device=DEVICE, dtype=torch.float32),
587+
),
588+
(
589+
"and",
590+
atomic_and_2d_td_kernel,
591+
torch.full((M, N), 0b1111, device=DEVICE, dtype=torch.int32),
592+
torch.full((M, N), 0b1010, device=DEVICE, dtype=torch.int32),
593+
torch.full((M, N), 0b1010, device=DEVICE, dtype=torch.int32),
594+
),
595+
(
596+
"or",
597+
atomic_or_2d_td_kernel,
598+
torch.zeros(M, N, device=DEVICE, dtype=torch.int32),
599+
torch.full((M, N), 0b1010, device=DEVICE, dtype=torch.int32),
600+
torch.full((M, N), 0b1010, device=DEVICE, dtype=torch.int32),
601+
),
602+
(
603+
"xor",
604+
atomic_xor_2d_td_kernel,
605+
torch.full((M, N), 0b1010, device=DEVICE, dtype=torch.int32),
606+
torch.full((M, N), 0b1100, device=DEVICE, dtype=torch.int32),
607+
torch.full((M, N), 0b0110, device=DEVICE, dtype=torch.int32),
608+
),
609+
(
610+
"max",
611+
atomic_max_2d_td_kernel,
612+
torch.ones(M, N, device=DEVICE, dtype=torch.int32),
613+
torch.full((M, N), 5, device=DEVICE, dtype=torch.int32),
614+
torch.full((M, N), 5, device=DEVICE, dtype=torch.int32),
615+
),
616+
(
617+
"min",
618+
atomic_min_2d_td_kernel,
619+
torch.full((M, N), 10, device=DEVICE, dtype=torch.int32),
620+
torch.full((M, N), 3, device=DEVICE, dtype=torch.int32),
621+
torch.full((M, N), 3, device=DEVICE, dtype=torch.int32),
622+
),
623+
]
624+
for op_name, kernel, x, y, expected in cases:
625+
with self.subTest(op=op_name):
626+
code, result = code_and_output(kernel, (x, y), **td_config)
627+
torch.testing.assert_close(result, expected)
628+
self.assertIn(f"desc.atomic_{op_name}(", code)
629+
self.assertNotIn(f"tl.atomic_{op_name}", code)
630+
631+
# xchg is NOT a TMA reduction op — should fall back to pointer
632+
with self.subTest(op="xchg_fallback"):
633+
x = torch.zeros(M, N, device=DEVICE, dtype=torch.int32)
634+
y = torch.ones(M, N, device=DEVICE, dtype=torch.int32)
635+
code, result = code_and_output(
636+
atomic_xchg_2d_td_kernel, (x, y), **td_config
637+
)
638+
torch.testing.assert_close(
639+
result, torch.ones(M, N, device=DEVICE, dtype=torch.int32)
640+
)
641+
self.assertIn("tl.atomic_xchg", code)
642+
self.assertNotIn("desc.atomic_xchg", code)
643+
539644

540645
if __name__ == "__main__":
541646
unittest.main()

0 commit comments

Comments
 (0)