Skip to content

Commit 6c6fcc9

Browse files
authored
[deps] Make CUDA build more robust (#1782)
* actually run code for symlinking ptxas, fatbinary, libdevice (the condition `build_backend == "cuda"` was now always false) * check the source of the symlinks exists, instead of creating broken symlinks, if that part of the code ever runs (which at the moment it doesn't because of point above) * rework backend handling, to use more structure instead of string comparisons (which lead to issues like the first point)
1 parent 0fd57e1 commit 6c6fcc9

File tree

1 file changed

+57
-32
lines changed

1 file changed

+57
-32
lines changed

deps/build_local.jl

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -115,36 +115,50 @@ source_dir = joinpath(@__DIR__, "ReactantExtra")
115115
# --@local_config_cuda//:cuda_compiler=nvcc
116116
# --crosstool_top="@local_config_cuda//crosstool:toolchain"
117117

118-
build_kind = parsed_args["debug"] ? "dbg" : "opt"
118+
abstract type AbstractBackend end
119+
struct CPUBackend <: AbstractBackend end
120+
struct CUDABackend <: AbstractBackend
121+
version::VersionNumber
122+
CUDABackend(ver::VersionNumber) = new(VersionNumber(ver.major))
123+
end
119124

120-
build_backend = parsed_args["backend"]
125+
function parse_build_backend(str::String)::AbstractBackend
126+
if str == "cpu"
127+
return CPUBackend()
128+
elseif str == "cuda12"
129+
return CUDABackend(v"12")
130+
elseif str == "cuda13"
131+
return CUDABackend(v"13")
132+
end
121133

122-
if build_backend == "auto" || build_backend == "cuda"
123-
cuda_ver = get_cuda_version()
124-
@show cuda_ver
125-
if cuda_ver === nothing
126-
if build_backend == "cuda"
127-
throw(
128-
AssertionError(
129-
"Could not detect cuda version, but requested cuda with auto version build",
130-
),
131-
)
132-
end
133-
build_backend = "cpu"
134-
else
135-
if Int(get_cuda_version().major) == 13
136-
build_backend = "cuda13"
134+
if str in ("auto", "cuda")
135+
cuda_ver = get_cuda_version()
136+
if isnothing(cuda_ver)
137+
if str == "cuda"
138+
throw(
139+
AssertionError(
140+
"Could not detect cuda version, but requested cuda with auto version build",
141+
),
142+
)
143+
end
144+
return CPUBackend()
137145
else
138-
build_backend = "cuda12"
146+
return CUDABackend(get_cuda_version())
139147
end
148+
else
149+
error("Unknown backend '$(str)'")
140150
end
141151
end
142152

143-
arg = if build_backend == "cuda12"
153+
build_kind = parsed_args["debug"] ? "dbg" : "opt"
154+
155+
build_backend = parse_build_backend(parsed_args["backend"])
156+
157+
arg = if build_backend == CUDABackend(v"12")
144158
"--config=cuda12"
145-
elseif build_backend == "cuda13"
159+
elseif build_backend == CUDABackend(v"13")
146160
"--config=cuda13"
147-
elseif build_backend == "cpu"
161+
elseif build_backend == CPUBackend()
148162
""
149163
else
150164
throw(AssertionError("Unknown backend `$build_backend`"))
@@ -231,6 +245,7 @@ push!(build_cmd_list, "--copt=-Wno-private-header")
231245
push!(build_cmd_list, "--color=$(parsed_args["color"])")
232246
push!(build_cmd_list, ":libReactantExtra.so")
233247

248+
@info "About to run Bazel" build_cmd_list
234249
run(Cmd(Cmd(build_cmd_list); dir=source_dir))
235250

236251
# Discover built libraries
@@ -241,25 +256,35 @@ end
241256
lib_path = joinpath(source_dir, "bazel-bin", only(built_libs))
242257
isfile(lib_path) || error("Could not find library $lib_path in build directory")
243258

244-
if build_backend == "cuda"
259+
if build_backend isa CUDABackend
245260
for path in (
246261
joinpath("bin", "ptxas"),
247262
joinpath("bin", "fatbinary"),
248263
joinpath("nvvm", "libdevice", "libdevice.10.bc"),
249264
)
250265
full_path = joinpath(source_dir, "bazel-bin", "cuda", path)
251266
if !Base.Filesystem.ispath(full_path)
252-
Base.Filesystem.mkpath(dirname(full_path))
253-
Base.Filesystem.symlink(
254-
joinpath(
255-
source_dir,
256-
"bazel-bin",
257-
"libReactantExtra.so.runfiles",
258-
"cuda_nvcc",
259-
path,
260-
),
261-
full_path,
267+
source = joinpath(
268+
source_dir,
269+
"bazel-bin",
270+
"libReactantExtra.so.runfiles",
271+
# libdevice's directory was moved in CUDA 13, before was in same
272+
# dir as ptxas and fatbinary
273+
if contains(basename(path), "libdevice") && build_backend.version >= v"13"
274+
"cuda_nvvm"
275+
else
276+
"cuda_nvcc"
277+
end,
278+
path,
262279
)
280+
if !Base.Filesystem.ispath(source)
281+
error(
282+
"File $(source) does not exist, are you sure it is where you expect it to be?",
283+
)
284+
end
285+
Base.Filesystem.mkpath(dirname(full_path))
286+
@info "Symlinking $(full_path) -> $(source)"
287+
Base.Filesystem.symlink(source, full_path)
263288
end
264289
end
265290
end

0 commit comments

Comments
 (0)