-
Notifications
You must be signed in to change notification settings - Fork 137
Expand file tree
/
Copy pathaot_cache.py
More file actions
1239 lines (1061 loc) · 46.2 KB
/
aot_cache.py
File metadata and controls
1239 lines (1061 loc) · 46.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
AOT (Ahead-of-Time) Autotuning Cache Implementation
====================================================
This module provides a cache implementation for AOT autotuning workflows that:
1. Collects tuned configs for each shape during benchmark runs
2. Measures all configs across all shapes
3. Generates heuristics using decision trees to select optimal configs
4. Supports multiple hardware architectures
The workflow is:
1. collect_tuned_configs: Tune each shape, record (kernel, shape, config) triples
2. measure_configs: Measure each shape with all observed configs
3. Generate heuristics to select configs based on performance goals
4. evaluate: Validate performance goals are achieved
"""
from __future__ import annotations
import csv
import dataclasses
from dataclasses import dataclass
import functools
import hashlib
import importlib
import importlib.util
import inspect
import json
import logging
import operator
import os
from pathlib import Path
import sys
import traceback
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from typing import Literal
import torch
from ..experimental.aot_kernel import _flatten_key_value
from ..experimental.aot_kernel import extract_key_features
from ..experimental.aot_kernel import extract_shape_features
from ..runtime.config import Config
from .base_cache import AutotuneCacheBase
from .base_cache import BoundKernelInMemoryCacheKey
from .base_cache import LooseAutotuneCacheKey
if TYPE_CHECKING:
from collections.abc import Sequence
from .base_search import BaseSearch
log: logging.Logger = logging.getLogger(__name__)
# Compute capability lists for fallback (newest to oldest)
_CUDA_COMPUTE_CAPS: list[str] = [
"sm100",
"sm90",
"sm89",
"sm87",
"sm86",
"sm80",
"sm75",
"sm72",
"sm70",
]
_ROCM_ARCHS: list[str] = [
"gfx950",
"gfx942",
"gfx941",
"gfx940",
"gfx90a",
"gfx908",
"gfx906",
"gfx900",
]
@dataclasses.dataclass(frozen=True)
class HardwareInfo:
"""
Hardware information for cache keys and heuristic file discovery.
Attributes:
device_kind: Device type ('cuda', 'rocm', 'xpu')
hardware_name: Device name (e.g., 'NVIDIA H100', 'gfx90a')
runtime_version: Runtime version (e.g., '12.4', 'gfx90a')
compute_capability: Compute capability for heuristics (e.g., 'sm90', 'gfx90a')
"""
device_kind: str
hardware_name: str
runtime_version: str
compute_capability: str
@property
def hardware_id(self) -> str:
"""Get a unique identifier string for this hardware."""
safe_name = self.hardware_name.replace(" ", "_")
return f"{self.device_kind}_{safe_name}_{self.runtime_version}"
def get_compatible_compute_ids(self) -> list[str]:
"""
Get a list of compatible compute IDs for fallback, ordered from current to oldest.
For CUDA/ROCm, returns the current compute capability followed by all older
compatible architectures. This allows using heuristics tuned on older hardware
when newer hardware-specific heuristics aren't available.
"""
if self.device_kind == "cuda":
arch_list = _CUDA_COMPUTE_CAPS
elif self.device_kind == "rocm":
arch_list = _ROCM_ARCHS
else:
return [self.compute_capability]
try:
current_idx = arch_list.index(self.compute_capability)
return arch_list[current_idx:]
except ValueError:
return [self.compute_capability, *arch_list]
@functools.cache
def get_hardware_info(device: torch.device | None = None) -> HardwareInfo:
"""
Get hardware information for the current or specified device.
This is the single source of truth for hardware detection, used by both
local cache and AOT cache.
Args:
device: Optional device to get info for. If None, uses first available GPU or CPU.
Returns:
HardwareInfo with device details for caching and heuristic lookup.
"""
# XPU (Intel) path
if (
device is not None
and device.type == "xpu"
and getattr(torch, "xpu", None) is not None
and torch.xpu.is_available()
):
props = torch.xpu.get_device_properties(device)
return HardwareInfo(
device_kind="xpu",
hardware_name=props.name,
runtime_version=props.driver_version,
compute_capability=props.name, # XPU doesn't have compute capability
)
# CUDA/ROCm path
if torch.cuda.is_available():
dev = (
device
if device is not None and device.type == "cuda"
else torch.device("cuda:0")
)
props = torch.cuda.get_device_properties(dev)
if torch.version.cuda is not None:
return HardwareInfo(
device_kind="cuda",
hardware_name=props.name,
runtime_version=str(torch.version.cuda),
compute_capability=f"sm{props.major}{props.minor}",
)
if torch.version.hip is not None:
return HardwareInfo(
device_kind="rocm",
hardware_name=props.gcnArchName,
runtime_version=torch.version.hip,
compute_capability=props.gcnArchName,
)
raise RuntimeError(
"No supported GPU device found. Helion requires CUDA, ROCm, or XPU."
)
# Environment variable to control AOT mode
AOT_MODE_ENV = "HELION_AOT_MODE"
AOT_DATA_DIR_ENV = "HELION_AOT_DATA_DIR"
# Environment variable to override heuristic search path (for comparing heuristics)
HEURISTIC_DIR_ENV = "HELION_HEURISTIC_DIR"
# Environment variable to enable verbose output in evaluate mode (default: quiet)
AOT_VERBOSE_ENV = "HELION_AOT_VERBOSE"
AOTMode = Literal["collect", "measure", "evaluate", "compile", "disabled"]
def get_aot_mode() -> AOTMode:
"""Get the current AOT mode from environment."""
mode = os.environ.get(AOT_MODE_ENV, "evaluate").lower()
if mode in ("collect", "measure", "evaluate", "compile", "disabled"):
return mode # type: ignore[return-value]
raise ValueError(
f"Invalid {AOT_MODE_ENV} value: {mode}. "
"Must be one of: collect, measure, evaluate, compile, disabled"
)
def is_aot_verbose() -> bool:
"""Check if verbose output is enabled for AOT mode.
In evaluate mode, output is quiet by default (just using heuristics).
Set HELION_AOT_VERBOSE=1 to enable verbose output.
"""
return os.environ.get(AOT_VERBOSE_ENV, "").lower() in ("1", "true", "yes")
def get_aot_data_dir() -> Path:
"""Get the AOT data directory from environment or default."""
if (path := os.environ.get(AOT_DATA_DIR_ENV)) is not None:
return Path(path)
return Path.cwd() / ".helion_aot"
# Cache for heuristic file lookups
_heuristic_file_cache: dict[str, Path | None] = {}
def find_heuristic_file(
kernel_source_file: str | Path,
kernel_name: str | None = None,
data_dir: Path | None = None,
) -> Path | None:
"""
Find the heuristic file for a kernel.
This is the single source of truth for heuristic file discovery, used by both
AOTKeyFunction and AOTAutotuneCache.
Search order:
1. HELION_HEURISTIC_DIR env var (if set) - for comparing different heuristics
2. Next to kernel source file: _<filename>_<device>_<compute>.py
3. Fallback to older compute capabilities within the same device family
4. AOT data directory: heuristic_<kernel_name>.py (fallback)
Args:
kernel_source_file: Path to the kernel's source file
kernel_name: Optional kernel name for fallback lookup
data_dir: Optional AOT data directory for fallback lookup
Returns:
Path to heuristic file if found, None otherwise
"""
cache_key = str(kernel_source_file)
if cache_key in _heuristic_file_cache:
return _heuristic_file_cache[cache_key]
source_path = Path(kernel_source_file)
base_name = source_path.stem
hw = get_hardware_info()
compatible_computes = hw.get_compatible_compute_ids()
candidates: list[Path] = []
# 1. Check HELION_HEURISTIC_DIR override
if (heuristic_dir := os.environ.get(HEURISTIC_DIR_ENV)) is not None:
heuristic_dir_path = Path(heuristic_dir)
for compat_compute in compatible_computes:
candidates.append(
heuristic_dir_path
/ f"_helion_aot_{base_name}_{hw.device_kind}_{compat_compute}.py"
)
if kernel_name:
candidates.append(heuristic_dir_path / f"heuristic_{kernel_name}.py")
# 2. Check next to kernel source file with compute capability fallback
for compat_compute in compatible_computes:
heuristic_name = f"_helion_aot_{base_name}_{hw.device_kind}_{compat_compute}.py"
candidates.append(source_path.parent / heuristic_name)
# 3. Check AOT data directory (fallback)
if data_dir is not None and kernel_name is not None:
candidates.append(data_dir / f"heuristic_{kernel_name}.py")
# Find first existing file
result: Path | None = None
for candidate in candidates:
if candidate.exists():
log.debug(f"Found heuristic file: {candidate}")
result = candidate
break
_heuristic_file_cache[cache_key] = result
return result
def clear_heuristic_cache() -> None:
"""Clear the heuristic file cache (useful for testing)."""
_heuristic_file_cache.clear()
def load_kernel_source_files(data_dir: Path, hardware_id: str) -> dict[str, str]:
"""
Load kernel source file mappings from tuned configs JSON.
This is a standalone function for use by aot_runner.py during heuristic generation.
Args:
data_dir: Directory containing the tuned configs file
hardware_id: Hardware ID used in the filename
Returns:
Dict mapping kernel_name -> source_file_path
"""
configs_file = data_dir / f"tuned_configs_{hardware_id}.json"
if not configs_file.exists():
return {}
try:
data = json.loads(configs_file.read_text())
result: dict[str, str] = {}
for kernel_name, configs in data.items():
for cfg in configs:
if cfg.get("kernel_source_file"):
result[kernel_name] = cfg["kernel_source_file"]
break
return result
except Exception as e:
log.warning(f"Failed to load kernel source files: {e}")
return {}
@dataclass
class ShapeKey:
"""Represents a unique shape/dtype combination for a kernel."""
kernel_name: str
specialization_key: tuple[Any, ...]
hardware_id: str
def to_dict(self) -> dict[str, Any]:
"""Convert to a JSON-serializable dict."""
return {
"kernel_name": self.kernel_name,
"specialization_key": _serialize_tuple(self.specialization_key),
"hardware_id": self.hardware_id,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> ShapeKey:
"""Create from a dict."""
return cls(
kernel_name=data["kernel_name"],
specialization_key=_deserialize_tuple(data["specialization_key"]),
hardware_id=data["hardware_id"],
)
def stable_hash(self) -> str:
"""Get a stable hash for this shape key."""
return hashlib.sha256(
json.dumps(self.to_dict(), sort_keys=True).encode()
).hexdigest()[:16]
def compute_tensor_hash(tensor: torch.Tensor) -> str:
"""Compute SHA256 hash (first 8 chars) of tensor bytes."""
if not tensor.is_contiguous():
tensor = tensor.contiguous()
if tensor.device.type != "cpu":
tensor = tensor.cpu()
if tensor.requires_grad:
tensor = tensor.detach()
# Convert dtypes not supported by numpy (e.g., bfloat16)
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)
return hashlib.sha256(tensor.numpy().tobytes()).hexdigest()[:8]
@dataclass
class TunedConfig:
"""A tuned configuration with its benchmark results."""
config: Config
shape_key: ShapeKey
timing_ms: float | None = None
kernel_source_file: str | None = None
shape_features: dict[str, Any] | None = None
# SHA256 hashes (first 8 chars) for correctness verification:
# [0] = input tensor hashes before kernel runs
# [1] = input tensor hashes after kernel runs (to detect in-place modifications)
# [2] = output tensor hashes
tensor_hashes: list[list[str]] | None = None
class AOTAutotuneCache(AutotuneCacheBase):
"""
Cache implementation for AOT autotuning workflows.
Behavior depends on the HELION_AOT_MODE environment variable:
- collect: Tune each shape individually, record results
- measure: Measure each shape with all observed configs
- evaluate: Use heuristics to select configs, validate performance
- disabled: Fall through to underlying autotuner (default)
When collect_fn/measure_fn are set on the kernel:
- collect_fn: In collect mode, only these inputs are autotuned
- measure_fn: In measure mode, only these inputs are measured
- One-shot: If both set in collect mode, runs both phases in one invocation
"""
# Tracks which AOT modes have been announced to avoid repeated stderr messages.
# Class-level so announcements happen only once per mode across all instances.
_mode_announced: ClassVar[set[str]] = set()
# Class-level caches for heuristic lookup (shared across instances)
# Maps heuristic file path -> loaded module
_heuristic_modules: ClassVar[dict[Path, Any]] = {}
# Maps (kernel_source_file, kernel_name, shape_features_hash) -> Config
# Using source file ensures kernels with same name in different modules don't collide
_heuristic_results: ClassVar[dict[tuple[str, str, str], Config]] = {}
# Tracks which kernels have shown the "no heuristic" warning (to avoid spam)
_no_heuristic_warned: ClassVar[set[str]] = set()
# Tracks which kernels have already been compiled in compile mode
_compiled_kernels: ClassVar[set[str]] = set()
@classmethod
def clear_caches(cls) -> None:
"""Clear all class-level caches (heuristic modules and results)."""
cls._heuristic_modules.clear()
cls._heuristic_results.clear()
cls._no_heuristic_warned.clear()
cls._compiled_kernels.clear()
clear_heuristic_cache() # Clear module-level cache
cls._mode_announced.clear()
log.debug("Cleared AOTAutotuneCache caches")
def __init__(self, autotuner: BaseSearch) -> None:
super().__init__(autotuner)
self.mode = get_aot_mode()
self.hardware_id = get_hardware_info().hardware_id
self.data_dir = get_aot_data_dir()
self.data_dir.mkdir(parents=True, exist_ok=True)
self._tuned_configs: dict[str, list[TunedConfig]] = self._load_tuned_configs()
self.shape_key = self._create_shape_key()
self._verbose = is_aot_verbose()
# Look up optional collect_fn/measure_fn from the Kernel object
# These are set by @aot_kernel() decorator
self._collect_fn = getattr(self.kernel.kernel, "_aot_collect_fn", None)
self._measure_fn = getattr(self.kernel.kernel, "_aot_measure_fn", None)
# Announce mode once per mode type (quiet in evaluate mode unless verbose)
should_announce = (
self.mode != "disabled"
and self.mode not in AOTAutotuneCache._mode_announced
and (self.mode not in ("evaluate", "compile") or self._verbose)
)
if should_announce:
print(
f"[AOT] Mode: {self.mode}, Data dir: {self.data_dir}, "
f"Hardware: {self.hardware_id}",
file=sys.stderr,
)
num_configs = sum(len(v) for v in self._tuned_configs.values())
if num_configs > 0:
print(f"[AOT] Loaded {num_configs} existing configs", file=sys.stderr)
AOTAutotuneCache._mode_announced.add(self.mode)
@property
def _configs_file(self) -> Path:
"""Path to the tuned configs JSON file."""
return self.data_dir / f"tuned_configs_{self.hardware_id}.json"
@property
def _measurements_file(self) -> Path:
"""Path to the measurements CSV file."""
return self.data_dir / f"measurements_{self.hardware_id}.csv"
def _load_tuned_configs(self) -> dict[str, list[TunedConfig]]:
"""Load tuned configs from disk."""
if not self._configs_file.exists():
return {}
try:
data = json.loads(self._configs_file.read_text())
result: dict[str, list[TunedConfig]] = {}
for kernel_name, configs in data.items():
result[kernel_name] = [
TunedConfig(
config=Config(**cfg["config"]),
shape_key=ShapeKey.from_dict(cfg["shape_key"]),
timing_ms=cfg.get("timing_ms"),
kernel_source_file=cfg.get("kernel_source_file"),
shape_features=cfg.get("shape_features"),
tensor_hashes=cfg.get("tensor_hashes"),
)
for cfg in configs
]
return result
except Exception as e:
log.warning(f"Failed to load tuned configs: {e}")
return {}
def _save_tuned_configs(self) -> None:
"""Save tuned configs to disk."""
data: dict[str, list[dict[str, Any]]] = {}
for kernel_name, config_list in self._tuned_configs.items():
data[kernel_name] = [
{
"config": dict(cfg.config),
"shape_key": cfg.shape_key.to_dict(),
"timing_ms": cfg.timing_ms,
"kernel_source_file": cfg.kernel_source_file,
"shape_features": cfg.shape_features,
"tensor_hashes": cfg.tensor_hashes,
}
for cfg in config_list
]
self._configs_file.write_text(json.dumps(data, indent=2))
def _add_tuned_config(
self,
kernel_name: str,
config: Config,
shape_key: ShapeKey,
timing_ms: float | None = None,
kernel_source_file: str | None = None,
shape_features: dict[str, Any] | None = None,
tensor_hashes: list[list[str]] | None = None,
) -> None:
"""Add a tuned config for a kernel/shape combination."""
if kernel_name not in self._tuned_configs:
self._tuned_configs[kernel_name] = []
shape_hash = shape_key.stable_hash()
config_dict = dict(config)
# Check if this exact config already exists for this shape
for existing in self._tuned_configs[kernel_name]:
if (
existing.shape_key.stable_hash() == shape_hash
and dict(existing.config) == config_dict
):
# Update if we have better timing
if timing_ms is not None:
if existing.timing_ms is None or timing_ms < existing.timing_ms:
existing.timing_ms = timing_ms
if kernel_source_file is not None:
existing.kernel_source_file = kernel_source_file
if shape_features is not None:
existing.shape_features = shape_features
if tensor_hashes is not None:
existing.tensor_hashes = tensor_hashes
return
self._tuned_configs[kernel_name].append(
TunedConfig(
config=config,
shape_key=shape_key,
timing_ms=timing_ms,
kernel_source_file=kernel_source_file,
shape_features=shape_features,
tensor_hashes=tensor_hashes,
)
)
def _get_all_configs_for_kernel(self, kernel_name: str) -> list[Config]:
"""Get all unique configs observed for a kernel."""
if kernel_name not in self._tuned_configs:
return []
seen: set[str] = set()
result: list[Config] = []
for tc in self._tuned_configs[kernel_name]:
config_hash = hashlib.sha256(
json.dumps(dict(tc.config), sort_keys=True).encode()
).hexdigest()
if config_hash not in seen:
seen.add(config_hash)
result.append(tc.config)
return result
def _save_measurement(
self,
kernel_name: str,
shape_key: ShapeKey,
config: Config,
timing_ms: float,
shape_features: dict[str, Any],
) -> None:
"""Save a measurement to CSV."""
config_hash = hashlib.sha256(
json.dumps(dict(config), sort_keys=True).encode()
).hexdigest()[:16]
row = {
"kernel_name": kernel_name,
"shape_hash": shape_key.stable_hash(),
"config_hash": config_hash,
"config": json.dumps(dict(config)),
"shape_features": json.dumps(shape_features),
"timing_ms": timing_ms,
}
file_exists = self._measurements_file.exists()
with open(self._measurements_file, "a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=row.keys())
if not file_exists:
writer.writeheader()
writer.writerow(row)
def _create_shape_key(self) -> ShapeKey:
"""Create a shape key for the current kernel invocation."""
return ShapeKey(
kernel_name=self.kernel.kernel.name,
specialization_key=self.kernel.kernel.specialization_key(self.args),
hardware_id=self.hardware_id,
)
def _extract_shape_features(
self, args: Sequence[object] | None = None
) -> dict[str, Any]:
"""Extract numeric features from the shape for ML model.
If a user key function is provided, extracts features from the
flattened key output instead of raw args.
"""
if args is None:
args = self.args
# Check if user provided a key function
user_key = getattr(self.kernel.kernel, "_aot_user_key", None)
if user_key is not None:
# Extract features from flattened key output
key_value = user_key(*args)
return extract_key_features(key_value)
# Use single source of truth from aot_kernel module
return extract_shape_features(args)
def get(self) -> Config | None:
"""Get a cached config based on current mode."""
if self.mode == "collect":
# In collect mode, check if we already have a config for this exact shape
kernel_name = self.kernel.kernel.name
configs = self._tuned_configs.get(kernel_name, [])
for tc in configs:
if tc.shape_key.stable_hash() == self.shape_key.stable_hash():
log.info(f"AOT collect: Using existing config for {kernel_name}")
return tc.config
return None # Need to tune
if self.mode == "measure":
# In measure mode, we don't use cache - we measure all configs
return None
if self.mode == "compile":
# In compile mode: use heuristic + generate standalone Triton code
self._maybe_run_compile()
# For disabled/evaluate/compile modes: try heuristic, fall back to default config
# (never trigger autotuning for aot_kernel)
config = self._get_heuristic_config()
if config is not None:
return config
# No heuristic available - use default config with warning (once per kernel)
kernel_name = self.kernel.kernel.name
from .. import exc
if kernel_name not in AOTAutotuneCache._no_heuristic_warned:
AOTAutotuneCache._no_heuristic_warned.add(kernel_name)
if exc.NoAOTHeuristicWarning not in self.autotuner.settings.ignore_warnings:
print(
f"[AOT] Warning: No heuristic found for '{kernel_name}'. "
f"Using default config. "
f"Use `python -m helion.experimental.aot_runner` to generate tuned configs.",
file=sys.stderr,
)
return self.autotuner.config_spec.default_config()
def _compute_tensor_hashes(
self, tensors: Sequence[object] | None = None
) -> list[str]:
"""Compute hashes for tensors. Non-tensors get "n/a"."""
if tensors is None:
tensors = self.args
return [
compute_tensor_hash(arg) if isinstance(arg, torch.Tensor) else "n/a"
for arg in tensors
]
def put(self, config: Config, timing_ms: float | None = None) -> None:
"""Store a tuned config based on current mode."""
if self.mode == "disabled":
return
if self.mode == "collect":
kernel_name = self.kernel.kernel.name
kernel_source_file = self.kernel.kernel.__code__.co_filename
shape_features = self._extract_shape_features()
# Hash inputs, run kernel, hash inputs again and outputs
input_hashes = self._compute_tensor_hashes()
fn = self.kernel.compile_config(config)
outputs = fn(*self.args)
input_after_hashes = self._compute_tensor_hashes()
if outputs is None:
outputs = ()
elif not isinstance(outputs, (tuple, list)):
outputs = (outputs,)
output_hashes = self._compute_tensor_hashes(outputs)
tensor_hashes = [input_hashes, input_after_hashes, output_hashes]
self._add_tuned_config(
kernel_name=kernel_name,
config=config,
shape_key=self.shape_key,
timing_ms=timing_ms,
kernel_source_file=kernel_source_file,
shape_features=shape_features,
tensor_hashes=tensor_hashes,
)
self._save_tuned_configs()
print(
f"[AOT collect] Saved config for kernel={kernel_name} "
f"shape_hash={self.shape_key.stable_hash()[:8]} "
f"hashes={tensor_hashes} "
f"to {self._configs_file}",
file=sys.stderr,
)
log.info(
f"AOT collect: Saved config for {kernel_name} "
f"shape={self.shape_key.stable_hash()}"
)
def measure_all_configs(self) -> list[tuple[Config, float]]:
"""
Measure all known configs for the current shape.
Returns list of (config, timing_ms) pairs.
"""
self.autotuner._prepare()
kernel_name = self.kernel.kernel.name
all_configs = self._get_all_configs_for_kernel(kernel_name)
if not all_configs:
log.warning(f"No configs found for kernel {kernel_name}")
return []
print(
f"[AOT measure] Testing {len(all_configs)} configs for {kernel_name} "
f"shape_hash={self.shape_key.stable_hash()[:8]}",
file=sys.stderr,
)
results: list[tuple[Config, float]] = []
shape_features = self._extract_shape_features()
# Temporarily disable subprocess precompile for direct benchmark calls
old_precompile = self.autotuner.settings.autotune_precompile
self.autotuner.settings.autotune_precompile = None
# Set up provider resources if needed (normally done inside autotune())
benchmark_provider = self.autotuner.benchmark_provider
benchmark_provider.setup()
try:
for i, config in enumerate(all_configs):
try:
# Benchmark this config
result = self.autotuner.benchmark(config)
timing = result.perf
if timing < float("inf"):
results.append((config, timing))
# Save measurement
self._save_measurement(
kernel_name=kernel_name,
shape_key=self.shape_key,
config=config,
timing_ms=timing,
shape_features=shape_features,
)
print(
f"[AOT measure] Config {i + 1}/{len(all_configs)}: {timing:.4f}ms",
file=sys.stderr,
)
else:
print(
f"[AOT measure] Config {i + 1}/{len(all_configs)}: invalid (inf timing)",
file=sys.stderr,
)
except Exception as e:
error_msg = str(e) or type(e).__name__
tb = traceback.format_exc()
print(
f"[AOT measure] Config {i + 1}/{len(all_configs)}: failed - {error_msg}",
file=sys.stderr,
)
# Print last few lines of traceback for debugging
tb_lines = tb.strip().split("\n")
if len(tb_lines) > 4:
print(f" Traceback: ...{tb_lines[-3]}", file=sys.stderr)
print(f" {tb_lines[-2]}", file=sys.stderr)
log.debug(f"Failed to benchmark config {config}: {e}\n{tb}")
finally:
# Restore settings
self.autotuner.settings.autotune_precompile = old_precompile
benchmark_provider.cleanup()
print(
f"[AOT measure] Completed: {len(results)}/{len(all_configs)} configs succeeded",
file=sys.stderr,
)
return results
def _find_heuristic_file(self) -> Path | None:
"""Find the heuristic file for this kernel using shared lookup."""
kernel_name = self.kernel.kernel.name
kernel_source_file = self.kernel.kernel.__code__.co_filename
return find_heuristic_file(
kernel_source_file,
kernel_name=kernel_name,
data_dir=self.data_dir,
)
def _get_heuristic_config(
self, args: Sequence[object] | None = None
) -> Config | None:
"""
Use the heuristic to select a config.
Looks for autotune_<kernel>(*args) function in the heuristic file.
Args:
args: Optional arguments to use. If None, uses self.args.
For CUDA/ROCm, if heuristics for the current compute capability aren't found,
we try older compatible architectures (e.g., sm80 heuristics on sm90 hardware).
"""
heuristic_file = self._find_heuristic_file()
if heuristic_file is None:
return None
if args is None:
args = self.args
kernel_name = self.kernel.kernel.name
kernel_source_file = self.kernel.kernel.__code__.co_filename
# Compute cache key based on shape features
shape_features = self._extract_shape_features(args)
shape_hash = hashlib.sha256(
json.dumps(shape_features, sort_keys=True).encode()
).hexdigest()[:16]
# Check if we already have a cached result for this kernel+shape
cache_key = (kernel_source_file, kernel_name, shape_hash)
if cache_key in AOTAutotuneCache._heuristic_results:
log.debug(
f"Using cached heuristic result for {kernel_name} shape={shape_hash}"
)
return AOTAutotuneCache._heuristic_results[cache_key]
try:
# Load heuristic module from cache or import fresh
if heuristic_file in AOTAutotuneCache._heuristic_modules:
module = AOTAutotuneCache._heuristic_modules[heuristic_file]
else:
spec = importlib.util.spec_from_file_location(
"heuristic", heuristic_file
)
if spec is None or spec.loader is None:
return None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
AOTAutotuneCache._heuristic_modules[heuristic_file] = module
log.debug(f"Loaded heuristic module: {heuristic_file}")
# Call autotune_<kernel>(*args) to get the config
# If there's a user key, we need to pass flattened key values, not raw args
config: Config | None = None
autotune_fn = getattr(module, f"autotune_{kernel_name}", None)
if autotune_fn is not None:
user_key = getattr(self.kernel.kernel, "_aot_user_key", None)
if user_key is not None:
# User key: pass flattened key values to heuristic
key_value = user_key(*args)
flat_key = _flatten_key_value(key_value)
config_dict = autotune_fn(*flat_key)
else:
# No user key: pass raw args to heuristic
config_dict = autotune_fn(*args)
config = Config(**config_dict)
# Cache the result
if config is not None:
AOTAutotuneCache._heuristic_results[cache_key] = config
log.debug(
f"Cached heuristic result for {kernel_name} shape={shape_hash}"
)
return config
except Exception as e:
log.warning(f"Failed to load heuristic from {heuristic_file}: {e}")
return None
def _maybe_run_compile(self) -> None:
"""
In compile mode, generate Triton code for all heuristic-selected
configs and write a standalone ``.py`` file with zero Helion deps.
Runs at most once per kernel (tracked by ``_compiled_kernels``).
"""
kernel_name = self.kernel.kernel.name
if kernel_name in AOTAutotuneCache._compiled_kernels:
return
AOTAutotuneCache._compiled_kernels.add(kernel_name)
heuristic_file = self._find_heuristic_file()
if heuristic_file is None:
log.warning(
"No heuristic for '%s', skipping standalone compile", kernel_name
)
return
# -- load heuristic module ------------------------------------------
if heuristic_file in AOTAutotuneCache._heuristic_modules:
module = AOTAutotuneCache._heuristic_modules[heuristic_file]
else:
spec = importlib.util.spec_from_file_location("heuristic", heuristic_file)
if spec is None or spec.loader is None:
return
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
AOTAutotuneCache._heuristic_modules[heuristic_file] = module
# -- extract selected configs ---------------------------------------
# nearest_neighbor backend: module-level CONFIGS
# decision_tree backend: _C = [...] inside autotune_<kernel>
configs_list: list[dict[str, object]] | None = getattr(module, "CONFIGS", None)
if configs_list is None:
configs_list = self._parse_configs_from_autotune(module, kernel_name)
if configs_list is None:
log.warning("Cannot extract configs from heuristic for '%s'", kernel_name)
return
# -- generate Triton code for each config --------------------------
triton_codes: list[str] = []
for i, config_dict in enumerate(configs_list):
config = Config(**config_dict) # pyrefly: ignore [bad-argument-type]
try:
triton_codes.append(self.kernel.to_triton_code(config))
except Exception:
log.warning(
"Config %d failed to compile for '%s'",
i,
kernel_name,
exc_info=True,
)
triton_codes.append(
f"def {kernel_name}(*args, **kwargs):\n"
f" raise RuntimeError('Config {i} failed to compile')\n"
)
# -- emit standalone file -------------------------------------------
from ..experimental.aot_compile import generate_standalone_file
out_path = generate_standalone_file(
kernel_name=kernel_name,
triton_codes=triton_codes,
heuristic_code=heuristic_file.read_text(),
output_dir=self.data_dir,
kernel_source_file=self.kernel.kernel.__code__.co_filename,
)
print(f"[AOT] Standalone: {out_path}", file=sys.stderr)
@staticmethod
def _parse_configs_from_autotune(
module: object, kernel_name: str
) -> list[dict[str, object]] | None:
"""Extract the ``_C`` config list from ``autotune_<kernel>``."""
autotune_fn = getattr(module, f"autotune_{kernel_name}", None)
if autotune_fn is None:
return None
try:
src = inspect.getsource(autotune_fn)
except OSError:
return None
start = src.find("_C = [")
if start < 0:
return None
start += len("_C = ")
depth = 0
end = start
for i in range(start, len(src)):
if src[i] == "[":
depth += 1
elif src[i] == "]":
depth -= 1
if depth == 0:
end = i + 1
break