@@ -6,6 +6,46 @@ Reactant_jll = Base.UUID("0192cb87-2b54-54ad-80e0-3be72ad8a3c0")
66
77using ArgParse
88
9+ using Libdl
10+
11+ # adapted from `cudaRuntimeGetVersion` in CUDA_Runtime_jll
12+ function cuDriverGetVersion (library_handle)
13+ function_handle = Libdl. dlsym (library_handle, " cuDriverGetVersion" ; throw_error= false )
14+ if function_handle === nothing
15+ @debug " CUDA Driver library seems invalid (does not contain 'cuDriverGetVersion')"
16+ return nothing
17+ end
18+ version_ref = Ref {Cint} ()
19+ status = ccall (function_handle, Cint, (Ptr{Cint},), version_ref)
20+ if status != 0
21+ @debug " Call to 'cuDriverGetVersion' failed with status $(status) "
22+ return nothing
23+ end
24+ major, ver = divrem (version_ref[], 1000 )
25+ minor, patch = divrem (ver, 10 )
26+ version = VersionNumber (major, minor, patch)
27+ @debug " Detected CUDA Driver version $(version) "
28+ return version
29+ end
30+
31+ function get_cuda_version ()
32+ cuname = if Sys. iswindows ()
33+ Libdl. find_library (" nvcuda" )
34+ else
35+ Libdl. find_library ([" libcuda.so.1" , " libcuda.so" ])
36+ end
37+
38+ if cuname == " "
39+ return nothing
40+ end
41+
42+ handle = Libdl. dlopen (cuname)
43+ current_cuda_version = cuDriverGetVersion (handle)
44+ path = Libdl. dlpath (handle)
45+ Libdl. dlclose (handle)
46+ return current_cuda_version
47+ end
48+
949s = ArgParseSettings ()
1050# ! format: off
1151@add_arg_table! s begin
@@ -78,21 +118,32 @@ source_dir = joinpath(@__DIR__, "ReactantExtra")
78118build_kind = parsed_args[" debug" ] ? " dbg" : " opt"
79119
80120build_backend = parsed_args[" backend" ]
81- @assert build_backend in (" auto" , " cpu" , " cuda" )
82-
83- if build_backend == " auto"
84- build_backend = try
85- run (Cmd (` nvidia-smi` ))
86- " cuda"
87- catch
88- " cpu"
121+
122+ if build_backend == " auto" || build_backend == " cuda"
123+ cuda_ver = get_cuda_version ()
124+ @show cuda_ver
125+ if cuda_ver === nothing
126+ if build_backend == " cuda"
127+ throw (AssertionError (" Could not detect cuda version, but requested cuda with auto version build" ))
128+ end
129+ build_backend = " cpu"
130+ else
131+ if Int (get_cuda_version (). major) == 13
132+ build_backend = " cuda13"
133+ else
134+ build_backend = " cuda12"
135+ end
89136 end
90137end
91138
92- arg = if build_backend == " cuda"
93- " --config=cuda"
139+ arg = if build_backend == " cuda12"
140+ " --config=cuda12"
141+ elseif build_backend == " cuda13"
142+ " --config=cuda13"
94143elseif build_backend == " cpu"
95144 " "
145+ else
146+ throw (AssertionError (" Unknown backend `$build_backend `" ))
96147end
97148
98149bazel_cmd = if ! isnothing (Sys. which (" bazelisk" ))
0 commit comments