forked from thu-pacman/chitu
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
118 lines (96 loc) · 3.72 KB
/
setup.py
File metadata and controls
118 lines (96 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# SPDX-FileCopyrightText: 2025 Qingcheng.AI
#
# SPDX-License-Identifier: Apache-2.0
import os
import glob
import setuptools
from setuptools import Extension, setup, find_packages
from setuptools.command.build_py import build_py
import packaging.version
from Cython.Build import cythonize
try:
import torch
except ImportError:
raise RuntimeError(
"torch is required to build chitu. Please install torch (with the correct CUDA version) before installing chitu.\n"
"For example: pip install torch --index-url https://download.pytorch.org/whl/cu124"
)
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
assert packaging.version.parse(setuptools.__version__) >= packaging.version.parse(
"62.3.0"
), "setuptools>=62.3.0 is required for `**` wildcard in package_data."
import csrc.setup_build as operators
from get_requires import install_requires, extras_require
# We use CUDAExtension instead of CMake for native sources, because many of the non-NVIDIA GPUs have
# their custom CUDAExtension, but not their custom CMake support.
if (
os.environ.get("CHITU_ASCEND_BUILD", "0") == "1"
or os.environ.get("CHITU_HYGON_BUILD", "0") == "1"
or os.environ.get("CHITU_MOORE_BUILD", "0") == "1"
):
ext_modules = []
else:
ext_modules = operators.get_extensions()
cython_unsafe_files = (
glob.glob("chitu/ops/triton_ops/**/*.py", recursive=True) # Triton kernels inside
+ glob.glob("chitu/moe/experts/*.py") # Triton kernels inside
+ [
"chitu/moe/batched_routed_activation.py", # plum inside
"chitu/muxi_utils.py", # plum inside
"chitu/native_layout.py", # plum inside
"__main__.py",
]
)
def is_cython_unsafe(path):
for unsafe_file in cython_unsafe_files:
if str(path).endswith(unsafe_file):
return True
return False
def find_py_modules(directory):
modules = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
if not is_cython_unsafe(os.path.join(root, file)):
module_name = os.path.splitext(os.path.join(root, file))[0].replace(
os.sep, "."
)
modules.append(module_name)
return modules
def create_cython_extensions(directory):
extensions = []
for module in find_py_modules(directory):
extension = Extension(module, [module.replace(".", os.sep) + ".py"])
extensions.append(extension)
return extensions
class SkipBuildPy(build_py):
def find_package_modules(self, package, package_dir):
modules = super().find_package_modules(package, package_dir)
filtered_modules = [
(pkg, mod, file) for (pkg, mod, file) in modules if is_cython_unsafe(file)
]
return filtered_modules
my_build_py = build_py
if os.environ.get("CHITU_WITH_CYTHON", "0") != "0":
nthreads = os.environ.get("CHITU_SETUP_JOBS", 0)
if nthreads:
nthreads = int(nthreads)
else:
nthreads = 0
ext_modules += cythonize(create_cython_extensions("chitu"), nthreads=nthreads)
my_build_py = SkipBuildPy
if os.environ.get("CHITU_SETUP_JOBS") is not None:
os.environ["MAX_JOBS"] = os.environ.get("CHITU_SETUP_JOBS") # type: ignore
# The information here can also be placed in setup.cfg - better separation of
# logic and declaration, and simpler if you include description/version in a file.
setup(
name="chitu",
version="0.5.1",
python_requires=">=3.10",
install_requires=install_requires,
extras_require=extras_require,
packages=find_packages(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension, "build_py": my_build_py},
package_data={"chitu": ["config/**/*.yaml"]},
)