Skip to content

Commit 7f38361

Browse files
python3kgaeXiang Li
andauthored
Support Windows in driver.compile_module (#242)
This change add support for windows using msvc for dirver.compile_module. Switched to sys.base_prefix and sys.version_info to build py_include_dir py_lib because sysconfig.get_config_var("LDVERSION") returns None for Windows. Add code to compile the launcher for Windows with msvc. Co-authored-by: Xiang Li <xiagli@microsoft.com>
1 parent 0286bbc commit 7f38361

1 file changed

Lines changed: 41 additions & 26 deletions

File tree

backend/driver.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import tempfile
33
import sysconfig
44

5-
import os, subprocess, tempfile
5+
import os, subprocess, tempfile, platform
66
import importlib.util
7-
import sysconfig
7+
import sys
88

99
from pathlib import Path
1010

@@ -217,19 +217,15 @@ def _generate_launcher(constants, signature, kernel_name):
217217

218218

219219
def compile_module(launcher_src, kernel_placeholder_name):
220-
# This function was renamed and made public in Python 3.10
221-
if hasattr(sysconfig, 'get_default_scheme'):
222-
scheme = sysconfig.get_default_scheme()
220+
py_version = sys.version_info
221+
if platform.system() == "Windows":
222+
py_include_dir = os.path.join(sys.base_prefix, 'include')
223+
py_lib_dir = os.path.join(sys.base_prefix, 'libs')
224+
py_lib = '{name}{major}{minor}.lib'.format(name="python", major=py_version.major, minor=py_version.minor)
223225
else:
224-
scheme = sysconfig._get_default_scheme()
225-
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
226-
# path changes to include 'local'. This change is required to use triton with system-wide python.
227-
if scheme == 'posix_local':
228-
scheme = 'posix_prefix'
229-
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
230-
py_lib_dir = sysconfig.get_config_var("LIBDIR")
231-
py_version = sysconfig.get_config_var("LDVERSION")
232-
py_lib = '{name}{py_version}'.format(name="python", py_version=py_version)
226+
py_include_dir = os.path.join(sys.base_prefix, 'include', f'python{sys.version_info.major}.{sys.version_info.minor}')
227+
py_lib_dir = os.path.join(sys.base_prefix, 'lib')
228+
py_lib = '{name}{major}.{minor}'.format(name="python", major=py_version.major, minor=py_version.minor)
233229
cpu_backend_path = Path(__file__).resolve().parent
234230
include_dir = os.path.join(cpu_backend_path, "include")
235231

@@ -248,28 +244,47 @@ def launch(
248244
key = hashlib.md5(src.encode("utf-8") + kernel_obj).hexdigest()
249245
cache = get_cache_manager(key)
250246
name = "__triton_shared_ref_cpu_kernel_launcher"
251-
filename = f"{name}.so"
247+
248+
if platform.system() == "Windows":
249+
filename = f"{name}.pyd"
250+
else:
251+
filename = f"{name}.so"
252252
cache_path = cache.get_file(filename)
253253

254254
if cache_path is None:
255255
with tempfile.TemporaryDirectory() as tmpdir:
256-
obj_path = os.path.join(tmpdir, "kernel.o")
257-
launcher_src_path = os.path.join(tmpdir, "main.cxx")
258-
so_path = os.path.join(tmpdir, "kernel.so")
259-
Path(obj_path).write_bytes(kernel_obj)
260-
Path(launcher_src_path).write_text(src)
261-
# Compile it together.
262-
subprocess.check_call([
263-
"g++", "-std=c++17", launcher_src_path, obj_path,
264-
f"-I{py_include_dir}", f"-I{include_dir}", f"-L{py_lib_dir}",
265-
"-shared", f"-l{py_lib}", "-fPIC", "-o", so_path
266-
])
256+
if platform.system() == "Windows":
257+
obj_path = os.path.join(tmpdir, "kernel.obj")
258+
launcher_src_path = os.path.join(tmpdir, "main.cxx")
259+
so_path = os.path.join(tmpdir, "kernel.pyd")
260+
Path(obj_path).write_bytes(kernel_obj)
261+
Path(launcher_src_path).write_text(src)
262+
# Compile it together.
263+
subprocess.check_call([
264+
"cl", "/LD", "/std:c++17", launcher_src_path, obj_path,
265+
f"-I{py_include_dir}", f"-I{include_dir}", "/link", f"/LIBPATH:{py_lib_dir}",
266+
"/link", f"{py_lib}", f"/OUT:{so_path}"
267+
])
268+
else:
269+
obj_path = os.path.join(tmpdir, "kernel.o")
270+
launcher_src_path = os.path.join(tmpdir, "main.cxx")
271+
so_path = os.path.join(tmpdir, "kernel.so")
272+
Path(obj_path).write_bytes(kernel_obj)
273+
Path(launcher_src_path).write_text(src)
274+
# Compile it together.
275+
subprocess.check_call([
276+
"g++", "-std=c++17", launcher_src_path, obj_path,
277+
f"-I{py_include_dir}", f"-I{include_dir}", f"-L{py_lib_dir}",
278+
"-shared", f"-l{py_lib}", "-fPIC", "-o", so_path
279+
])
267280

268281
with open(so_path, "rb") as f:
269282
cache_path = cache.put(f.read(), filename, binary=True)
270283

271284
# Load and launch the compiled kernel.
272285
spec = importlib.util.spec_from_file_location(name, cache_path)
286+
if spec is None:
287+
raise RuntimeError(f"Cannot find {name} module in {cache_path}")
273288
mod = importlib.util.module_from_spec(spec)
274289
spec.loader.exec_module(mod)
275290
return mod.launch(gridX, gridY, gridZ,

0 commit comments

Comments
 (0)