Skip to content

Commit 6abf884

Browse files
authored
build: added build for blackwell (#459)
1 parent 1294c6c commit 6abf884

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ endif()
7575
# Build TORCH_CUDA_ARCH_LIST
7676
set(TORCH_CUDA_ARCH_LIST "")
7777
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
78-
if(CUDA_ARCH MATCHES "^([0-9])([0-9])*$")
78+
if(CUDA_ARCH MATCHES "^([1-9][0-9]*)([0-9]a?)$")
7979
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}")
8080
elseif(CUDA_ARCH STREQUAL "native")
8181
set(TORCH_ARCH "Auto")

scripts/build_scalellm.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ command -v ccache >/dev/null && ccache -M 25Gi
1111
command -v ccache >/dev/null && ccache -z
1212

1313
# build
14-
cmake -G Ninja -S . -B build -DCMAKE_CUDA_ARCHITECTURES="80;89;90"
14+
cmake -G Ninja -S . -B build -DCMAKE_CUDA_ARCHITECTURES="80;89;90a;100a;120a"
1515
cmake --build build --target scalellm --config Release -j$(nproc)
1616

1717
# show ccache statistics if ccache is installed
1818
command -v ccache >/dev/null && ccache -vs
1919

2020
# install
21-
cmake --install build --prefix ./app
21+
cmake --install build --prefix ./app

setup.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
import sysconfig
1010
from pathlib import Path
1111
from typing import List
12-
from jinja2 import Template
1312

13+
from jinja2 import Template
14+
from packaging.version import Version
1415
from setuptools import Extension, setup
1516
from setuptools.command.build_ext import build_ext
1617

@@ -33,6 +34,17 @@ def get_nccl_root():
3334
return None
3435

3536

37+
def get_cuda_version():
38+
import torch.utils.cpp_extension as torch_cpp_ext
39+
40+
if torch_cpp_ext.CUDA_HOME is None:
41+
nvcc = "nvcc"
42+
else:
43+
nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc")
44+
txt = subprocess.check_output([nvcc, "--version"], text=True)
45+
return Version(re.findall(r"release (\d+\.\d+),", txt)[0])
46+
47+
3648
def get_base_dir():
3749
return os.path.abspath(os.path.dirname(__file__))
3850

@@ -55,7 +67,7 @@ def get_scalellm_version():
5567

5668
if not version:
5769
raise RuntimeError("Unable to find version string.")
58-
70+
5971
version_suffix = os.getenv("SCALELLM_VERSION_SUFFIX")
6072
if version_suffix:
6173
version += version_suffix
@@ -158,8 +170,11 @@ def build_extension(self, ext: CMakeExtension):
158170
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
159171
build_type = "Debug" if debug else "Release"
160172

161-
# python directories
162-
cuda_architectures = "80;89;90"
173+
cuda_version = get_cuda_version()
174+
cuda_architectures = "80;89;90a"
175+
if cuda_version >= Version("12.8"):
176+
# blackwell needs cuda 12.8
177+
cuda_architectures += ";100a;120a"
163178
cmake_args = [
164179
"-G",
165180
"Ninja", # Ninja is much faster than make
@@ -247,7 +262,6 @@ def build_extension(self, ext: CMakeExtension):
247262
"License :: OSI Approved :: Apache Software License",
248263
"Topic :: Scientific/Engineering",
249264
"Topic :: Scientific/Engineering :: Artificial Intelligence",
250-
251265
],
252266
packages=["scalellm", "scalellm/serve", "scalellm/_C", "examples"],
253267
ext_modules=[CMakeExtension("_C", "scalellm/")],

0 commit comments

Comments
 (0)