From 360b38cceb1d20cffb00057be2078cdf7fa0b25a Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Aug 2025 19:49:40 -0700 Subject: [PATCH] Fix device code generation, to account for an implicit dyn_ptr argument. --- compiler/rustc_codegen_llvm/src/back/lto.rs | 3 +- compiler/rustc_codegen_llvm/src/back/write.rs | 70 ++++++++++++++++++- .../src/builder/gpu_offload.rs | 3 + compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 6 ++ compiler/rustc_codegen_llvm/src/type_.rs | 5 ++ compiler/rustc_codegen_ssa/src/back/write.rs | 2 + .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 24 +++++++ compiler/rustc_target/src/callconv/mod.rs | 1 + compiler/rustc_target/src/spec/json.rs | 3 + compiler/rustc_target/src/spec/mod.rs | 8 +++ .../src/spec/targets/amdgcn_amd_amdhsa.rs | 3 + .../src/spec/targets/nvptx64_nvidia_cuda.rs | 3 + 12 files changed, 129 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 02b50fa8a6971..b820b992105fd 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -616,7 +616,8 @@ pub(crate) fn run_pass_manager( write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage); } - if enable_gpu && !thin { + // Here we only handle the GPU host (=cpu) code. + if enable_gpu && !thin && !cgcx.target_is_like_gpu { let cx = SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx); diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index b582d587d9f8a..5b71d6b6ba8e0 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -43,7 +43,7 @@ use crate::errors::{ use crate::llvm::diagnostic::OptimizationDiagnosticKind::*; use crate::llvm::{self, DiagnosticInfo}; use crate::type_::llvm_type_ptr; -use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util}; +use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, base, common, llvm_util}; pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> ! { match llvm::last_error() { @@ -645,6 +645,74 @@ pub(crate) unsafe fn llvm_optimize( None }; + fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) { + let old_fn_ty = cx.get_type_of_global(old_fn); + let old_param_types = cx.func_params_types(old_fn_ty); + let old_param_count = old_param_types.len(); + if old_param_count == 0 { + return; + } + + let first_param = llvm::get_param(old_fn, 0); + let c_name = llvm::get_value_name(first_param); + let first_arg_name = str::from_utf8(&c_name).unwrap(); + // We might call llvm_optimize (and thus this code) multiple times on the same IR, + // but we shouldn't add this helper ptr multiple times. + // FIXME(offload): This could break if the user calls his first argument `dyn_ptr`. + if first_arg_name == "dyn_ptr" { + return; + } + + // Create the new parameter list, with ptr as the first argument + let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1); + new_param_types.push(cx.type_ptr()); + new_param_types.extend(old_param_types); + + // Create the new function type + let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) }; + let new_fn_ty = cx.type_func(&new_param_types, ret_ty); + + // Create the new function, with a temporary .offload name to avoid a name collision. + let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap(); + let new_fn_name = format!("{}.offload", &old_fn_name); + let new_fn = cx.add_func(&new_fn_name, new_fn_ty); + let a0 = llvm::get_param(new_fn, 0); + llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes()); + + // Here we map the old arguments to the new arguments, with an offset of 1 to make sure + // that we don't use the newly added `%dyn_ptr`. + unsafe { + llvm::LLVMRustOffloadMapper(cx.llmod(), old_fn, new_fn); + } + + llvm::set_linkage(new_fn, llvm::get_linkage(old_fn)); + llvm::set_visibility(new_fn, llvm::get_visibility(old_fn)); + + // Replace all uses of old_fn with new_fn (RAUW) + unsafe { + llvm::LLVMReplaceAllUsesWith(old_fn, new_fn); + } + let name = llvm::get_value_name(old_fn); + unsafe { + llvm::LLVMDeleteFunction(old_fn); + } + // Now we can re-use the old name, without name collision. + llvm::set_value_name(new_fn, &name); + } + + if cgcx.target_is_like_gpu && config.offload.contains(&config::Offload::Enable) { + let cx = + SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size); + // For now we only support up to 10 kernels named kernel_0 ... kernel_9, a follow-up PR is + // introducing a proper offload intrinsic to solve this limitation. + for num in 0..9 { + let name = format!("kernel_{num}"); + if let Some(kernel) = cx.get_function(&name) { + handle_offload(&cx, kernel); + } + } + } + let mut llvm_profiler = cgcx .prof .llvm_recording_enabled() diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 3d55064ea1304..5c2f8f700627e 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -19,6 +19,9 @@ pub(crate) fn handle_gpu_code<'ll>( let mut memtransfer_types = vec![]; let mut region_ids = vec![]; let offload_entry_ty = TgtOffloadEntry::new_decl(&cx); + // This is a temporary hack, we only search for kernel_0 to kernel_9 functions. + // There is a draft PR in progress which will introduce a proper offload intrinsic to remove + // this limitation. for num in 0..9 { let kernel = cx.get_function(&format!("kernel_{num}")); if let Some(kernel) = kernel { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 9a391d57d6fb4..74d268ad5dd2e 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1127,6 +1127,11 @@ unsafe extern "C" { // Operations on functions pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint); + pub(crate) fn LLVMAddFunction<'a>( + Mod: &'a Module, + Name: *const c_char, + FunctionTy: &'a Type, + ) -> &'a Value; pub(crate) fn LLVMDeleteFunction(Fn: &Value); // Operations about llvm intrinsics @@ -2017,6 +2022,7 @@ unsafe extern "C" { ) -> &Attribute; // Operations on functions + pub(crate) fn LLVMRustOffloadMapper<'a>(M: &'a Module, Fn: &'a Value, Fn: &'a Value); pub(crate) fn LLVMRustGetOrInsertFunction<'a>( M: &'a Module, Name: *const c_char, diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs index 81bb70c958790..55f053f4fad3f 100644 --- a/compiler/rustc_codegen_llvm/src/type_.rs +++ b/compiler/rustc_codegen_llvm/src/type_.rs @@ -68,6 +68,11 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { unsafe { llvm::LLVMVectorType(ty, len as c_uint) } } + pub(crate) fn add_func(&self, name: &str, ty: &'ll Type) -> &'ll Value { + let name = SmallCStr::new(name); + unsafe { llvm::LLVMAddFunction(self.llmod(), name.as_ptr(), ty) } + } + pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> { unsafe { let n_args = llvm::LLVMCountParamTypes(ty) as usize; diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 368a2e307bb27..edaf65bdb9222 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -342,6 +342,7 @@ pub struct CodegenContext { pub target_arch: String, pub target_is_like_darwin: bool, pub target_is_like_aix: bool, + pub target_is_like_gpu: bool, pub split_debuginfo: rustc_target::spec::SplitDebuginfo, pub split_dwarf_kind: rustc_session::config::SplitDwarfKind, pub pointer_size: Size, @@ -1309,6 +1310,7 @@ fn start_executing_work( target_arch: tcx.sess.target.arch.to_string(), target_is_like_darwin: tcx.sess.target.is_like_darwin, target_is_like_aix: tcx.sess.target.is_like_aix, + target_is_like_gpu: tcx.sess.target.is_like_gpu, split_debuginfo: tcx.sess.split_debuginfo(), split_dwarf_kind: tcx.sess.opts.unstable_opts.split_dwarf_kind, parallel: backend.supports_parallel() && !sess.opts.unstable_opts.no_parallel_backend, diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 2d87ea232eea2..df811ddd8d4fc 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -35,6 +35,8 @@ #include "llvm/Support/Signals.h" #include "llvm/Support/Timer.h" #include "llvm/Support/ToolOutputFile.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include // for raw `write` in the bad-alloc handler @@ -142,6 +144,28 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) { llvm::PrintStatistics(OS); } +extern "C" void LLVMRustOffloadMapper(LLVMModuleRef M, LLVMValueRef OldFn, + LLVMValueRef NewFn) { + llvm::Module *module = llvm::unwrap(M); + llvm::Function *oldFn = llvm::unwrap(OldFn); + llvm::Function *newFn = llvm::unwrap(NewFn); + + // Map old arguments to new arguments. We skip the first dyn_ptr argument, + // since it can't be used directly by user code. + llvm::ValueToValueMapTy vmap; + auto newArgIt = newFn->arg_begin(); + newArgIt->setName("dyn_ptr"); + ++newArgIt; // skip %dyn_ptr + for (auto &oldArg : oldFn->args()) { + vmap[&oldArg] = &*newArgIt++; + } + + llvm::SmallVector returns; + llvm::CloneFunctionInto(newFn, oldFn, vmap, + llvm::CloneFunctionChangeType::LocalChangesOnly, + returns); +} + extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name, size_t NameLen) { return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen))); diff --git a/compiler/rustc_target/src/callconv/mod.rs b/compiler/rustc_target/src/callconv/mod.rs index 43e1ca3ef9cee..147b17b24bb57 100644 --- a/compiler/rustc_target/src/callconv/mod.rs +++ b/compiler/rustc_target/src/callconv/mod.rs @@ -578,6 +578,7 @@ impl RiscvInterruptKind { /// /// The signature represented by this type may not match the MIR function signature. /// Certain attributes, like `#[track_caller]` can introduce additional arguments, which are present in [`FnAbi`], but not in `FnSig`. +/// The std::offload module also adds an addition dyn_ptr argument to the GpuKernel ABI. /// While this difference is rarely relevant, it should still be kept in mind. /// /// I will do my best to describe this structure, but these diff --git a/compiler/rustc_target/src/spec/json.rs b/compiler/rustc_target/src/spec/json.rs index c25628c3939db..563ba0c4131ae 100644 --- a/compiler/rustc_target/src/spec/json.rs +++ b/compiler/rustc_target/src/spec/json.rs @@ -147,6 +147,7 @@ impl Target { forward!(is_like_darwin); forward!(is_like_solaris); forward!(is_like_windows); + forward!(is_like_gpu); forward!(is_like_msvc); forward!(is_like_wasm); forward!(is_like_android); @@ -337,6 +338,7 @@ impl ToJson for Target { target_option_val!(is_like_darwin); target_option_val!(is_like_solaris); target_option_val!(is_like_windows); + target_option_val!(is_like_gpu); target_option_val!(is_like_msvc); target_option_val!(is_like_wasm); target_option_val!(is_like_android); @@ -556,6 +558,7 @@ struct TargetSpecJson { is_like_darwin: Option, is_like_solaris: Option, is_like_windows: Option, + is_like_gpu: Option, is_like_msvc: Option, is_like_wasm: Option, is_like_android: Option, diff --git a/compiler/rustc_target/src/spec/mod.rs b/compiler/rustc_target/src/spec/mod.rs index 74048d351802a..5d8ef47efe31b 100644 --- a/compiler/rustc_target/src/spec/mod.rs +++ b/compiler/rustc_target/src/spec/mod.rs @@ -2180,6 +2180,8 @@ pub struct TargetOptions { /// Also indicates whether to use Apple-specific ABI changes, such as extending function /// parameters to 32-bits. pub is_like_darwin: bool, + /// Whether the target is a GPU (e.g. NVIDIA, AMD, Intel). + pub is_like_gpu: bool, /// Whether the target toolchain is like Solaris's. /// Only useful for compiling against Illumos/Solaris, /// as they have a different set of linker flags. Defaults to false. @@ -2583,6 +2585,7 @@ impl Default for TargetOptions { abi_return_struct_as_int: false, is_like_aix: false, is_like_darwin: false, + is_like_gpu: false, is_like_solaris: false, is_like_windows: false, is_like_msvc: false, @@ -2748,6 +2751,11 @@ impl Target { self.os == "solaris" || self.os == "illumos", "`is_like_solaris` must be set if and only if `os` is `solaris` or `illumos`" ); + check_eq!( + self.is_like_gpu, + self.arch == Arch::Nvptx64 || self.arch == Arch::AmdGpu, + "`is_like_gpu` must be set if and only if `target` is `nvptx64` or `amdgcn`" + ); check_eq!( self.is_like_windows, self.os == "windows" || self.os == "uefi" || self.os == "cygwin", diff --git a/compiler/rustc_target/src/spec/targets/amdgcn_amd_amdhsa.rs b/compiler/rustc_target/src/spec/targets/amdgcn_amd_amdhsa.rs index 07772c7573377..d80a3ffd0c7fd 100644 --- a/compiler/rustc_target/src/spec/targets/amdgcn_amd_amdhsa.rs +++ b/compiler/rustc_target/src/spec/targets/amdgcn_amd_amdhsa.rs @@ -34,6 +34,9 @@ pub(crate) fn target() -> Target { no_builtins: true, simd_types_indirect: false, + // Clearly a GPU + is_like_gpu: true, + // Allow `cdylib` crate type. dynamic_linking: true, only_cdylib: true, diff --git a/compiler/rustc_target/src/spec/targets/nvptx64_nvidia_cuda.rs b/compiler/rustc_target/src/spec/targets/nvptx64_nvidia_cuda.rs index ac2d31a0d61aa..5bbf40b5fadd7 100644 --- a/compiler/rustc_target/src/spec/targets/nvptx64_nvidia_cuda.rs +++ b/compiler/rustc_target/src/spec/targets/nvptx64_nvidia_cuda.rs @@ -42,6 +42,9 @@ pub(crate) fn target() -> Target { // Let the `ptx-linker` to handle LLVM lowering into MC / assembly. obj_is_bitcode: true, + // Clearly a GPU + is_like_gpu: true, + // Convenient and predicable naming scheme. dll_prefix: "".into(), dll_suffix: ".ptx".into(),