22import tempfile
33import sysconfig
44
5- import os , subprocess , tempfile
5+ import os , subprocess , tempfile , platform
66import importlib .util
7- import sysconfig
7+ import sys
88
99from pathlib import Path
1010
@@ -217,19 +217,15 @@ def _generate_launcher(constants, signature, kernel_name):
217217
218218
219219def compile_module (launcher_src , kernel_placeholder_name ):
220- # This function was renamed and made public in Python 3.10
221- if hasattr (sysconfig , 'get_default_scheme' ):
222- scheme = sysconfig .get_default_scheme ()
220+ py_version = sys .version_info
221+ if platform .system () == "Windows" :
222+ py_include_dir = os .path .join (sys .base_prefix , 'include' )
223+ py_lib_dir = os .path .join (sys .base_prefix , 'libs' )
224+ py_lib = '{name}{major}{minor}.lib' .format (name = "python" , major = py_version .major , minor = py_version .minor )
223225 else :
224- scheme = sysconfig ._get_default_scheme ()
225- # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
226- # path changes to include 'local'. This change is required to use triton with system-wide python.
227- if scheme == 'posix_local' :
228- scheme = 'posix_prefix'
229- py_include_dir = sysconfig .get_paths (scheme = scheme )["include" ]
230- py_lib_dir = sysconfig .get_config_var ("LIBDIR" )
231- py_version = sysconfig .get_config_var ("LDVERSION" )
232- py_lib = '{name}{py_version}' .format (name = "python" , py_version = py_version )
226+ py_include_dir = os .path .join (sys .base_prefix , 'include' , f'python{ sys .version_info .major } .{ sys .version_info .minor } ' )
227+ py_lib_dir = os .path .join (sys .base_prefix , 'lib' )
228+ py_lib = '{name}{major}.{minor}' .format (name = "python" , major = py_version .major , minor = py_version .minor )
233229 cpu_backend_path = Path (__file__ ).resolve ().parent
234230 include_dir = os .path .join (cpu_backend_path , "include" )
235231
@@ -248,28 +244,47 @@ def launch(
248244 key = hashlib .md5 (src .encode ("utf-8" ) + kernel_obj ).hexdigest ()
249245 cache = get_cache_manager (key )
250246 name = "__triton_shared_ref_cpu_kernel_launcher"
251- filename = f"{ name } .so"
247+
248+ if platform .system () == "Windows" :
249+ filename = f"{ name } .pyd"
250+ else :
251+ filename = f"{ name } .so"
252252 cache_path = cache .get_file (filename )
253253
254254 if cache_path is None :
255255 with tempfile .TemporaryDirectory () as tmpdir :
256- obj_path = os .path .join (tmpdir , "kernel.o" )
257- launcher_src_path = os .path .join (tmpdir , "main.cxx" )
258- so_path = os .path .join (tmpdir , "kernel.so" )
259- Path (obj_path ).write_bytes (kernel_obj )
260- Path (launcher_src_path ).write_text (src )
261- # Compile it together.
262- subprocess .check_call ([
263- "g++" , "-std=c++17" , launcher_src_path , obj_path ,
264- f"-I{ py_include_dir } " , f"-I{ include_dir } " , f"-L{ py_lib_dir } " ,
265- "-shared" , f"-l{ py_lib } " , "-fPIC" , "-o" , so_path
266- ])
256+ if platform .system () == "Windows" :
257+ obj_path = os .path .join (tmpdir , "kernel.obj" )
258+ launcher_src_path = os .path .join (tmpdir , "main.cxx" )
259+ so_path = os .path .join (tmpdir , "kernel.pyd" )
260+ Path (obj_path ).write_bytes (kernel_obj )
261+ Path (launcher_src_path ).write_text (src )
262+ # Compile it together.
263+ subprocess .check_call ([
264+ "cl" , "/LD" , "/std:c++17" , launcher_src_path , obj_path ,
265+ f"-I{ py_include_dir } " , f"-I{ include_dir } " , "/link" , f"/LIBPATH:{ py_lib_dir } " ,
266+ "/link" , f"{ py_lib } " , f"/OUT:{ so_path } "
267+ ])
268+ else :
269+ obj_path = os .path .join (tmpdir , "kernel.o" )
270+ launcher_src_path = os .path .join (tmpdir , "main.cxx" )
271+ so_path = os .path .join (tmpdir , "kernel.so" )
272+ Path (obj_path ).write_bytes (kernel_obj )
273+ Path (launcher_src_path ).write_text (src )
274+ # Compile it together.
275+ subprocess .check_call ([
276+ "g++" , "-std=c++17" , launcher_src_path , obj_path ,
277+ f"-I{ py_include_dir } " , f"-I{ include_dir } " , f"-L{ py_lib_dir } " ,
278+ "-shared" , f"-l{ py_lib } " , "-fPIC" , "-o" , so_path
279+ ])
267280
268281 with open (so_path , "rb" ) as f :
269282 cache_path = cache .put (f .read (), filename , binary = True )
270283
271284 # Load and launch the compiled kernel.
272285 spec = importlib .util .spec_from_file_location (name , cache_path )
286+ if spec is None :
287+ raise RuntimeError (f"Cannot find { name } module in { cache_path } " )
273288 mod = importlib .util .module_from_spec (spec )
274289 spec .loader .exec_module (mod )
275290 return mod .launch (gridX , gridY , gridZ ,
0 commit comments