Skip to content

Commit d2681f5

Browse files
committed
Fix local jll more
1 parent 750e3e4 commit d2681f5

File tree

1 file changed

+61
-10
lines changed

1 file changed

+61
-10
lines changed

deps/build_local.jl

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,46 @@ Reactant_jll = Base.UUID("0192cb87-2b54-54ad-80e0-3be72ad8a3c0")
66

77
using ArgParse
88

9+
using Libdl
10+
11+
# adapted from `cudaRuntimeGetVersion` in CUDA_Runtime_jll
12+
function cuDriverGetVersion(library_handle)
13+
function_handle = Libdl.dlsym(library_handle, "cuDriverGetVersion"; throw_error=false)
14+
if function_handle === nothing
15+
@debug "CUDA Driver library seems invalid (does not contain 'cuDriverGetVersion')"
16+
return nothing
17+
end
18+
version_ref = Ref{Cint}()
19+
status = ccall(function_handle, Cint, (Ptr{Cint},), version_ref)
20+
if status != 0
21+
@debug "Call to 'cuDriverGetVersion' failed with status $(status)"
22+
return nothing
23+
end
24+
major, ver = divrem(version_ref[], 1000)
25+
minor, patch = divrem(ver, 10)
26+
version = VersionNumber(major, minor, patch)
27+
@debug "Detected CUDA Driver version $(version)"
28+
return version
29+
end
30+
31+
function get_cuda_version()
32+
cuname = if Sys.iswindows()
33+
Libdl.find_library("nvcuda")
34+
else
35+
Libdl.find_library(["libcuda.so.1", "libcuda.so"])
36+
end
37+
38+
if cuname == ""
39+
return nothing
40+
end
41+
42+
handle = Libdl.dlopen(cuname)
43+
current_cuda_version = cuDriverGetVersion(handle)
44+
path = Libdl.dlpath(handle)
45+
Libdl.dlclose(handle)
46+
return current_cuda_version
47+
end
48+
949
s = ArgParseSettings()
1050
#! format: off
1151
@add_arg_table! s begin
@@ -78,21 +118,32 @@ source_dir = joinpath(@__DIR__, "ReactantExtra")
78118
build_kind = parsed_args["debug"] ? "dbg" : "opt"
79119

80120
build_backend = parsed_args["backend"]
81-
@assert build_backend in ("auto", "cpu", "cuda")
82-
83-
if build_backend == "auto"
84-
build_backend = try
85-
run(Cmd(`nvidia-smi`))
86-
"cuda"
87-
catch
88-
"cpu"
121+
122+
if build_backend == "auto" || build_backend == "cuda"
123+
cuda_ver = get_cuda_version()
124+
@show cuda_ver
125+
if cuda_ver === nothing
126+
if build_backend == "cuda"
127+
throw(AssertionError("Could not detect cuda version, but requested cuda with auto version build"))
128+
end
129+
build_backend = "cpu"
130+
else
131+
if Int(get_cuda_version().major) == 13
132+
build_backend = "cuda13"
133+
else
134+
build_backend = "cuda12"
135+
end
89136
end
90137
end
91138

92-
arg = if build_backend == "cuda"
93-
"--config=cuda"
139+
arg = if build_backend == "cuda12"
140+
"--config=cuda12"
141+
elseif build_backend == "cuda13"
142+
"--config=cuda13"
94143
elseif build_backend == "cpu"
95144
""
145+
else
146+
throw(AssertionError("Unknown backend `$build_backend`"))
96147
end
97148

98149
bazel_cmd = if !isnothing(Sys.which("bazelisk"))

0 commit comments

Comments
 (0)