Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ def __init__(
self._indexing_config = config.indexing
self.indexing_strategies: list[IndexingStrategy] = []

# Atomic indexing config (separate from load/store indexing)
self._atomic_indexing_config = config.atomic_indexing
self.atomic_indexing_strategies: list[IndexingStrategy] = []
self.atomic_op_index = 0

self.rng_seed_count = 0
self.device_load_index = 0
self.device_store_index = 0
Expand Down Expand Up @@ -354,6 +359,34 @@ def get_indexing_strategy(self, index: int) -> IndexingStrategy:

return self.indexing_strategies[index]

def get_atomic_indexing_strategy(self, index: int) -> IndexingStrategy:
from .indexing_strategy import IndexingStrategy
from .indexing_strategy import PointerIndexingStrategy

while len(self.atomic_indexing_strategies) <= index:
idx = len(self.atomic_indexing_strategies)

if isinstance(self._atomic_indexing_config, str):
if not self.atomic_indexing_strategies:
strategy = IndexingStrategy.select(self._atomic_indexing_config)
else:
strategy = self.atomic_indexing_strategies[0]
elif (
isinstance(self._atomic_indexing_config, list)
and self._atomic_indexing_config
):
assert idx < len(self._atomic_indexing_config), (
f"Atomic operation {idx} exceeds atomic_indexing config length "
f"{len(self._atomic_indexing_config)}. Please specify atomic_indexing for all atomic ops."
)
strategy = IndexingStrategy.select(self._atomic_indexing_config[idx])
else:
strategy = PointerIndexingStrategy()

self.atomic_indexing_strategies.append(strategy)

return self.atomic_indexing_strategies[index]

def has_rng_ops(self) -> bool:
"""Check if this kernel uses any RNG operations."""
return self.rng_seed_count > 0 and self.rng_seed_buffer_param_name is not None
Expand Down
28 changes: 28 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,18 @@ def _count_device_loads_and_stores(device_ir: DeviceIR) -> tuple[int, int, int]:
return total_load_count, loads_without_eviction_policy, store_count


def _count_device_atomics(device_ir: DeviceIR) -> int:
"""Count the number of atomic operations in device code for autotuning."""
from ..language import atomic_ops

atomic_count = 0
for graph_info in device_ir.graphs:
for node in graph_info.graph.nodes:
if node.op == "call_function" and node.target in vars(atomic_ops).values():
atomic_count += 1
return atomic_count


def _register_load_store_tunables(
total_load_count: int, loads_without_eviction_policy: int, store_count: int
) -> None:
Expand Down Expand Up @@ -1773,6 +1785,21 @@ def _register_load_store_tunables(
)


def _register_atomic_tunables(atomic_count: int) -> None:
"""Register atomic_indexing tunable for all atomic operations."""
if atomic_count == 0:
return

from ..autotuner.config_fragment import EnumFragment
from ..autotuner.config_fragment import ListOf

env = CompileEnvironment.current()
env.config_spec.atomic_indexing = ListOf(
EnumFragment(choices=env.config_spec.valid_atomic_indexing_types()),
length=atomic_count,
)


def lower_to_device_ir(func: HostFunction) -> DeviceIR:
device_ir = DeviceIR()
with func, device_ir, compile_lock:
Expand Down Expand Up @@ -1834,6 +1861,7 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
_register_load_store_tunables(
total_load_count, loads_without_eviction_policy, store_count
)
_register_atomic_tunables(_count_device_atomics(device_ir))

return device_ir

Expand Down
77 changes: 77 additions & 0 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import collections
import dataclasses
from typing import TYPE_CHECKING
from typing import ClassVar
from typing import NamedTuple

import sympy
Expand Down Expand Up @@ -166,6 +167,17 @@ def codegen_store(
) -> ast.AST:
raise NotImplementedError

def codegen_atomic(
self,
op: str,
state: CodegenState,
fake_tensor: torch.Tensor,
subscript: list[object],
value: ast.AST,
sem: ast.AST,
) -> ast.AST:
raise NotImplementedError

@staticmethod
def select(indexing_literal: IndexingLiteral) -> IndexingStrategy:
if indexing_literal == "pointer":
Expand Down Expand Up @@ -305,6 +317,25 @@ def codegen_store(
mask=indexing.mask_expr,
)

def codegen_atomic(
self,
op: str,
state: CodegenState,
fake_tensor: torch.Tensor,
subscript: list[object],
value: ast.AST,
sem: ast.AST,
) -> ast.AST:
indexing = SubscriptIndexing.create(state, fake_tensor, subscript)
name = state.device_function.tensor_arg(fake_tensor).name
return expr_from_string(
f"tl.{op}({name} + {{offset}}, {{value}}, mask={{mask}}, sem={{sem}})",
offset=indexing.index_expr,
value=value,
mask=indexing.mask_expr,
sem=sem,
)


class BlockPtrIndexingStrategy(IndexingStrategy):
"""Use block_ptr to load/store from tensors"""
Expand Down Expand Up @@ -544,6 +575,52 @@ def codegen_store(
value=store_value,
)

# Ops supported by TMA cp.reduce.async.bulk.tensor via Triton descriptor API
_TMA_ATOMIC_OPS: ClassVar[set[str]] = {
"atomic_add",
"atomic_and",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xor",
}

def codegen_atomic(
self,
op: str,
state: CodegenState,
fake_tensor: torch.Tensor,
subscript: list[object],
value: ast.AST,
sem: ast.AST,
) -> ast.AST:
fallback = PointerIndexingStrategy().codegen_atomic
# Only certain ops are supported by TMA reduce
if op not in self._TMA_ATOMIC_OPS:
return fallback(op, state, fake_tensor, subscript, value, sem)
# Descriptor atomics return void; fall back if the return value is used
if state.fx_node is not None and len(state.fx_node.users) > 0:
return fallback(op, state, fake_tensor, subscript, value, sem)
# Descriptor atomics have no sem parameter; fall back for non-relaxed
if isinstance(sem, ast.Constant) and sem.value != "relaxed":
return fallback(op, state, fake_tensor, subscript, value, sem)
if not self.is_supported(state, fake_tensor, subscript):
return fallback(op, state, fake_tensor, subscript, value, sem)
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
desc_arg = indexing.tensor_descriptor_arg(state)
atomic_value = indexing.reshape_store(state, value)

if desc_arg.permutation is not None:
atomic_value = expr_from_string(
f"tl.permute({{value}}, {desc_arg.permutation!r})",
value=atomic_value,
)

return expr_from_string(
f"{indexing.tensor_descriptor(state)}.{op}({indexing.offsets_str_permuted(state)}, {{value}})",
value=atomic_value,
)


class StackIndexingStrategy:
"""
Expand Down
18 changes: 18 additions & 0 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _get_backend_tunable_keys() -> frozenset[str]:
"num_sm_multiplier",
"maxnreg",
"indexing",
"atomic_indexing",
"load_eviction_policies",
"pallas_loop_type",
*BACKEND_TUNABLE_KEYS,
Expand Down Expand Up @@ -179,6 +180,10 @@ def __init__(
EnumFragment(choices=self.valid_indexing_types()),
length=0,
)
self.atomic_indexing = ListOf(
EnumFragment(choices=self.valid_atomic_indexing_types()),
length=0,
)
self.epilogue_subtile_candidate_enabled: bool = False
self.epilogue_subtile_autotune_choices: tuple[int | None, ...] | None = None
self.epilogue_subtile_k_hint: int = 0
Expand Down Expand Up @@ -271,6 +276,12 @@ def valid_indexing_types(self) -> tuple[IndexingLiteral, ...]:
return ("pointer",)
return ("pointer", "block_ptr")

def valid_atomic_indexing_types(self) -> tuple[IndexingLiteral, ...]:
"""Atomic ops only support pointer and tensor_descriptor (no block_ptr)."""
if supports_tensor_descriptor():
return ("pointer", "tensor_descriptor")
return ("pointer",)

def _remove_duplicates(self) -> None:
self.num_threads._remove_duplicates()
self.loop_orders._remove_duplicates()
Expand Down Expand Up @@ -448,6 +459,7 @@ def normalize(
"static_ranges",
"load_eviction_policies",
"indexing",
"atomic_indexing",
):
if not config.get(name):
config.pop(name, None)
Expand All @@ -458,6 +470,7 @@ def normalize(
"num_stages",
"load_eviction_policies",
"indexing",
"atomic_indexing",
"pid_type",
"num_sm_multiplier",
"maxnreg",
Expand All @@ -475,6 +488,8 @@ def normalize(
)
if self.supports_config_key("indexing"):
config.setdefault("indexing", self.indexing.default())
if self.supports_config_key("atomic_indexing"):
config.setdefault("atomic_indexing", self.atomic_indexing.default())
for key, fragment in self.backend_tunable_fragments.items():
config.setdefault(key, fragment.default())

Expand Down Expand Up @@ -756,6 +771,8 @@ def _flat_fields(
fields["num_stages"] = num_stages_fragment
if self.supports_config_key("indexing"):
fields["indexing"] = self.indexing
if self.supports_config_key("atomic_indexing"):
fields["atomic_indexing"] = self.atomic_indexing
if self.supports_config_key("pid_type"):
fields["pid_type"] = EnumFragment(self.allowed_pid_types)
if self.supports_config_key("num_sm_multiplier"):
Expand Down Expand Up @@ -840,6 +857,7 @@ def flat_config(
"static_ranges",
"load_eviction_policies",
"indexing",
"atomic_indexing",
):
if not config.get(name):
config.pop(name, None)
Expand Down
25 changes: 8 additions & 17 deletions helion/language/atomic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def _prepare_mem_args(


def _codegen_common(
tl_func: str, state: CodegenState, value_exprs: list[ast.AST]
op: str, state: CodegenState, value_exprs: list[ast.AST]
) -> ast.AST:
"""Route any single-value atomic op through the atomic_indexing strategy."""
target = state.proxy_arg(0)
index = state.proxy_arg(1)
sem = expr_from_string(repr(state.proxy_arg(len(state.ast_args) - 1)))
Expand All @@ -70,23 +71,13 @@ def _codegen_common(

host_function = HostFunction.current()
if target not in host_function.tensor_to_origin:
raise exc.AtomicOnDeviceTensor(tl_func)
raise exc.AtomicOnDeviceTensor(op)

indices = SubscriptIndexing.create(state, target, index)
name = state.device_function.tensor_arg(target).name

placeholder_names = [f"v{i}" for i in range(len(value_exprs))]
values_section = (
", " + ", ".join([f"{{{n}}}" for n in placeholder_names]) if value_exprs else ""
)
placeholders = dict(zip(placeholder_names, value_exprs, strict=False))
return expr_from_string(
f"tl.{tl_func}({name} + {{offset}}{values_section}, mask={{mask}}, sem={{sem}})",
offset=indices.index_expr,
mask=indices.mask_expr,
sem=sem,
**placeholders,
)
device_fn = state.device_function
indexing_idx = device_fn.atomic_op_index
device_fn.atomic_op_index += 1
strategy = device_fn.get_atomic_indexing_strategy(indexing_idx)
return strategy.codegen_atomic(op, state, target, index, value_exprs[0], sem)


def _cute_pointer_expr(
Expand Down
12 changes: 12 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
num_sm_multiplier: NumSmMultiplierLiteral | None = None,
maxnreg: MaxnregLiteral | None = None,
indexing: IndexingLiteral | list[IndexingLiteral] | None = None,
atomic_indexing: IndexingLiteral | list[IndexingLiteral] | None = None,
advanced_controls_file: str | None = None,
epilogue_subtile: int | None = None,
# For user-defined properties
Expand Down Expand Up @@ -78,6 +79,9 @@ def __init__(
indexing=["pointer", "block_ptr", "tensor_descriptor"]
- Empty/omitted (all loads/stores default to "pointer")
Valid strategies: "pointer", "tensor_descriptor", "block_ptr"
atomic_indexing: Indexing strategy for atomic operations (e.g., hl.atomic_add).
Same format as ``indexing`` (a single string or a list per atomic op).
Defaults to "pointer" when omitted.
advanced_controls_file: Path to a PTXAS control file applied during compilation, or empty string for none.
epilogue_subtile: Split factor for the epilogue (post-matmul pointwise + store) along
the N dimension. None = disabled (default), valid values are 2 or 4.
Expand All @@ -101,6 +105,7 @@ def __init__(
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
"atomic_indexing": atomic_indexing,
"pid_type": pid_type,
"num_sm_multiplier": num_sm_multiplier,
"maxnreg": maxnreg,
Expand Down Expand Up @@ -302,6 +307,13 @@ def indexing(self) -> IndexingLiteral | list[IndexingLiteral]:
"IndexingLiteral | list[IndexingLiteral]", self.config.get("indexing", [])
)

@property
def atomic_indexing(self) -> IndexingLiteral | list[IndexingLiteral]:
return cast(
"IndexingLiteral | list[IndexingLiteral]",
self.config.get("atomic_indexing", []),
)

@property
def epilogue_subtile(self) -> int | None:
return cast("int | None", self.config.get("epilogue_subtile", None))
Expand Down
Loading
Loading