4646from torch .utils ._pytree import tree_map
4747
4848from helion ._compat import get_device_name
49+ from helion ._compile_time import get_total_time as get_compile_total_time
50+ from helion ._compile_time import reset as reset_compile_time
4951from helion ._testing import get_nvidia_gpu_model
5052from helion ._utils import counters
5153from helion .autotuner .metrics import AutotuneMetrics
@@ -1023,6 +1025,7 @@ def run_kernel(
10231025 results : list [RunResult ],
10241026 kernel_mappings : dict [str , tuple [str , ...]] | None = None ,
10251027 kernel_metric_mappings : dict [str , dict [str , str ]] | None = None ,
1028+ measure_compile_time : bool = False ,
10261029) -> None :
10271030 """Run a kernel benchmark, handling both single and multiple variants."""
10281031 # Use provided mappings or default to global mappings
@@ -1081,6 +1084,7 @@ def run_kernel(
10811084 operator_args ,
10821085 results ,
10831086 active_metrics ,
1087+ measure_compile_time = measure_compile_time ,
10841088 )
10851089
10861090
@@ -1093,6 +1097,7 @@ def run_kernel_variants(
10931097 operator_args : dict [str , Any ] | None ,
10941098 results : list [RunResult ],
10951099 kernel_metric_mappings : dict [str , dict [str , str ]] | None = None ,
1100+ measure_compile_time : bool = False ,
10961101) -> None :
10971102 """Run kernel variants in the same benchmark run."""
10981103
@@ -1183,6 +1188,9 @@ def run_kernel_variants(
11831188 # pyrefly: ignore [missing-import]
11841189 from tritonbench .utils .triton_op import register_benchmark
11851190
1191+ # Compile time tracking per variant
1192+ variant_compile_times : dict [str , list [float ]] = {}
1193+
11861194 # Register all variants as separate methods
11871195 for module_path , func_name in variants :
11881196 # Import the kernel function
@@ -1207,6 +1215,7 @@ def run_kernel_variants(
12071215 def create_helion_method (
12081216 mod : Any , # noqa: ANN401
12091217 kfunc : Callable [..., Any ],
1218+ compile_time_list : list [float ] | None = None ,
12101219 ) -> Callable [..., Any ]:
12111220 def helion_method (
12121221 self : object ,
@@ -1245,6 +1254,28 @@ def helion_method(
12451254 measured_func_callable = kfunc (self , * args , ** kwargs )
12461255
12471256 assert callable (measured_func_callable )
1257+
1258+ if compile_time_list is not None :
1259+ original = measured_func_callable
1260+ first_call = True
1261+
1262+ def timed_callable () -> object :
1263+ nonlocal first_call
1264+ if first_call :
1265+ first_call = False
1266+ torch .cuda .synchronize ()
1267+ reset_compile_time ()
1268+ try :
1269+ result = original ()
1270+ except Exception :
1271+ compile_time_list .append (get_compile_total_time ())
1272+ raise
1273+ compile_time_list .append (get_compile_total_time ())
1274+ return result
1275+ return original ()
1276+
1277+ return timed_callable
1278+
12481279 return measured_func_callable
12491280
12501281 return helion_method
@@ -1253,6 +1284,12 @@ def helion_method(
12531284 variant_name = func_name
12541285 helion_method_name = f"helion_{ variant_name } "
12551286
1287+ # Set up compile time tracking for this variant
1288+ compile_times : list [float ] | None = None
1289+ if measure_compile_time :
1290+ compile_times = []
1291+ variant_compile_times [func_name ] = compile_times
1292+
12561293 # Use register_benchmark decorator
12571294 decorated_method = register_benchmark (
12581295 operator_name = operator_name ,
@@ -1261,7 +1298,7 @@ def helion_method(
12611298 enabled = True ,
12621299 fwd_only = False ,
12631300 label = helion_method_name ,
1264- )(create_helion_method (module , kernel_func ))
1301+ )(create_helion_method (module , kernel_func , compile_times ))
12651302
12661303 # Set the decorated method on the Operator class
12671304 setattr (Operator , helion_method_name , decorated_method )
@@ -1350,6 +1387,40 @@ def accuracy_fail_hook(
13501387 except Exception :
13511388 logger .exception ("failed to process results" )
13521389
1390+ # Add compile time metrics (per-shape, same format as speedup)
1391+ if measure_compile_time and variant_compile_times :
1392+ # Get shapes from the most recent result for this kernel
1393+ kernel_results = [r for r in results if r .model == kernel_name ]
1394+ shapes = kernel_results [- 1 ].shape if kernel_results else []
1395+ device = get_device_name () or "unknown"
1396+ for func_name , times in variant_compile_times .items ():
1397+ if not times :
1398+ continue
1399+ # Align compile times with shapes (both are in input order)
1400+ if len (times ) != len (shapes ):
1401+ logger .warning (
1402+ f"Compile time count ({ len (times )} ) != shape count "
1403+ f"({ len (shapes )} ) for { kernel_name } /{ func_name } , skipping"
1404+ )
1405+ continue
1406+ metric_name = "helion_compile_time_s"
1407+ if len (variants ) > 1 :
1408+ metric_name = f"helion_{ func_name } _compile_time_s"
1409+ results .append (
1410+ RunResult (
1411+ model = kernel_name ,
1412+ device = device ,
1413+ shape = shapes ,
1414+ metrics = {metric_name : times },
1415+ )
1416+ )
1417+ print (
1418+ f"Compile time for { kernel_name } /{ func_name } : "
1419+ f"{ ', ' .join (f'{ t :.3f} s' for t in times )} "
1420+ f"({ len (times )} shapes)" ,
1421+ file = sys .stderr ,
1422+ )
1423+
13531424 # Force garbage collection multiple times to ensure memory is freed
13541425 for _ in range (3 ):
13551426 gc .collect ()
@@ -1607,6 +1678,12 @@ def main() -> None:
16071678 help = "Export autotune metrics to a JSON file at the given path. "
16081679 "Also set via HELION_AUTOTUNE_METRICS_JSON=<path>." ,
16091680 )
1681+ parser .add_argument (
1682+ "--measure-compile-time" ,
1683+ action = "store_true" ,
1684+ help = "Measure and report Helion kernel compile time (seconds) for each input shape. "
1685+ "Results are included in JSON output as helion_compile_time_s metric." ,
1686+ )
16101687
16111688 # Parse known args to get the kernel name, pass rest to tritonbench
16121689 args , tritonbench_args = parser .parse_known_args ()
@@ -1729,6 +1806,9 @@ def main() -> None:
17291806
17301807 results : list [RunResult ] = []
17311808
1809+ if args .measure_compile_time :
1810+ os .environ ["HELION_MEASURE_COMPILE_TIME" ] = "1"
1811+
17321812 collected_metrics : list [AutotuneMetrics ] = []
17331813 if args .autotune_metrics or args .autotune_metrics_json :
17341814 register_post_autotune_hook (collected_metrics .append )
@@ -1759,6 +1839,7 @@ def main() -> None:
17591839 results ,
17601840 active_kernel_mappings ,
17611841 active_metric_mappings ,
1842+ measure_compile_time = args .measure_compile_time ,
17621843 )
17631844 else :
17641845 print (
@@ -1776,6 +1857,7 @@ def main() -> None:
17761857 results ,
17771858 active_kernel_mappings ,
17781859 active_metric_mappings ,
1860+ measure_compile_time = args .measure_compile_time ,
17791861 )
17801862 else :
17811863 # Run all kernels
@@ -1793,6 +1875,7 @@ def main() -> None:
17931875 results ,
17941876 active_kernel_mappings ,
17951877 active_metric_mappings ,
1878+ measure_compile_time = args .measure_compile_time ,
17961879 )
17971880
17981881 if args .output :
0 commit comments