-
Notifications
You must be signed in to change notification settings - Fork 137
Expand file tree
/
Copy pathtest_best_available.py
More file actions
1064 lines (898 loc) · 43.2 KB
/
test_best_available.py
File metadata and controls
1064 lines (898 loc) · 43.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import hashlib
import json
import os
from pathlib import Path
import tempfile
import time
import unittest
from unittest.mock import MagicMock
from unittest.mock import patch
import torch
from helion._compiler.backend import TileIRBackend
from helion._compiler.backend import TritonBackend
from helion._testing import DEVICE
from helion.autotuner.base_cache import LooseAutotuneCacheKey
from helion.autotuner.base_search import PopulationBasedSearch
from helion.autotuner.base_search import _normalize_spec_key_str
from helion.autotuner.config_fragment import Category
from helion.autotuner.config_generation import ConfigGeneration
from helion.autotuner.config_spec import BlockSizeSpec
from helion.autotuner.config_spec import ConfigSpec
from helion.autotuner.config_spec import FlattenLoopSpec
from helion.autotuner.config_spec import LoopOrderSpec
from helion.autotuner.config_spec import RangeUnrollFactorSpec
from helion.autotuner.config_spec import ReductionLoopSpec
from helion.autotuner.local_cache import LocalAutotuneCache
from helion.autotuner.local_cache import SavedBestConfig
from helion.autotuner.local_cache import iter_cache_entries
from helion.autotuner.pattern_search import InitialPopulationStrategy
from helion.runtime.config import Config
from helion.runtime.settings import Settings
from helion.runtime.settings import _get_initial_population_strategy
class TestBestAvailable(unittest.TestCase):
"""Tests for the from_best_available autotuner feature."""
def test_initial_population_strategy_enum(self):
"""Test that FROM_BEST_AVAILABLE is a valid strategy."""
self.assertEqual(
InitialPopulationStrategy.FROM_BEST_AVAILABLE.value, "from_best_available"
)
def test_get_initial_population_strategy_best_available(self):
"""Test that HELION_AUTOTUNER_INITIAL_POPULATION=from_best_available works."""
with patch.dict(
os.environ, {"HELION_AUTOTUNER_INITIAL_POPULATION": "from_best_available"}
):
strategy = _get_initial_population_strategy("from_random")
self.assertEqual(strategy, InitialPopulationStrategy.FROM_BEST_AVAILABLE)
def test_get_initial_population_strategy_invalid(self):
"""Test that invalid values raise ValueError."""
with patch.dict(
os.environ, {"HELION_AUTOTUNER_INITIAL_POPULATION": "invalid_value"}
):
with self.assertRaises(ValueError) as cm:
_get_initial_population_strategy("from_random")
self.assertIn("from_best_available", str(cm.exception))
def test_best_available_max_configs_default(self):
"""Test that autotune_best_available_max_configs default is 20."""
settings = Settings()
self.assertEqual(settings.autotune_best_available_max_configs, 20)
def test_best_available_max_cache_scan_default(self):
"""Test that autotune_best_available_max_cache_scan default is 500."""
settings = Settings()
self.assertEqual(settings.autotune_best_available_max_cache_scan, 500)
def test_cache_entry_to_mutable_flat_config(self):
"""Test SavedBestConfig.to_mutable_flat_config returns a mutable list."""
config_spec = ConfigSpec(backend=TritonBackend())
config_spec.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=64, min_size=16, max_size=256)
)
config_spec.block_sizes.append(
BlockSizeSpec(block_id=1, size_hint=128, min_size=16, max_size=256)
)
config_gen = ConfigGeneration(config_spec)
cached_config = Config(block_sizes=[32, 64], num_warps=8, num_stages=3)
stored_flat = tuple(config_gen.flatten(cached_config))
entry = SavedBestConfig(
hardware="test",
specialization_key="test",
config=cached_config,
config_spec_hash="",
flat_config=stored_flat,
)
flat = entry.to_mutable_flat_config()
self.assertEqual(flat, list(stored_flat))
self.assertEqual(flat[0], 32)
self.assertEqual(flat[1], 64)
num_warps_idx = config_gen._key_to_flat_indices["num_warps"][0][0]
self.assertEqual(flat[num_warps_idx], 8)
num_stages_idx = config_gen._key_to_flat_indices["num_stages"][0][0]
self.assertEqual(flat[num_stages_idx], 3)
def test_key_to_flat_indices_mapping(self):
"""Test that _key_to_flat_indices mapping is built correctly."""
config_spec = ConfigSpec(backend=TritonBackend())
config_spec.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=64, min_size=16, max_size=256)
)
config_spec.block_sizes.append(
BlockSizeSpec(block_id=1, size_hint=128, min_size=16, max_size=256)
)
config_spec.flatten_loops.append(FlattenLoopSpec([0]))
config_spec.flatten_loops.append(FlattenLoopSpec([1]))
config_gen = ConfigGeneration(config_spec)
mapping = config_gen._key_to_flat_indices
self.assertIn("block_sizes", mapping)
self.assertIn("num_warps", mapping)
self.assertIn("num_stages", mapping)
self.assertIn("flatten_loops", mapping)
# block_sizes should have 2 indices, num_warps should have 1
self.assertEqual(len(mapping["block_sizes"][0]), 2)
self.assertEqual(len(mapping["num_warps"][0]), 1)
self.assertEqual(len(mapping["flatten_loops"][0]), 2)
# is_sequence should be True for BlockIdSequence fields, False for scalars
self.assertTrue(mapping["block_sizes"][1])
self.assertTrue(mapping["flatten_loops"][1])
self.assertFalse(mapping["num_warps"][1])
self.assertFalse(mapping["num_stages"][1])
for key, (indices, _is_sequence) in mapping.items():
for idx in indices:
self.assertGreaterEqual(idx, 0, f"Key {key} has negative index")
self.assertLess(
idx, len(config_gen.flat_spec), f"Key {key} index out of bounds"
)
first_block_size_idx = next(
i
for i, spec in enumerate(config_gen.flat_spec)
if spec.category() == Category.BLOCK_SIZE
)
self.assertEqual(mapping["block_sizes"][0][0], first_block_size_idx)
num_warps_indices = [
i
for i, spec in enumerate(config_gen.flat_spec)
if spec.category() == Category.NUM_WARPS
]
self.assertEqual(mapping["num_warps"][0][0], num_warps_indices[-1])
def test_flatten_unflatten_roundtrip(self):
"""Test that flatten(unflatten(flat)) == flat for default config."""
config_spec = ConfigSpec(backend=TritonBackend())
config_spec.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=64, min_size=16, max_size=256)
)
config_spec.block_sizes.append(
BlockSizeSpec(block_id=1, size_hint=128, min_size=16, max_size=256)
)
config_spec.loop_orders.append(LoopOrderSpec([0, 1]))
config_spec.flatten_loops.append(FlattenLoopSpec([0]))
config_gen = ConfigGeneration(config_spec)
default_flat = config_gen.default_flat()
config = config_gen.unflatten(default_flat)
roundtripped = config_gen.flatten(config)
self.assertEqual(
roundtripped,
default_flat,
"flatten(unflatten(default_flat)) != default_flat",
)
def test_flatten_with_dropped_keys(self):
"""Regression: normalize() drops num_sm_multiplier for non-persistent pid_types.
flat_spec still has an entry for it (flat_config calls fn() before
normalize drops the key). flatten() must not shift later indices
when a key is present in flat_spec but absent from config.config.
"""
config_spec = ConfigSpec(backend=TritonBackend())
config_spec.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=64, min_size=16, max_size=256)
)
config_gen = ConfigGeneration(config_spec)
# Build a config with pid_type="flat" — normalize drops num_sm_multiplier
config = Config(
block_sizes=[64],
num_warps=4,
num_stages=2,
pid_type="flat",
)
config_spec.normalize(config.config)
# num_sm_multiplier should NOT be in the config after normalize
self.assertNotIn("num_sm_multiplier", config.config)
# flatten must not crash or mis-align indices
flat = config_gen.flatten(config)
self.assertEqual(len(flat), len(config_gen.flat_spec))
# Roundtrip: unflatten should produce a valid config
restored = config_gen.unflatten(flat)
self.assertIn("block_sizes", restored.config)
self.assertEqual(restored.config["block_sizes"], [64])
def test_flatten_with_empty_list_keys(self):
"""Regression: normalize() can re-add empty-list keys.
config_spec.normalize() unconditionally writes back
``config["range_warp_specializes"] = range_warp_specializes``
(see config_spec.py normalize(), near the end of the method)
even when the value is ``[]``. Because the BlockIdSequence for
range_warp_specialize may be empty, flat_key_layout() won't
include it, yet config.config will contain the key.
flatten() must skip such keys without crashing.
"""
config_spec = ConfigSpec(backend=TritonBackend())
config_spec.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=64, min_size=16, max_size=256)
)
config_gen = ConfigGeneration(config_spec)
default_config = config_spec.default_config()
# Manually add a spurious empty-list key (simulates normalize re-adding it)
default_config.config["range_warp_specializes"] = []
flat = config_gen.flatten(default_config)
self.assertEqual(len(flat), len(config_gen.flat_spec))
@patch("helion.autotuner.config_spec.supports_maxnreg", return_value=False)
def test_flatten_unflatten_with_tileir_no_duplicate_keys(self, _mock_maxnreg):
"""TileIR replaces the standard num_warps/num_stages fragments.
flat_key_layout must contain each key exactly once, and
flatten/unflatten must roundtrip correctly.
"""
config_spec = ConfigSpec(backend=TileIRBackend())
config_spec.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=64, min_size=16, max_size=256)
)
config_gen = ConfigGeneration(config_spec)
# flat_key_layout should contain num_warps and num_stages exactly once
layout_keys = [key for key, *_ in config_spec.flat_key_layout()]
self.assertEqual(layout_keys.count("num_warps"), 1)
self.assertEqual(layout_keys.count("num_stages"), 1)
# Roundtrip: default_flat -> unflatten -> flatten should be stable
default_flat = config_gen.default_flat()
config = config_gen.unflatten(default_flat)
roundtripped = config_gen.flatten(config)
self.assertEqual(len(roundtripped), len(default_flat))
# The tileir-specific keys should be present in the config
self.assertIn("num_ctas", config.config)
self.assertIn("occupancy", config.config)
def test_flatten_multiple_reduction_loops(self):
"""Test that flatten/unflatten handles multiple reduction loops correctly.
ReductionLoopSpec overrides _flat_config() with custom logic, so this
verifies each element lands at its own flat index and round-trips cleanly.
"""
config_spec = ConfigSpec(backend=TritonBackend())
config_spec.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=64, min_size=16, max_size=256)
)
config_spec.reduction_loops.append(ReductionLoopSpec(block_id=1, size_hint=128))
config_spec.reduction_loops.append(ReductionLoopSpec(block_id=2, size_hint=64))
config_gen = ConfigGeneration(config_spec)
# Build a config with explicit non-None reduction_loops values
config = config_spec.default_config()
config.config["reduction_loops"] = [32, 16]
config_spec.normalize(config.config)
flat = config_gen.flatten(config)
self.assertEqual(len(flat), len(config_gen.flat_spec))
# Verify the reduction_loops values land at their respective flat indices
rl_indices, rl_is_seq = config_gen._key_to_flat_indices["reduction_loops"]
self.assertTrue(rl_is_seq)
self.assertEqual(len(rl_indices), 2)
self.assertEqual(flat[rl_indices[0]], 32)
self.assertEqual(flat[rl_indices[1]], 16)
# Roundtrip through unflatten should reproduce the same flat config
restored = config_gen.unflatten(flat)
re_flat = config_gen.flatten(restored)
self.assertEqual(re_flat, flat)
class TestCacheMatching(unittest.TestCase):
"""Tests for cache file matching in warm start."""
def _write_best_config(
self,
cache_dir: str,
filename: str,
hardware: str,
spec_key: str,
source_hash: str,
config_dict: dict,
mtime_offset: float = 0,
config_spec_hash: str = "",
flat_config: list[object] | None = None,
) -> None:
"""Helper to write a fake .best_config file."""
data: dict[str, object] = {
"key": {
"fields": {
"hardware": hardware,
"specialization_key": spec_key,
"kernel_source_hash": source_hash,
"config_spec_hash": config_spec_hash,
}
},
"config": json.dumps(config_dict),
}
if flat_config is not None:
data["flat_config"] = json.dumps(flat_config)
filepath = os.path.join(cache_dir, filename)
with open(filepath, "w") as f:
json.dump(data, f)
if mtime_offset != 0:
current_time = time.time()
os.utime(
filepath, (current_time + mtime_offset, current_time + mtime_offset)
)
def test_find_similar_cached_configs_end_to_end(self):
"""End-to-end test for _find_similar_cached_configs."""
fingerprint = (("block_sizes", 2, 1, 1),)
fp_hash = hashlib.sha256(repr(fingerprint).encode("utf-8")).hexdigest()
with tempfile.TemporaryDirectory() as cache_dir:
self._write_best_config(
cache_dir,
"match1.best_config",
hardware="NVIDIA GeForce RTX 4090",
spec_key="('tensor_spec',)",
source_hash="hash1",
config_dict={"block_sizes": [64, 128], "num_warps": 4},
mtime_offset=-10,
config_spec_hash=fp_hash,
flat_config=[64, 128, 4],
)
self._write_best_config(
cache_dir,
"match2.best_config",
hardware="NVIDIA GeForce RTX 4090",
spec_key="('tensor_spec',)",
source_hash="hash2",
config_dict={"block_sizes": [32, 64], "num_warps": 8},
mtime_offset=0,
config_spec_hash=fp_hash,
flat_config=[32, 64, 8],
)
self._write_best_config(
cache_dir,
"diff_hw.best_config",
hardware="NVIDIA A100",
spec_key="('tensor_spec',)",
source_hash="hash3",
config_dict={"block_sizes": [128, 256], "num_warps": 4},
config_spec_hash=fp_hash,
flat_config=[128, 256, 4],
)
self._write_best_config(
cache_dir,
"diff_spec.best_config",
hardware="NVIDIA GeForce RTX 4090",
spec_key="('different_spec',)",
source_hash="hash4",
config_dict={"block_sizes": [16, 32], "num_warps": 2},
config_spec_hash=fp_hash,
flat_config=[16, 32, 2],
)
mock_search = MagicMock()
mock_search._skip_cache = False
mock_search.settings = MagicMock()
mock_search.settings.autotune_best_available_max_cache_scan = 500
mock_search._get_current_hardware_and_specialization = MagicMock(
return_value=("NVIDIA GeForce RTX 4090", "('tensor_spec',)")
)
mock_search.config_spec = MagicMock()
mock_search.config_spec.structural_fingerprint_hash = MagicMock(
return_value=fp_hash
)
with patch(
"helion.autotuner.local_cache.get_helion_cache_dir",
return_value=Path(cache_dir),
):
entries = PopulationBasedSearch._find_similar_cached_configs(
mock_search, max_configs=10
)
self.assertEqual(len(entries), 2)
self.assertEqual(entries[0].config.config["block_sizes"], [32, 64])
self.assertEqual(entries[1].config.config["block_sizes"], [64, 128])
def test_find_similar_cached_configs_respects_max_configs(self):
"""Test that _find_similar_cached_configs respects max_configs limit."""
fingerprint = (("block_sizes", 1, 1),)
fp_hash = hashlib.sha256(repr(fingerprint).encode("utf-8")).hexdigest()
with tempfile.TemporaryDirectory() as cache_dir:
for i in range(5):
self._write_best_config(
cache_dir,
f"match{i}.best_config",
hardware="NVIDIA GeForce RTX 4090",
spec_key="('tensor_spec',)",
source_hash=f"hash{i}",
config_dict={"block_sizes": [32 * (i + 1)], "num_warps": 4},
mtime_offset=-i,
config_spec_hash=fp_hash,
flat_config=[32 * (i + 1), 4],
)
mock_search = MagicMock()
mock_search._skip_cache = False
mock_search.settings = MagicMock()
mock_search.settings.autotune_best_available_max_cache_scan = 500
mock_search._get_current_hardware_and_specialization = MagicMock(
return_value=("NVIDIA GeForce RTX 4090", "('tensor_spec',)")
)
mock_search.config_spec = MagicMock()
mock_search.config_spec.structural_fingerprint_hash = MagicMock(
return_value=fp_hash
)
with patch(
"helion.autotuner.local_cache.get_helion_cache_dir",
return_value=Path(cache_dir),
):
entries = PopulationBasedSearch._find_similar_cached_configs(
mock_search, max_configs=2
)
self.assertEqual(len(entries), 2)
def test_cache_matching_with_code_object_in_spec_key(self):
"""End-to-end: cached entry with raw code object repr matches current
key that has a different memory address for the same function.
This simulates the matmul-with-activation-lambda scenario where
put() stores the raw str() of specialization_key (containing
<code object ...at 0xADDR>) and _find_similar_cached_configs
must normalize both sides to find the match.
"""
fingerprint = (("block_sizes", 2, 1, 1),)
fp_hash = hashlib.sha256(repr(fingerprint).encode("utf-8")).hexdigest()
# What put() stored: raw str() with a specific memory address
stored_spec_key = (
"((torch.float16, 'cuda', (2, 2), False, frozenset()), "
"(torch.float16, 'cuda', (2, 2), False, frozenset()), "
"(<code object addmm_epilogue at 0x7e56e22f1a70, "
'file "/home/user/matmul.py", line 244>, '
"<class 'float'>, <class 'float'>, "
"(torch.float16, 'cuda', (2, 2), False, frozenset())))"
)
# What the current process computes: same function, different address
current_raw_spec_key = (
"((torch.float16, 'cuda', (2, 2), False, frozenset()), "
"(torch.float16, 'cuda', (2, 2), False, frozenset()), "
"(<code object addmm_epilogue at 0x7fff98761234, "
'file "/home/user/matmul.py", line 244>, '
"<class 'float'>, <class 'float'>, "
"(torch.float16, 'cuda', (2, 2), False, frozenset())))"
)
# _get_current_hardware_and_specialization applies normalization
current_normalized = _normalize_spec_key_str(current_raw_spec_key)
with tempfile.TemporaryDirectory() as cache_dir:
self._write_best_config(
cache_dir,
"matmul_activation.best_config",
hardware="NVIDIA GeForce RTX 5090",
spec_key=stored_spec_key,
source_hash="hash1",
config_dict={"block_sizes": [64, 128], "num_warps": 4},
config_spec_hash=fp_hash,
flat_config=[64, 128, 4],
)
# Also write a cache entry with different closure values —
# should NOT match even after stripping the code object
stored_different_closure = (
"((torch.float32, 'cuda', (4, 4), False, frozenset()), "
"(torch.float32, 'cuda', (4, 4), False, frozenset()), "
"(<code object addmm_epilogue at 0x7e56e22f1a70, "
'file "/home/user/matmul.py", line 244>, '
"<class 'int'>, <class 'int'>, "
"(torch.float32, 'cuda', (4, 4), False, frozenset())))"
)
self._write_best_config(
cache_dir,
"matmul_different_closure.best_config",
hardware="NVIDIA GeForce RTX 5090",
spec_key=stored_different_closure,
source_hash="hash2",
config_dict={"block_sizes": [32, 64], "num_warps": 8},
config_spec_hash=fp_hash,
flat_config=[32, 64, 8],
)
mock_search = MagicMock()
mock_search._skip_cache = False
mock_search.settings = MagicMock()
mock_search.settings.autotune_best_available_max_cache_scan = 500
mock_search._get_current_hardware_and_specialization = MagicMock(
return_value=("NVIDIA GeForce RTX 5090", current_normalized)
)
mock_search.config_spec = MagicMock()
mock_search.config_spec.structural_fingerprint_hash = MagicMock(
return_value=fp_hash
)
with patch(
"helion.autotuner.local_cache.get_helion_cache_dir",
return_value=Path(cache_dir),
):
entries = PopulationBasedSearch._find_similar_cached_configs(
mock_search, max_configs=10
)
# Only the matching closure entry should be returned
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].config.config["block_sizes"], [64, 128])
def test_find_similar_matches_with_specialize_extras(self):
"""FROM_BEST_AVAILABLE matches cache entries when hl.specialize() adds
extras to the full specialization key.
The cache stores _base_specialization_key (no extras) but the kernel's
specialization_key() appends hl.specialize() discoveries. The lookup
must use the base key so it matches the stored format.
"""
fingerprint = (("block_sizes", 2, 1, 1),)
fp_hash = hashlib.sha256(repr(fingerprint).encode("utf-8")).hexdigest()
base_spec_key = ("tensor_spec",)
# Full key has an extra element from hl.specialize(x.size(1))
full_spec_key = ("tensor_spec", 256)
with tempfile.TemporaryDirectory() as cache_dir:
# Cache entry stored with base key (as local_cache.py does)
self._write_best_config(
cache_dir,
"specialize.best_config",
hardware="NVIDIA GeForce RTX 4090",
spec_key=str(base_spec_key),
source_hash="hash1",
config_dict={"block_sizes": [64, 128], "num_warps": 4},
config_spec_hash=fp_hash,
flat_config=[64, 128, 4],
)
mock_search = MagicMock()
mock_search._skip_cache = False
mock_search.settings = MagicMock()
mock_search.settings.autotune_best_available_max_cache_scan = 500
mock_search.args = [torch.tensor([1.0], device=DEVICE)]
mock_search.config_spec = MagicMock()
mock_search.config_spec.structural_fingerprint_hash = MagicMock(
return_value=fp_hash
)
# Set up kernel with base key != full key (simulates hl.specialize())
mock_kernel = MagicMock()
mock_kernel._base_specialization_key = MagicMock(return_value=base_spec_key)
mock_kernel.specialization_key = MagicMock(return_value=full_spec_key)
mock_search.kernel.kernel = mock_kernel
# Use the REAL _get_current_hardware_and_specialization
mock_search._get_current_hardware_and_specialization = lambda: (
PopulationBasedSearch._get_current_hardware_and_specialization(
mock_search
)
)
with (
patch(
"helion.autotuner.local_cache.get_helion_cache_dir",
return_value=Path(cache_dir),
),
patch(
"helion.autotuner.base_search.get_device_name",
return_value="NVIDIA GeForce RTX 4090",
),
):
entries = PopulationBasedSearch._find_similar_cached_configs(
mock_search, max_configs=10
)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].config.config["block_sizes"], [64, 128])
class TestIterCacheEntries(unittest.TestCase):
"""Tests for the iter_cache_entries() module-level API in local_cache."""
def _write_cache_file(
self,
cache_dir: str,
filename: str,
hardware: str,
spec_key: str,
config_dict: dict,
mtime_offset: float = 0,
) -> None:
data = {
"key": {
"fields": {
"hardware": hardware,
"specialization_key": spec_key,
}
},
"config": json.dumps(config_dict),
}
filepath = os.path.join(cache_dir, filename)
with open(filepath, "w") as f:
json.dump(data, f)
if mtime_offset != 0:
current_time = time.time()
os.utime(
filepath, (current_time + mtime_offset, current_time + mtime_offset)
)
def test_newest_first_ordering(self):
"""Test that entries are yielded newest first."""
with tempfile.TemporaryDirectory() as cache_dir:
self._write_cache_file(
cache_dir,
"old.best_config",
"HW",
"spec",
{"block_sizes": [32], "num_warps": 4},
mtime_offset=-10,
)
self._write_cache_file(
cache_dir,
"new.best_config",
"HW",
"spec",
{"block_sizes": [64], "num_warps": 4},
mtime_offset=0,
)
entries = list(iter_cache_entries(Path(cache_dir)))
self.assertEqual(len(entries), 2)
self.assertEqual(entries[0].config.config["block_sizes"], [64])
self.assertEqual(entries[1].config.config["block_sizes"], [32])
def test_corrupt_json_skipped(self):
"""Test that corrupt files are silently skipped."""
with tempfile.TemporaryDirectory() as cache_dir:
# Write a valid file
self._write_cache_file(
cache_dir,
"valid.best_config",
"HW",
"spec",
{"block_sizes": [64], "num_warps": 4},
)
# Write a corrupt file
corrupt_path = os.path.join(cache_dir, "corrupt.best_config")
Path(corrupt_path).write_text("not valid json {{{")
entries = list(iter_cache_entries(Path(cache_dir)))
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].hardware, "HW")
def test_max_scan_limits_results(self):
"""Test that max_scan limits how many files are parsed."""
with tempfile.TemporaryDirectory() as cache_dir:
for i in range(5):
self._write_cache_file(
cache_dir,
f"entry{i}.best_config",
"HW",
"spec",
{"block_sizes": [32 * (i + 1)], "num_warps": 4},
mtime_offset=-i,
)
entries = list(iter_cache_entries(Path(cache_dir), max_scan=2))
self.assertEqual(len(entries), 2)
def test_nonexistent_directory(self):
"""Test that a nonexistent directory yields nothing."""
entries = list(iter_cache_entries(Path("/nonexistent/path")))
self.assertEqual(len(entries), 0)
def test_fields_parsed_correctly(self):
"""Test that hardware, specialization_key, and config are correctly parsed."""
with tempfile.TemporaryDirectory() as cache_dir:
self._write_cache_file(
cache_dir,
"entry.best_config",
hardware="NVIDIA RTX 5090",
spec_key="('my_spec',)",
config_dict={"block_sizes": [128], "num_warps": 8},
)
entries = list(iter_cache_entries(Path(cache_dir)))
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].hardware, "NVIDIA RTX 5090")
self.assertEqual(entries[0].specialization_key, "('my_spec',)")
self.assertEqual(entries[0].config.config["block_sizes"], [128])
self.assertEqual(entries[0].config.config["num_warps"], 8)
class TestSpecKeyNormalization(unittest.TestCase):
"""Tests for specialization key normalization via _normalize_spec_key_str()."""
def test_code_object_repr_stripped(self):
"""Test that code object reprs are stripped from strings."""
raw = "(<code object <lambda> at 0x7cdd123, file \"foo.py\", line 322>, (torch.float16, 'cuda'))"
result = _normalize_spec_key_str(raw)
self.assertNotIn("<code object", result)
self.assertIn("torch.float16", result)
self.assertIn("'cuda'", result)
def test_nested_code_objects_stripped(self):
"""Test that nested code objects in tuples are stripped."""
raw = "((<code object helper at 0xabc, file \"x.py\", line 10>, 'inner'), 'outer')"
result = _normalize_spec_key_str(raw)
self.assertNotIn("<code object", result)
self.assertIn("'inner'", result)
self.assertIn("'outer'", result)
def test_tensor_closure_info_preserved(self):
"""Test that tensor/closure information is preserved."""
raw = "((torch.float16, 'cuda', (1024,), (1,), frozenset()),)"
result = _normalize_spec_key_str(raw)
self.assertEqual(result, raw)
def test_end_to_end_matching(self):
"""Test that a stored cache entry with raw code object repr matches
a current key computed with a different address."""
# Simulated stored cache entry (raw str() with address)
stored = "(<code object <lambda> at 0x7cdd1234abcd, file \"matmul.py\", line 42>, (torch.float16, 'cuda', (1024,), (1,), frozenset()))"
# Simulated current key (different address)
current = "(<code object <lambda> at 0x7fff9876fedc, file \"matmul.py\", line 42>, (torch.float16, 'cuda', (1024,), (1,), frozenset()))"
self.assertEqual(
_normalize_spec_key_str(stored),
_normalize_spec_key_str(current),
)
def test_put_stores_raw_spec_key(self):
"""Test that put() stores the raw specialization_key (with code object reprs)."""
def dummy_fn():
pass
code_obj = dummy_fn.__code__
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = Path(tmpdir) / "test.best_config"
key = LooseAutotuneCacheKey(
specialization_key=(code_obj, "tensor_spec"),
extra_results=(),
kernel_source_hash="abc123",
hardware="test_hw",
runtime_name="1.0",
backend="triton",
)
mock_cache = MagicMock()
mock_cache.key = key
mock_cache._get_local_cache_path.return_value = cache_path
mock_cache.kernel.backend_cache_key.return_value = None
# Make flatten() return a JSON-serializable list
mock_cache.kernel.config_spec.create_config_generation.return_value.flatten.return_value = [
64,
4,
]
LocalAutotuneCache.put(mock_cache, Config(block_sizes=[64], num_warps=4))
data = json.loads(cache_path.read_text())
spec_key_str = data["key"]["fields"]["specialization_key"]
# put() stores raw str(v), so code object reprs are present
self.assertIn("<code object", spec_key_str)
self.assertIn("tensor_spec", spec_key_str)
class TestStructuralFingerprint(unittest.TestCase):
"""Tests for ConfigSpec.structural_fingerprint()."""
def test_different_block_sizes_count(self):
"""ConfigSpecs with different block_sizes counts have different fingerprints."""
spec_2 = ConfigSpec(backend=TritonBackend())
spec_2.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=64))
spec_2.block_sizes.append(BlockSizeSpec(block_id=1, size_hint=128))
spec_3 = ConfigSpec(backend=TritonBackend())
spec_3.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=64))
spec_3.block_sizes.append(BlockSizeSpec(block_id=1, size_hint=128))
spec_3.block_sizes.append(BlockSizeSpec(block_id=2, size_hint=256))
self.assertNotEqual(
spec_2.structural_fingerprint(), spec_3.structural_fingerprint()
)
def test_same_structure_same_fingerprint(self):
"""ConfigSpecs with same structure have the same fingerprint."""
spec_a = ConfigSpec(backend=TritonBackend())
spec_a.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=64))
spec_a.block_sizes.append(BlockSizeSpec(block_id=1, size_hint=128))
spec_b = ConfigSpec(backend=TritonBackend())
spec_b.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=32, min_size=8, max_size=512)
)
spec_b.block_sizes.append(
BlockSizeSpec(block_id=1, size_hint=256, min_size=16, max_size=1024)
)
self.assertEqual(
spec_a.structural_fingerprint(), spec_b.structural_fingerprint()
)
def test_different_range_fields_count(self):
"""ConfigSpecs with different range field counts have different fingerprints."""
spec_a = ConfigSpec(backend=TritonBackend())
spec_a.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=64))
spec_a.range_unroll_factors.append(RangeUnrollFactorSpec([0]))
spec_b = ConfigSpec(backend=TritonBackend())
spec_b.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=64))
spec_b.range_unroll_factors.append(RangeUnrollFactorSpec([0]))
spec_b.range_unroll_factors.append(RangeUnrollFactorSpec([1]))
self.assertNotEqual(
spec_a.structural_fingerprint(), spec_b.structural_fingerprint()
)
def test_loop_orders_block_ids_length(self):
"""Loop orders with different block_ids lengths produce different fingerprints."""
spec_a = ConfigSpec(backend=TritonBackend())
spec_a.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=64))
spec_a.loop_orders.append(LoopOrderSpec([0, 1]))
spec_b = ConfigSpec(backend=TritonBackend())
spec_b.block_sizes.append(BlockSizeSpec(block_id=0, size_hint=64))
spec_b.loop_orders.append(LoopOrderSpec([0, 1, 2]))
self.assertNotEqual(
spec_a.structural_fingerprint(), spec_b.structural_fingerprint()
)
class TestHardwareDetection(unittest.TestCase):
"""Tests for hardware detection from kernel arguments."""
def test_hardware_detection_direct_tensor(self):
"""Test hardware detection with a direct tensor argument."""
tensor = torch.zeros(10, device=DEVICE)
mock_search = MagicMock()
mock_search.args = [tensor]
mock_search.kernel = MagicMock()
mock_search.kernel.kernel = MagicMock()
mock_search.kernel.kernel.specialization_key = MagicMock(return_value=("spec",))
hardware, _ = PopulationBasedSearch._get_current_hardware_and_specialization(
mock_search
)
self.assertIsNotNone(hardware)
self.assertIsInstance(hardware, str)
self.assertGreater(len(hardware), 0)
def test_hardware_detection_list_of_tensors(self):
"""Test hardware detection with list[0] tensor (matches cache behavior)."""
tensor = torch.zeros(10, device=DEVICE)
mock_search = MagicMock()
mock_search.args = [[tensor, "other_data"], "scalar_arg"]
mock_search.kernel = MagicMock()
mock_search.kernel.kernel = MagicMock()
mock_search.kernel.kernel.specialization_key = MagicMock(return_value=("spec",))
hardware, _ = PopulationBasedSearch._get_current_hardware_and_specialization(
mock_search
)
self.assertIsNotNone(hardware)
self.assertIsInstance(hardware, str)
self.assertGreater(len(hardware), 0)
def test_hardware_detection_generic_adapter_no_inner_kernel(self):
"""Test that generic adapters without a .kernel attribute return None spec_key."""
mock_search = MagicMock()
mock_search.args = [1, 2, "string"]
mock_search.kernel = MagicMock(spec=[]) # no .kernel attribute
hardware, spec_key = (
PopulationBasedSearch._get_current_hardware_and_specialization(mock_search)
)
self.assertIsNone(spec_key)
class TestGenerateBestAvailablePopulation(unittest.TestCase):
"""Tests for _generate_best_available_population_flat orchestration."""
def _make_config_gen(self):
"""Create a ConfigGeneration with a simple 2-block spec."""
config_spec = ConfigSpec(backend=TritonBackend())
config_spec.block_sizes.append(
BlockSizeSpec(block_id=0, size_hint=64, min_size=16, max_size=256)
)
config_spec.block_sizes.append(
BlockSizeSpec(block_id=1, size_hint=128, min_size=16, max_size=256)
)
return ConfigGeneration(config_spec)
def _make_mock_search(self, config_gen, cached_configs):
"""Create a mock PopulationBasedSearch with the given cached configs.
cached_configs can be a list of Config objects (auto-wrapped in SavedBestConfig)
or a list of SavedBestConfig objects.
"""
entries = []
for cfg in cached_configs:
if isinstance(cfg, SavedBestConfig):
entries.append(cfg)
else:
entries.append(
SavedBestConfig(
hardware="test",
specialization_key="test",
config=cfg,
config_spec_hash="",
flat_config=tuple(config_gen.flatten(cfg)),
)
)
mock_search = MagicMock()
mock_search.config_gen = config_gen
mock_search.settings = Settings()
mock_search.log = MagicMock()
mock_search.log.debug = MagicMock()
mock_search._find_similar_cached_configs = MagicMock(return_value=entries)
return mock_search
def test_default_only_when_no_cached(self):
"""Population contains only default config when no cached configs found."""
config_gen = self._make_config_gen()
mock_search = self._make_mock_search(config_gen, cached_configs=[])
result = PopulationBasedSearch._generate_best_available_population_flat(
mock_search
)
self.assertEqual(len(result), 1)
self.assertEqual(result[0], config_gen.default_flat())
def test_cached_configs_added(self):
"""Cached configs are added after default."""
config_gen = self._make_config_gen()
cached = [
Config(block_sizes=[32, 64], num_warps=8, num_stages=2),
Config(block_sizes=[128, 256], num_warps=2, num_stages=4),
]
mock_search = self._make_mock_search(config_gen, cached)
result = PopulationBasedSearch._generate_best_available_population_flat(
mock_search
)
self.assertEqual(len(result), 3) # 1 default + 2 cached