Skip to content

Commit 661e68f

Browse files
committed
[Benchmark] Add compile time measurement to CI benchmarks
Add --measure-compile-time flag to the benchmark runner that captures per-shape Helion kernel compile time using the existing CompileTimeTracker infrastructure. Results are reported as helion_compile_time_s in the PyTorch HUD v3 JSON format, enabling continuous monitoring and regression detection on the dashboard.
1 parent b7b43b5 commit 661e68f

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

.github/workflows/benchmark.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ jobs:
194194
${{ inputs.env-vars }} HELION_PRINT_OUTPUT_CODE=1 HELION_ASSERT_CACHE_HIT=1 python benchmarks/run.py \
195195
--op $kernel \
196196
--metrics speedup,accuracy \
197+
--measure-compile-time \
197198
--latency-measure-mode triton_do_bench \
198199
--cudagraph \
199200
--only $IMPLS \

benchmarks/run.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from torch.utils._pytree import tree_map
4747

4848
from helion._compat import get_device_name
49+
from helion._compile_time import get_total_time as get_compile_total_time
50+
from helion._compile_time import reset as reset_compile_time
4951
from helion._testing import get_nvidia_gpu_model
5052
from helion._utils import counters
5153
from helion.autotuner.metrics import AutotuneMetrics
@@ -1023,6 +1025,7 @@ def run_kernel(
10231025
results: list[RunResult],
10241026
kernel_mappings: dict[str, tuple[str, ...]] | None = None,
10251027
kernel_metric_mappings: dict[str, dict[str, str]] | None = None,
1028+
measure_compile_time: bool = False,
10261029
) -> None:
10271030
"""Run a kernel benchmark, handling both single and multiple variants."""
10281031
# Use provided mappings or default to global mappings
@@ -1081,6 +1084,7 @@ def run_kernel(
10811084
operator_args,
10821085
results,
10831086
active_metrics,
1087+
measure_compile_time=measure_compile_time,
10841088
)
10851089

10861090

@@ -1093,6 +1097,7 @@ def run_kernel_variants(
10931097
operator_args: dict[str, Any] | None,
10941098
results: list[RunResult],
10951099
kernel_metric_mappings: dict[str, dict[str, str]] | None = None,
1100+
measure_compile_time: bool = False,
10961101
) -> None:
10971102
"""Run kernel variants in the same benchmark run."""
10981103

@@ -1183,6 +1188,9 @@ def run_kernel_variants(
11831188
# pyrefly: ignore [missing-import]
11841189
from tritonbench.utils.triton_op import register_benchmark
11851190

1191+
# Compile time tracking per variant
1192+
variant_compile_times: dict[str, list[float]] = {}
1193+
11861194
# Register all variants as separate methods
11871195
for module_path, func_name in variants:
11881196
# Import the kernel function
@@ -1207,6 +1215,7 @@ def run_kernel_variants(
12071215
def create_helion_method(
12081216
mod: Any, # noqa: ANN401
12091217
kfunc: Callable[..., Any],
1218+
compile_time_list: list[float] | None = None,
12101219
) -> Callable[..., Any]:
12111220
def helion_method(
12121221
self: object,
@@ -1245,6 +1254,28 @@ def helion_method(
12451254
measured_func_callable = kfunc(self, *args, **kwargs)
12461255

12471256
assert callable(measured_func_callable)
1257+
1258+
if compile_time_list is not None:
1259+
original = measured_func_callable
1260+
first_call = True
1261+
1262+
def timed_callable() -> object:
1263+
nonlocal first_call
1264+
if first_call:
1265+
first_call = False
1266+
torch.cuda.synchronize()
1267+
reset_compile_time()
1268+
try:
1269+
result = original()
1270+
except Exception:
1271+
compile_time_list.append(get_compile_total_time())
1272+
raise
1273+
compile_time_list.append(get_compile_total_time())
1274+
return result
1275+
return original()
1276+
1277+
return timed_callable
1278+
12481279
return measured_func_callable
12491280

12501281
return helion_method
@@ -1253,6 +1284,12 @@ def helion_method(
12531284
variant_name = func_name
12541285
helion_method_name = f"helion_{variant_name}"
12551286

1287+
# Set up compile time tracking for this variant
1288+
compile_times: list[float] | None = None
1289+
if measure_compile_time:
1290+
compile_times = []
1291+
variant_compile_times[func_name] = compile_times
1292+
12561293
# Use register_benchmark decorator
12571294
decorated_method = register_benchmark(
12581295
operator_name=operator_name,
@@ -1261,7 +1298,7 @@ def helion_method(
12611298
enabled=True,
12621299
fwd_only=False,
12631300
label=helion_method_name,
1264-
)(create_helion_method(module, kernel_func))
1301+
)(create_helion_method(module, kernel_func, compile_times))
12651302

12661303
# Set the decorated method on the Operator class
12671304
setattr(Operator, helion_method_name, decorated_method)
@@ -1350,6 +1387,40 @@ def accuracy_fail_hook(
13501387
except Exception:
13511388
logger.exception("failed to process results")
13521389

1390+
# Add compile time metrics (per-shape, same format as speedup)
1391+
if measure_compile_time and variant_compile_times:
1392+
# Get shapes from the most recent result for this kernel
1393+
kernel_results = [r for r in results if r.model == kernel_name]
1394+
shapes = kernel_results[-1].shape if kernel_results else []
1395+
device = get_device_name() or "unknown"
1396+
for func_name, times in variant_compile_times.items():
1397+
if not times:
1398+
continue
1399+
# Align compile times with shapes (both are in input order)
1400+
if len(times) != len(shapes):
1401+
logger.warning(
1402+
f"Compile time count ({len(times)}) != shape count "
1403+
f"({len(shapes)}) for {kernel_name}/{func_name}, skipping"
1404+
)
1405+
continue
1406+
metric_name = "helion_compile_time_s"
1407+
if len(variants) > 1:
1408+
metric_name = f"helion_{func_name}_compile_time_s"
1409+
results.append(
1410+
RunResult(
1411+
model=kernel_name,
1412+
device=device,
1413+
shape=shapes,
1414+
metrics={metric_name: times},
1415+
)
1416+
)
1417+
print(
1418+
f"Compile time for {kernel_name}/{func_name}: "
1419+
f"{', '.join(f'{t:.3f}s' for t in times)} "
1420+
f"({len(times)} shapes)",
1421+
file=sys.stderr,
1422+
)
1423+
13531424
# Force garbage collection multiple times to ensure memory is freed
13541425
for _ in range(3):
13551426
gc.collect()
@@ -1607,6 +1678,12 @@ def main() -> None:
16071678
help="Export autotune metrics to a JSON file at the given path. "
16081679
"Also set via HELION_AUTOTUNE_METRICS_JSON=<path>.",
16091680
)
1681+
parser.add_argument(
1682+
"--measure-compile-time",
1683+
action="store_true",
1684+
help="Measure and report Helion kernel compile time (seconds) for each input shape. "
1685+
"Results are included in JSON output as helion_compile_time_s metric.",
1686+
)
16101687

16111688
# Parse known args to get the kernel name, pass rest to tritonbench
16121689
args, tritonbench_args = parser.parse_known_args()
@@ -1729,6 +1806,9 @@ def main() -> None:
17291806

17301807
results: list[RunResult] = []
17311808

1809+
if args.measure_compile_time:
1810+
os.environ["HELION_MEASURE_COMPILE_TIME"] = "1"
1811+
17321812
collected_metrics: list[AutotuneMetrics] = []
17331813
if args.autotune_metrics or args.autotune_metrics_json:
17341814
register_post_autotune_hook(collected_metrics.append)
@@ -1759,6 +1839,7 @@ def main() -> None:
17591839
results,
17601840
active_kernel_mappings,
17611841
active_metric_mappings,
1842+
measure_compile_time=args.measure_compile_time,
17621843
)
17631844
else:
17641845
print(
@@ -1776,6 +1857,7 @@ def main() -> None:
17761857
results,
17771858
active_kernel_mappings,
17781859
active_metric_mappings,
1860+
measure_compile_time=args.measure_compile_time,
17791861
)
17801862
else:
17811863
# Run all kernels
@@ -1793,6 +1875,7 @@ def main() -> None:
17931875
results,
17941876
active_kernel_mappings,
17951877
active_metric_mappings,
1878+
measure_compile_time=args.measure_compile_time,
17961879
)
17971880

17981881
if args.output:

helion/_compile_time.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def _print_line(self, name: str, elapsed: float, total: float, indent: int) -> N
207207
file=sys.stderr,
208208
)
209209

210+
def get_total_time(self) -> float:
211+
"""Get total top-level compile time in seconds."""
212+
return sum(self._timings.get(name, 0.0) for name in self._TOP_LEVEL)
213+
210214
def reset(self) -> None:
211215
"""Reset all timing data."""
212216
with self._timer_lock:
@@ -298,3 +302,8 @@ def print_report() -> None:
298302
def reset() -> None:
299303
"""Reset all timing data."""
300304
get_tracker().reset()
305+
306+
307+
def get_total_time() -> float:
308+
"""Get total top-level compile time from the global tracker."""
309+
return get_tracker().get_total_time()

0 commit comments

Comments
 (0)