Skip to content

Commit 32b136a

Browse files
committed
[Benchmark] Add compile time measurement to CI benchmarks
Add --measure-compile-time flag to the benchmark runner that captures per-shape Helion kernel compile time using the existing CompileTimeTracker infrastructure. Results are reported as helion_compile_time_s in the PyTorch HUD v3 JSON format, enabling continuous monitoring and regression detection on the dashboard.
1 parent b7b43b5 commit 32b136a

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

.github/workflows/benchmark.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ jobs:
194194
${{ inputs.env-vars }} HELION_PRINT_OUTPUT_CODE=1 HELION_ASSERT_CACHE_HIT=1 python benchmarks/run.py \
195195
--op $kernel \
196196
--metrics speedup,accuracy \
197+
--measure-compile-time \
197198
--latency-measure-mode triton_do_bench \
198199
--cudagraph \
199200
--only $IMPLS \

benchmarks/run.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from torch.utils._pytree import tree_map
4747

4848
from helion._compat import get_device_name
49+
from helion._compile_time import get_total_time as get_compile_total_time
50+
from helion._compile_time import reset as reset_compile_time
4951
from helion._testing import get_nvidia_gpu_model
5052
from helion._utils import counters
5153
from helion.autotuner.metrics import AutotuneMetrics
@@ -1023,6 +1025,7 @@ def run_kernel(
10231025
results: list[RunResult],
10241026
kernel_mappings: dict[str, tuple[str, ...]] | None = None,
10251027
kernel_metric_mappings: dict[str, dict[str, str]] | None = None,
1028+
measure_compile_time: bool = False,
10261029
) -> None:
10271030
"""Run a kernel benchmark, handling both single and multiple variants."""
10281031
# Use provided mappings or default to global mappings
@@ -1081,6 +1084,7 @@ def run_kernel(
10811084
operator_args,
10821085
results,
10831086
active_metrics,
1087+
measure_compile_time=measure_compile_time,
10841088
)
10851089

10861090

@@ -1093,6 +1097,7 @@ def run_kernel_variants(
10931097
operator_args: dict[str, Any] | None,
10941098
results: list[RunResult],
10951099
kernel_metric_mappings: dict[str, dict[str, str]] | None = None,
1100+
measure_compile_time: bool = False,
10961101
) -> None:
10971102
"""Run kernel variants in the same benchmark run."""
10981103

@@ -1183,6 +1188,9 @@ def run_kernel_variants(
11831188
# pyrefly: ignore [missing-import]
11841189
from tritonbench.utils.triton_op import register_benchmark
11851190

1191+
# Compile time tracking per variant
1192+
variant_compile_times: dict[str, list[float]] = {}
1193+
11861194
# Register all variants as separate methods
11871195
for module_path, func_name in variants:
11881196
# Import the kernel function
@@ -1207,6 +1215,7 @@ def run_kernel_variants(
12071215
def create_helion_method(
12081216
mod: Any, # noqa: ANN401
12091217
kfunc: Callable[..., Any],
1218+
compile_time_list: list[float] | None = None,
12101219
) -> Callable[..., Any]:
12111220
def helion_method(
12121221
self: object,
@@ -1245,6 +1254,29 @@ def helion_method(
12451254
measured_func_callable = kfunc(self, *args, **kwargs)
12461255

12471256
assert callable(measured_func_callable)
1257+
1258+
if compile_time_list is not None:
1259+
original = measured_func_callable
1260+
first_call = True
1261+
ct_list = compile_time_list
1262+
1263+
def timed_callable() -> object:
1264+
nonlocal first_call
1265+
if first_call:
1266+
first_call = False
1267+
torch.cuda.synchronize()
1268+
reset_compile_time()
1269+
try:
1270+
result = original()
1271+
except Exception:
1272+
ct_list.append(get_compile_total_time())
1273+
raise
1274+
ct_list.append(get_compile_total_time())
1275+
return result
1276+
return original()
1277+
1278+
return timed_callable
1279+
12481280
return measured_func_callable
12491281

12501282
return helion_method
@@ -1253,6 +1285,12 @@ def helion_method(
12531285
variant_name = func_name
12541286
helion_method_name = f"helion_{variant_name}"
12551287

1288+
# Set up compile time tracking for this variant
1289+
compile_times: list[float] | None = None
1290+
if measure_compile_time:
1291+
compile_times = []
1292+
variant_compile_times[func_name] = compile_times
1293+
12561294
# Use register_benchmark decorator
12571295
decorated_method = register_benchmark(
12581296
operator_name=operator_name,
@@ -1261,7 +1299,7 @@ def helion_method(
12611299
enabled=True,
12621300
fwd_only=False,
12631301
label=helion_method_name,
1264-
)(create_helion_method(module, kernel_func))
1302+
)(create_helion_method(module, kernel_func, compile_times))
12651303

12661304
# Set the decorated method on the Operator class
12671305
setattr(Operator, helion_method_name, decorated_method)
@@ -1350,6 +1388,40 @@ def accuracy_fail_hook(
13501388
except Exception:
13511389
logger.exception("failed to process results")
13521390

1391+
# Add compile time metrics (per-shape, same format as speedup)
1392+
if measure_compile_time and variant_compile_times:
1393+
# Get shapes from the most recent result for this kernel
1394+
kernel_results = [r for r in results if r.model == kernel_name]
1395+
shapes = kernel_results[-1].shape if kernel_results else []
1396+
device = get_device_name() or "unknown"
1397+
for func_name, times in variant_compile_times.items():
1398+
if not times:
1399+
continue
1400+
# Align compile times with shapes (both are in input order)
1401+
if len(times) != len(shapes):
1402+
logger.warning(
1403+
f"Compile time count ({len(times)}) != shape count "
1404+
f"({len(shapes)}) for {kernel_name}/{func_name}, skipping"
1405+
)
1406+
continue
1407+
metric_name = "helion_compile_time_s"
1408+
if len(variants) > 1:
1409+
metric_name = f"helion_{func_name}_compile_time_s"
1410+
results.append(
1411+
RunResult(
1412+
model=kernel_name,
1413+
device=device,
1414+
shape=shapes,
1415+
metrics={metric_name: times},
1416+
)
1417+
)
1418+
print(
1419+
f"Compile time for {kernel_name}/{func_name}: "
1420+
f"{', '.join(f'{t:.3f}s' for t in times)} "
1421+
f"({len(times)} shapes)",
1422+
file=sys.stderr,
1423+
)
1424+
13531425
# Force garbage collection multiple times to ensure memory is freed
13541426
for _ in range(3):
13551427
gc.collect()
@@ -1607,6 +1679,12 @@ def main() -> None:
16071679
help="Export autotune metrics to a JSON file at the given path. "
16081680
"Also set via HELION_AUTOTUNE_METRICS_JSON=<path>.",
16091681
)
1682+
parser.add_argument(
1683+
"--measure-compile-time",
1684+
action="store_true",
1685+
help="Measure and report Helion kernel compile time (seconds) for each input shape. "
1686+
"Results are included in JSON output as helion_compile_time_s metric.",
1687+
)
16101688

16111689
# Parse known args to get the kernel name, pass rest to tritonbench
16121690
args, tritonbench_args = parser.parse_known_args()
@@ -1729,6 +1807,9 @@ def main() -> None:
17291807

17301808
results: list[RunResult] = []
17311809

1810+
if args.measure_compile_time:
1811+
os.environ["HELION_MEASURE_COMPILE_TIME"] = "1"
1812+
17321813
collected_metrics: list[AutotuneMetrics] = []
17331814
if args.autotune_metrics or args.autotune_metrics_json:
17341815
register_post_autotune_hook(collected_metrics.append)
@@ -1759,6 +1840,7 @@ def main() -> None:
17591840
results,
17601841
active_kernel_mappings,
17611842
active_metric_mappings,
1843+
measure_compile_time=args.measure_compile_time,
17621844
)
17631845
else:
17641846
print(
@@ -1776,6 +1858,7 @@ def main() -> None:
17761858
results,
17771859
active_kernel_mappings,
17781860
active_metric_mappings,
1861+
measure_compile_time=args.measure_compile_time,
17791862
)
17801863
else:
17811864
# Run all kernels
@@ -1793,6 +1876,7 @@ def main() -> None:
17931876
results,
17941877
active_kernel_mappings,
17951878
active_metric_mappings,
1879+
measure_compile_time=args.measure_compile_time,
17961880
)
17971881

17981882
if args.output:

helion/_compile_time.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def _print_line(self, name: str, elapsed: float, total: float, indent: int) -> N
207207
file=sys.stderr,
208208
)
209209

210+
def get_total_time(self) -> float:
211+
"""Get total top-level compile time in seconds."""
212+
return sum(self._timings.get(name, 0.0) for name in self._TOP_LEVEL)
213+
210214
def reset(self) -> None:
211215
"""Reset all timing data."""
212216
with self._timer_lock:
@@ -298,3 +302,8 @@ def print_report() -> None:
298302
def reset() -> None:
299303
"""Reset all timing data."""
300304
get_tracker().reset()
305+
306+
307+
def get_total_time() -> float:
308+
"""Get total top-level compile time from the global tracker."""
309+
return get_tracker().get_total_time()

0 commit comments

Comments
 (0)