Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ jobs:
${{ inputs.env-vars }} HELION_PRINT_OUTPUT_CODE=1 HELION_ASSERT_CACHE_HIT=1 python benchmarks/run.py \
--op $kernel \
--metrics speedup,accuracy \
--measure-compile-time \
--latency-measure-mode triton_do_bench \
--cudagraph \
--only $IMPLS \
Expand Down
86 changes: 85 additions & 1 deletion benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from torch.utils._pytree import tree_map

from helion._compat import get_device_name
from helion._compile_time import get_total_time as get_compile_total_time
from helion._compile_time import reset as reset_compile_time
from helion._testing import get_nvidia_gpu_model
from helion._utils import counters
from helion.autotuner.metrics import AutotuneMetrics
Expand Down Expand Up @@ -1023,6 +1025,7 @@ def run_kernel(
results: list[RunResult],
kernel_mappings: dict[str, tuple[str, ...]] | None = None,
kernel_metric_mappings: dict[str, dict[str, str]] | None = None,
measure_compile_time: bool = False,
) -> None:
"""Run a kernel benchmark, handling both single and multiple variants."""
# Use provided mappings or default to global mappings
Expand Down Expand Up @@ -1081,6 +1084,7 @@ def run_kernel(
operator_args,
results,
active_metrics,
measure_compile_time=measure_compile_time,
)


Expand All @@ -1093,6 +1097,7 @@ def run_kernel_variants(
operator_args: dict[str, Any] | None,
results: list[RunResult],
kernel_metric_mappings: dict[str, dict[str, str]] | None = None,
measure_compile_time: bool = False,
) -> None:
"""Run kernel variants in the same benchmark run."""

Expand Down Expand Up @@ -1183,6 +1188,9 @@ def run_kernel_variants(
# pyrefly: ignore [missing-import]
from tritonbench.utils.triton_op import register_benchmark

# Compile time tracking per variant
variant_compile_times: dict[str, list[float]] = {}

# Register all variants as separate methods
for module_path, func_name in variants:
# Import the kernel function
Expand All @@ -1207,6 +1215,7 @@ def run_kernel_variants(
def create_helion_method(
mod: Any, # noqa: ANN401
kfunc: Callable[..., Any],
compile_time_list: list[float] | None = None,
) -> Callable[..., Any]:
def helion_method(
self: object,
Expand Down Expand Up @@ -1245,6 +1254,29 @@ def helion_method(
measured_func_callable = kfunc(self, *args, **kwargs)

assert callable(measured_func_callable)

if compile_time_list is not None:
original = measured_func_callable
first_call = True
ct_list = compile_time_list

def timed_callable() -> object:
nonlocal first_call
if first_call:
first_call = False
torch.cuda.synchronize()
reset_compile_time()
try:
result = original()
except Exception:
ct_list.append(get_compile_total_time())
raise
ct_list.append(get_compile_total_time())
return result
return original()

return timed_callable

return measured_func_callable

return helion_method
Expand All @@ -1253,6 +1285,12 @@ def helion_method(
variant_name = func_name
helion_method_name = f"helion_{variant_name}"

# Set up compile time tracking for this variant
compile_times: list[float] | None = None
if measure_compile_time:
compile_times = []
variant_compile_times[func_name] = compile_times

# Use register_benchmark decorator
decorated_method = register_benchmark(
operator_name=operator_name,
Expand All @@ -1261,7 +1299,7 @@ def helion_method(
enabled=True,
fwd_only=False,
label=helion_method_name,
)(create_helion_method(module, kernel_func))
)(create_helion_method(module, kernel_func, compile_times))

# Set the decorated method on the Operator class
setattr(Operator, helion_method_name, decorated_method)
Expand Down Expand Up @@ -1350,6 +1388,40 @@ def accuracy_fail_hook(
except Exception:
logger.exception("failed to process results")

# Add compile time metrics (per-shape, same format as speedup)
if measure_compile_time and variant_compile_times:
# Get shapes from the most recent result for this kernel
kernel_results = [r for r in results if r.model == kernel_name]
shapes = kernel_results[-1].shape if kernel_results else []
device = get_device_name() or "unknown"
for func_name, times in variant_compile_times.items():
if not times:
continue
# Align compile times with shapes (both are in input order)
if len(times) != len(shapes):
logger.warning(
f"Compile time count ({len(times)}) != shape count "
f"({len(shapes)}) for {kernel_name}/{func_name}, skipping"
)
continue
metric_name = "helion_compile_time_s"
if len(variants) > 1:
metric_name = f"helion_{func_name}_compile_time_s"
results.append(
RunResult(
model=kernel_name,
device=device,
shape=shapes,
metrics={metric_name: times},
)
)
print(
f"Compile time for {kernel_name}/{func_name}: "
f"{', '.join(f'{t:.3f}s' for t in times)} "
f"({len(times)} shapes)",
file=sys.stderr,
)

# Force garbage collection multiple times to ensure memory is freed
for _ in range(3):
gc.collect()
Expand Down Expand Up @@ -1607,6 +1679,12 @@ def main() -> None:
help="Export autotune metrics to a JSON file at the given path. "
"Also set via HELION_AUTOTUNE_METRICS_JSON=<path>.",
)
parser.add_argument(
"--measure-compile-time",
action="store_true",
help="Measure and report Helion kernel compile time (seconds) for each input shape. "
"Results are included in JSON output as helion_compile_time_s metric.",
)

# Parse known args to get the kernel name, pass rest to tritonbench
args, tritonbench_args = parser.parse_known_args()
Expand Down Expand Up @@ -1729,6 +1807,9 @@ def main() -> None:

results: list[RunResult] = []

if args.measure_compile_time:
os.environ["HELION_MEASURE_COMPILE_TIME"] = "1"

collected_metrics: list[AutotuneMetrics] = []
if args.autotune_metrics or args.autotune_metrics_json:
register_post_autotune_hook(collected_metrics.append)
Expand Down Expand Up @@ -1759,6 +1840,7 @@ def main() -> None:
results,
active_kernel_mappings,
active_metric_mappings,
measure_compile_time=args.measure_compile_time,
)
else:
print(
Expand All @@ -1776,6 +1858,7 @@ def main() -> None:
results,
active_kernel_mappings,
active_metric_mappings,
measure_compile_time=args.measure_compile_time,
)
else:
# Run all kernels
Expand All @@ -1793,6 +1876,7 @@ def main() -> None:
results,
active_kernel_mappings,
active_metric_mappings,
measure_compile_time=args.measure_compile_time,
)

if args.output:
Expand Down
9 changes: 9 additions & 0 deletions helion/_compile_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ def _print_line(self, name: str, elapsed: float, total: float, indent: int) -> N
file=sys.stderr,
)

def get_total_time(self) -> float:
"""Get total top-level compile time in seconds."""
return sum(self._timings.get(name, 0.0) for name in self._TOP_LEVEL)

def reset(self) -> None:
"""Reset all timing data."""
with self._timer_lock:
Expand Down Expand Up @@ -298,3 +302,8 @@ def print_report() -> None:
def reset() -> None:
"""Reset all timing data."""
get_tracker().reset()


def get_total_time() -> float:
"""Get total top-level compile time from the global tracker."""
return get_tracker().get_total_time()
Loading