99import sysconfig
1010from pathlib import Path
1111from typing import List
12- from jinja2 import Template
1312
13+ from jinja2 import Template
14+ from packaging .version import Version
1415from setuptools import Extension , setup
1516from setuptools .command .build_ext import build_ext
1617
@@ -33,6 +34,17 @@ def get_nccl_root():
3334 return None
3435
3536
37+ def get_cuda_version ():
38+ import torch .utils .cpp_extension as torch_cpp_ext
39+
40+ if torch_cpp_ext .CUDA_HOME is None :
41+ nvcc = "nvcc"
42+ else :
43+ nvcc = os .path .join (torch_cpp_ext .CUDA_HOME , "bin/nvcc" )
44+ txt = subprocess .check_output ([nvcc , "--version" ], text = True )
45+ return Version (re .findall (r"release (\d+\.\d+)," , txt )[0 ])
46+
47+
3648def get_base_dir ():
3749 return os .path .abspath (os .path .dirname (__file__ ))
3850
@@ -55,7 +67,7 @@ def get_scalellm_version():
5567
5668 if not version :
5769 raise RuntimeError ("Unable to find version string." )
58-
70+
5971 version_suffix = os .getenv ("SCALELLM_VERSION_SUFFIX" )
6072 if version_suffix :
6173 version += version_suffix
@@ -158,8 +170,11 @@ def build_extension(self, ext: CMakeExtension):
158170 debug = int (os .environ .get ("DEBUG" , 0 )) if self .debug is None else self .debug
159171 build_type = "Debug" if debug else "Release"
160172
161- # python directories
162- cuda_architectures = "80;89;90"
173+ cuda_version = get_cuda_version ()
174+ cuda_architectures = "80;89;90a"
175+ if cuda_version >= Version ("12.8" ):
176+ # blackwell needs cuda 12.8
177+ cuda_architectures += ";100a;120a"
163178 cmake_args = [
164179 "-G" ,
165180 "Ninja" , # Ninja is much faster than make
@@ -247,7 +262,6 @@ def build_extension(self, ext: CMakeExtension):
247262 "License :: OSI Approved :: Apache Software License" ,
248263 "Topic :: Scientific/Engineering" ,
249264 "Topic :: Scientific/Engineering :: Artificial Intelligence" ,
250-
251265 ],
252266 packages = ["scalellm" , "scalellm/serve" , "scalellm/_C" , "examples" ],
253267 ext_modules = [CMakeExtension ("_C" , "scalellm/" )],
0 commit comments