@@ -115,36 +115,50 @@ source_dir = joinpath(@__DIR__, "ReactantExtra")
115115# --@local_config_cuda//:cuda_compiler=nvcc
116116# --crosstool_top="@local_config_cuda//crosstool:toolchain"
117117
118- build_kind = parsed_args[" debug" ] ? " dbg" : " opt"
118+ abstract type AbstractBackend end
119+ struct CPUBackend <: AbstractBackend end
120+ struct CUDABackend <: AbstractBackend
121+ version:: VersionNumber
122+ CUDABackend (ver:: VersionNumber ) = new (VersionNumber (ver. major))
123+ end
119124
120- build_backend = parsed_args[" backend" ]
125+ function parse_build_backend (str:: String ):: AbstractBackend
126+ if str == " cpu"
127+ return CPUBackend ()
128+ elseif str == " cuda12"
129+ return CUDABackend (v " 12" )
130+ elseif str == " cuda13"
131+ return CUDABackend (v " 13" )
132+ end
121133
122- if build_backend == " auto" || build_backend == " cuda"
123- cuda_ver = get_cuda_version ()
124- @show cuda_ver
125- if cuda_ver === nothing
126- if build_backend == " cuda"
127- throw (
128- AssertionError (
129- " Could not detect cuda version, but requested cuda with auto version build" ,
130- ),
131- )
132- end
133- build_backend = " cpu"
134- else
135- if Int (get_cuda_version (). major) == 13
136- build_backend = " cuda13"
134+ if str in (" auto" , " cuda" )
135+ cuda_ver = get_cuda_version ()
136+ if isnothing (cuda_ver)
137+ if str == " cuda"
138+ throw (
139+ AssertionError (
140+ " Could not detect cuda version, but requested cuda with auto version build" ,
141+ ),
142+ )
143+ end
144+ return CPUBackend ()
137145 else
138- build_backend = " cuda12 "
146+ return CUDABackend ( get_cuda_version ())
139147 end
148+ else
149+ error (" Unknown backend '$(str) '" )
140150 end
141151end
142152
143- arg = if build_backend == " cuda12"
153+ build_kind = parsed_args[" debug" ] ? " dbg" : " opt"
154+
155+ build_backend = parse_build_backend (parsed_args[" backend" ])
156+
157+ arg = if build_backend == CUDABackend (v " 12" )
144158 " --config=cuda12"
145- elseif build_backend == " cuda13 "
159+ elseif build_backend == CUDABackend ( v " 13 " )
146160 " --config=cuda13"
147- elseif build_backend == " cpu "
161+ elseif build_backend == CPUBackend ()
148162 " "
149163else
150164 throw (AssertionError (" Unknown backend `$build_backend `" ))
@@ -231,6 +245,7 @@ push!(build_cmd_list, "--copt=-Wno-private-header")
231245push! (build_cmd_list, " --color=$(parsed_args[" color" ]) " )
232246push! (build_cmd_list, " :libReactantExtra.so" )
233247
248+ @info " About to run Bazel" build_cmd_list
234249run (Cmd (Cmd (build_cmd_list); dir= source_dir))
235250
236251# Discover built libraries
@@ -241,25 +256,35 @@ end
241256lib_path = joinpath (source_dir, " bazel-bin" , only (built_libs))
242257isfile (lib_path) || error (" Could not find library $lib_path in build directory" )
243258
244- if build_backend == " cuda "
259+ if build_backend isa CUDABackend
245260 for path in (
246261 joinpath (" bin" , " ptxas" ),
247262 joinpath (" bin" , " fatbinary" ),
248263 joinpath (" nvvm" , " libdevice" , " libdevice.10.bc" ),
249264 )
250265 full_path = joinpath (source_dir, " bazel-bin" , " cuda" , path)
251266 if ! Base. Filesystem. ispath (full_path)
252- Base. Filesystem. mkpath (dirname (full_path))
253- Base. Filesystem. symlink (
254- joinpath (
255- source_dir,
256- " bazel-bin" ,
257- " libReactantExtra.so.runfiles" ,
258- " cuda_nvcc" ,
259- path,
260- ),
261- full_path,
267+ source = joinpath (
268+ source_dir,
269+ " bazel-bin" ,
270+ " libReactantExtra.so.runfiles" ,
271+ # libdevice's directory was moved in CUDA 13, before was in same
272+ # dir as ptxas and fatbinary
273+ if contains (basename (path), " libdevice" ) && build_backend. version >= v " 13"
274+ " cuda_nvvm"
275+ else
276+ " cuda_nvcc"
277+ end ,
278+ path,
262279 )
280+ if ! Base. Filesystem. ispath (source)
281+ error (
282+ " File $(source) does not exist, are you sure it is where you expect it to be?" ,
283+ )
284+ end
285+ Base. Filesystem. mkpath (dirname (full_path))
286+ @info " Symlinking $(full_path) -> $(source) "
287+ Base. Filesystem. symlink (source, full_path)
263288 end
264289 end
265290end
0 commit comments