1111import os
1212import re
1313import tempfile
14+ from dataclasses import replace
1415from pathlib import Path
1516import subprocess
1617from typing import Any , Dict , no_type_check
3031# mypy compiler.py hexagon_executor.py hexagon_launcher_base.py torch_mlir_hexagon_launcher.py triton_hexagon_launcher.py --follow-untyped-imports --check-untyped-defs
3132
3233
33- def _get_triton_shared_opt_path () -> str :
34- path = os .getenv ("TRITON_SHARED_OPT_PATH" , "" )
34+ def _get_triton_shared_opt_path (device_type : str ) -> str :
35+ path = os .getenv (
36+ "TRITON_SHARED_OPT_PATH" ,
37+ "" ,
38+ )
3539 if path == "" :
3640 raise Exception ("TRITON_SHARED_OPT_PATH is not set." )
37- return path
3841
42+ bin_path = Path (path ).resolve ()
43+
44+ if not bin_path .exists () or not bin_path .is_file ():
45+ raise FileNotFoundError (
46+ f"Could not find 'triton-shared-opt' at expected location: { bin_path } "
47+ )
48+ if not os .access (bin_path , os .X_OK ):
49+ raise PermissionError (
50+ f"'triton-shared-opt' exists but is not executable: { bin_path } "
51+ )
52+
53+ return str (bin_path )
3954
40- def ttir_to_ttsharedir (mod ):
55+
56+ def ttir_to_ttsharedir (mod : str , options ):
4157 # Get Triton-MLIR as string
4258 ttir_code = str (mod )
4359 with tempfile .TemporaryDirectory () as tmpdir :
4460 src_path = os .path .join (tmpdir , "tt.mlir" )
4561 dst_path = os .path .join (tmpdir , "ttshared.mlir" )
4662 Path (src_path ).write_text (ttir_code )
47- triton_shared_opt_path = _get_triton_shared_opt_path ()
63+ triton_shared_opt_path = _get_triton_shared_opt_path (options . device_type )
4864 subprocess .check_call (
4965 [
5066 triton_shared_opt_path ,
@@ -90,6 +106,10 @@ def ttsharedir_to_obj(mod: str, options, metadata={}) -> bytes:
90106 options_map = {k : str (v ) for k , v in (options .__dict__ ).items ()}
91107 # TODO: Move setting benchmarking iterations when additional stage for shared object creation is part of compilation pipeline.
92108 metadata ["iterations" ] = options_map ["iterations" ]
109+ metadata ["scratch" ] = options_map ["scratch" ]
110+ metadata ["enableMultiThreading" ] = options_map ["enableMultiThreading" ]
111+ metadata ["enableThreadedDispatch" ] = options_map ["enableThreadedDispatch" ]
112+ metadata ["enableLWP" ] = options_map ["enableLWP" ]
93113
94114 # TODO: The lowering pipeline needs to be refactored similar to other Triton backends to
95115 # have a dynamic pipeline filtered by options with each pass represented by a pybind function.
@@ -138,13 +158,31 @@ def hash(self):
138158 return f"{ version } -{ self .target } "
139159
140160 def parse_options (self , opts ) -> Any :
141- assert self .target .backend == "hexagon"
142- args = {
143- k : opts [k ]
144- for k in HexagonOptions .__dataclass_fields__ .keys ()
145- if k in opts
146- }
147- return HexagonOptions (** args )
161+ assert self .target .backend == "hexagon"
162+ args = {
163+ k : opts [k ] for k in HexagonOptions .__dataclass_fields__ .keys () if k in opts
164+ }
165+ hexagon_opts = HexagonOptions (** args )
166+
167+ # When external VTCM scratch is enabled (scratch > 0), automatically
168+ # configure flags for correct SPMD behavior at compile time so that
169+ # both the compiled IR and the generated wrapper are consistent:
170+ # - Disable enableConvertToHexagonmem to prevent VTCMPool from
171+ # concurrently trying to allocate VTCM alongside the external pool.
172+ # - Disable enableHexagonmemCopyToDMA to avoid DMA/thread conflicts.
173+ # - Enable enableThreadedDispatch so instances run in parallel on
174+ # real qurt hardware threads (handled at wrapper generation time).
175+ # Users do not need to set these flags manually when scratch > 0.
176+ if hexagon_opts .scratch > 0 :
177+ hexagon_opts = replace (
178+ hexagon_opts ,
179+ enableMultiThreading = False ,
180+ enableConvertToHexagonmem = False ,
181+ enableVTCMTiling = True ,
182+ enableThreadedDispatch = True ,
183+ )
184+
185+ return hexagon_opts
148186
149187 @staticmethod
150188 def make_ttir (mod , metadata , opt ):
@@ -161,13 +199,13 @@ def make_ttir(mod, metadata, opt):
161199 passes .common .add_symbol_dce (pm )
162200 passes .ttir .add_loop_unroll (pm )
163201 passes .common .add_cse (pm )
164- pm .run (mod )
202+ pm .run (mod , "make_ttir" )
165203 return mod
166204
167205 # May need to add num_warps
168206 def add_stages (self , stages , options , language ):
169207 stages ["ttir" ] = lambda src , metadata : self .make_ttir (src , metadata , options )
170- stages ["ttsharedir" ] = lambda src , metadata : ttir_to_ttsharedir (src )
208+ stages ["ttsharedir" ] = lambda src , metadata : ttir_to_ttsharedir (src , options )
171209 if options .htp_kernel_gen :
172210 if options .target_artifact == "llir" :
173211 stages ["llir" ] = lambda src , metadata : ttsharedir_to_llir (
@@ -190,6 +228,8 @@ def add_stages(self, stages, options, language):
190228 src , options , metadata
191229 )
192230 else : # Default compilation pipeline
231+ assert options .device_type == "hexagon"
232+
193233 stages ["o" ] = lambda src , metadata : ttsharedir_to_obj (
194234 src , options , metadata
195235 )
@@ -240,6 +280,10 @@ def pack_metadata(self, metadata):
240280 metadata .name ,
241281 metadata .return_types ,
242282 metadata .iterations ,
283+ metadata .scratch ,
284+ metadata .enableMultiThreading ,
285+ metadata .enableThreadedDispatch ,
286+ metadata .enableLWP ,
243287 )
244288
245289 def get_module_map (self ) -> Dict [str , ModuleType ]:
0 commit comments