4646from torch .utils ._pytree import tree_map
4747
4848from helion ._compat import get_device_name
49+ from helion ._compile_time import get_total_time as get_compile_total_time
50+ from helion ._compile_time import reset as reset_compile_time
4951from helion ._testing import get_nvidia_gpu_model
5052from helion ._utils import counters
5153from helion .autotuner .metrics import AutotuneMetrics
@@ -1023,6 +1025,7 @@ def run_kernel(
10231025 results : list [RunResult ],
10241026 kernel_mappings : dict [str , tuple [str , ...]] | None = None ,
10251027 kernel_metric_mappings : dict [str , dict [str , str ]] | None = None ,
1028+ measure_compile_time : bool = False ,
10261029) -> None :
10271030 """Run a kernel benchmark, handling both single and multiple variants."""
10281031 # Use provided mappings or default to global mappings
@@ -1081,6 +1084,7 @@ def run_kernel(
10811084 operator_args ,
10821085 results ,
10831086 active_metrics ,
1087+ measure_compile_time = measure_compile_time ,
10841088 )
10851089
10861090
@@ -1093,6 +1097,7 @@ def run_kernel_variants(
10931097 operator_args : dict [str , Any ] | None ,
10941098 results : list [RunResult ],
10951099 kernel_metric_mappings : dict [str , dict [str , str ]] | None = None ,
1100+ measure_compile_time : bool = False ,
10961101) -> None :
10971102 """Run kernel variants in the same benchmark run."""
10981103
@@ -1183,6 +1188,9 @@ def run_kernel_variants(
11831188 # pyrefly: ignore [missing-import]
11841189 from tritonbench .utils .triton_op import register_benchmark
11851190
1191+ # Compile time tracking per variant
1192+ variant_compile_times : dict [str , list [float ]] = {}
1193+
11861194 # Register all variants as separate methods
11871195 for module_path , func_name in variants :
11881196 # Import the kernel function
@@ -1207,6 +1215,7 @@ def run_kernel_variants(
12071215 def create_helion_method (
12081216 mod : Any , # noqa: ANN401
12091217 kfunc : Callable [..., Any ],
1218+ compile_time_list : list [float ] | None = None ,
12101219 ) -> Callable [..., Any ]:
12111220 def helion_method (
12121221 self : object ,
@@ -1245,6 +1254,29 @@ def helion_method(
12451254 measured_func_callable = kfunc (self , * args , ** kwargs )
12461255
12471256 assert callable (measured_func_callable )
1257+
1258+ if compile_time_list is not None :
1259+ original = measured_func_callable
1260+ first_call = True
1261+ ct_list = compile_time_list
1262+
1263+ def timed_callable () -> object :
1264+ nonlocal first_call
1265+ if first_call :
1266+ first_call = False
1267+ torch .cuda .synchronize ()
1268+ reset_compile_time ()
1269+ try :
1270+ result = original ()
1271+ except Exception :
1272+ ct_list .append (get_compile_total_time ())
1273+ raise
1274+ ct_list .append (get_compile_total_time ())
1275+ return result
1276+ return original ()
1277+
1278+ return timed_callable
1279+
12481280 return measured_func_callable
12491281
12501282 return helion_method
@@ -1253,6 +1285,12 @@ def helion_method(
12531285 variant_name = func_name
12541286 helion_method_name = f"helion_{ variant_name } "
12551287
1288+ # Set up compile time tracking for this variant
1289+ compile_times : list [float ] | None = None
1290+ if measure_compile_time :
1291+ compile_times = []
1292+ variant_compile_times [func_name ] = compile_times
1293+
12561294 # Use register_benchmark decorator
12571295 decorated_method = register_benchmark (
12581296 operator_name = operator_name ,
@@ -1261,7 +1299,7 @@ def helion_method(
12611299 enabled = True ,
12621300 fwd_only = False ,
12631301 label = helion_method_name ,
1264- )(create_helion_method (module , kernel_func ))
1302+ )(create_helion_method (module , kernel_func , compile_times ))
12651303
12661304 # Set the decorated method on the Operator class
12671305 setattr (Operator , helion_method_name , decorated_method )
@@ -1350,6 +1388,40 @@ def accuracy_fail_hook(
13501388 except Exception :
13511389 logger .exception ("failed to process results" )
13521390
1391+ # Add compile time metrics (per-shape, same format as speedup)
1392+ if measure_compile_time and variant_compile_times :
1393+ # Get shapes from the most recent result for this kernel
1394+ kernel_results = [r for r in results if r .model == kernel_name ]
1395+ shapes = kernel_results [- 1 ].shape if kernel_results else []
1396+ device = get_device_name () or "unknown"
1397+ for func_name , times in variant_compile_times .items ():
1398+ if not times :
1399+ continue
1400+ # Align compile times with shapes (both are in input order)
1401+ if len (times ) != len (shapes ):
1402+ logger .warning (
1403+ f"Compile time count ({ len (times )} ) != shape count "
1404+ f"({ len (shapes )} ) for { kernel_name } /{ func_name } , skipping"
1405+ )
1406+ continue
1407+ metric_name = "helion_compile_time_s"
1408+ if len (variants ) > 1 :
1409+ metric_name = f"helion_{ func_name } _compile_time_s"
1410+ results .append (
1411+ RunResult (
1412+ model = kernel_name ,
1413+ device = device ,
1414+ shape = shapes ,
1415+ metrics = {metric_name : times },
1416+ )
1417+ )
1418+ print (
1419+ f"Compile time for { kernel_name } /{ func_name } : "
1420+ f"{ ', ' .join (f'{ t :.3f} s' for t in times )} "
1421+ f"({ len (times )} shapes)" ,
1422+ file = sys .stderr ,
1423+ )
1424+
13531425 # Force garbage collection multiple times to ensure memory is freed
13541426 for _ in range (3 ):
13551427 gc .collect ()
@@ -1607,6 +1679,12 @@ def main() -> None:
16071679 help = "Export autotune metrics to a JSON file at the given path. "
16081680 "Also set via HELION_AUTOTUNE_METRICS_JSON=<path>." ,
16091681 )
1682+ parser .add_argument (
1683+ "--measure-compile-time" ,
1684+ action = "store_true" ,
1685+ help = "Measure and report Helion kernel compile time (seconds) for each input shape. "
1686+ "Results are included in JSON output as helion_compile_time_s metric." ,
1687+ )
16101688
16111689 # Parse known args to get the kernel name, pass rest to tritonbench
16121690 args , tritonbench_args = parser .parse_known_args ()
@@ -1729,6 +1807,9 @@ def main() -> None:
17291807
17301808 results : list [RunResult ] = []
17311809
1810+ if args .measure_compile_time :
1811+ os .environ ["HELION_MEASURE_COMPILE_TIME" ] = "1"
1812+
17321813 collected_metrics : list [AutotuneMetrics ] = []
17331814 if args .autotune_metrics or args .autotune_metrics_json :
17341815 register_post_autotune_hook (collected_metrics .append )
@@ -1759,6 +1840,7 @@ def main() -> None:
17591840 results ,
17601841 active_kernel_mappings ,
17611842 active_metric_mappings ,
1843+ measure_compile_time = args .measure_compile_time ,
17621844 )
17631845 else :
17641846 print (
@@ -1776,6 +1858,7 @@ def main() -> None:
17761858 results ,
17771859 active_kernel_mappings ,
17781860 active_metric_mappings ,
1861+ measure_compile_time = args .measure_compile_time ,
17791862 )
17801863 else :
17811864 # Run all kernels
@@ -1793,6 +1876,7 @@ def main() -> None:
17931876 results ,
17941877 active_kernel_mappings ,
17951878 active_metric_mappings ,
1879+ measure_compile_time = args .measure_compile_time ,
17961880 )
17971881
17981882 if args .output :
0 commit comments