Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,8 @@ pub(crate) fn run_pass_manager(
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage);
}

if enable_gpu && !thin {
// Here we only handle the GPU host (=cpu) code.
if enable_gpu && !thin && !cgcx.target_is_like_gpu {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);
Expand Down
70 changes: 69 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::errors::{
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
use crate::llvm::{self, DiagnosticInfo};
use crate::type_::llvm_type_ptr;
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, base, common, llvm_util};

pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> ! {
match llvm::last_error() {
Expand Down Expand Up @@ -645,6 +645,74 @@ pub(crate) unsafe fn llvm_optimize(
None
};

fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) {
let old_fn_ty = cx.get_type_of_global(old_fn);
let old_param_types = cx.func_params_types(old_fn_ty);
let old_param_count = old_param_types.len();
if old_param_count == 0 {
return;
}

let first_param = llvm::get_param(old_fn, 0);
let c_name = llvm::get_value_name(first_param);
let first_arg_name = str::from_utf8(&c_name).unwrap();
// We might call llvm_optimize (and thus this code) multiple times on the same IR,
// but we shouldn't add this helper ptr multiple times.
// FIXME(offload): This could break if the user calls his first argument `dyn_ptr`.
if first_arg_name == "dyn_ptr" {
return;
}

// Create the new parameter list, with ptr as the first argument
let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
new_param_types.push(cx.type_ptr());
new_param_types.extend(old_param_types);

// Create the new function type
let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };
let new_fn_ty = cx.type_func(&new_param_types, ret_ty);

// Create the new function, with a temporary .offload name to avoid a name collision.
let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
let new_fn_name = format!("{}.offload", &old_fn_name);
let new_fn = cx.add_func(&new_fn_name, new_fn_ty);
let a0 = llvm::get_param(new_fn, 0);
llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes());

// Here we map the old arguments to the new arguments, with an offset of 1 to make sure
// that we don't use the newly added `%dyn_ptr`.
unsafe {
llvm::LLVMRustOffloadMapper(cx.llmod(), old_fn, new_fn);
}

llvm::set_linkage(new_fn, llvm::get_linkage(old_fn));
llvm::set_visibility(new_fn, llvm::get_visibility(old_fn));

// Replace all uses of old_fn with new_fn (RAUW)
unsafe {
llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
}
let name = llvm::get_value_name(old_fn);
unsafe {
llvm::LLVMDeleteFunction(old_fn);
}
// Now we can re-use the old name, without name collision.
llvm::set_value_name(new_fn, &name);
}

if cgcx.target_is_like_gpu && config.offload.contains(&config::Offload::Enable) {
let cx =
SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
// For now we only support up to 10 kernels named kernel_0 ... kernel_9, a follow-up PR is
// introducing a proper offload intrinsic to solve this limitation.
for num in 0..9 {
let name = format!("kernel_{num}");
if let Some(kernel) = cx.get_function(&name) {
handle_offload(&cx, kernel);
}
}
}

let mut llvm_profiler = cgcx
.prof
.llvm_recording_enabled()
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pub(crate) fn handle_gpu_code<'ll>(
let mut memtransfer_types = vec![];
let mut region_ids = vec![];
let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
// This is a temporary hack, we only search for kernel_0 to kernel_9 functions.
// There is a draft PR in progress which will introduce a proper offload intrinsic to remove
// this limitation.
for num in 0..9 {
let kernel = cx.get_function(&format!("kernel_{num}"));
if let Some(kernel) = kernel {
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,11 @@ unsafe extern "C" {

// Operations on functions
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
pub(crate) fn LLVMAddFunction<'a>(
Mod: &'a Module,
Name: *const c_char,
FunctionTy: &'a Type,
) -> &'a Value;
pub(crate) fn LLVMDeleteFunction(Fn: &Value);

// Operations about llvm intrinsics
Expand Down Expand Up @@ -2017,6 +2022,7 @@ unsafe extern "C" {
) -> &Attribute;

// Operations on functions
pub(crate) fn LLVMRustOffloadMapper<'a>(M: &'a Module, Fn: &'a Value, Fn: &'a Value);
pub(crate) fn LLVMRustGetOrInsertFunction<'a>(
M: &'a Module,
Name: *const c_char,
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_codegen_llvm/src/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
unsafe { llvm::LLVMVectorType(ty, len as c_uint) }
}

pub(crate) fn add_func(&self, name: &str, ty: &'ll Type) -> &'ll Value {
let name = SmallCStr::new(name);
unsafe { llvm::LLVMAddFunction(self.llmod(), name.as_ptr(), ty) }
}

pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
unsafe {
let n_args = llvm::LLVMCountParamTypes(ty) as usize;
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ pub struct CodegenContext<B: WriteBackendMethods> {
pub target_arch: String,
pub target_is_like_darwin: bool,
pub target_is_like_aix: bool,
pub target_is_like_gpu: bool,
pub split_debuginfo: rustc_target::spec::SplitDebuginfo,
pub split_dwarf_kind: rustc_session::config::SplitDwarfKind,
pub pointer_size: Size,
Expand Down Expand Up @@ -1309,6 +1310,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
target_arch: tcx.sess.target.arch.to_string(),
target_is_like_darwin: tcx.sess.target.is_like_darwin,
target_is_like_aix: tcx.sess.target.is_like_aix,
target_is_like_gpu: tcx.sess.target.is_like_gpu,
split_debuginfo: tcx.sess.split_debuginfo(),
split_dwarf_kind: tcx.sess.opts.unstable_opts.split_dwarf_kind,
parallel: backend.supports_parallel() && !sess.opts.unstable_opts.no_parallel_backend,
Expand Down
24 changes: 24 additions & 0 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "llvm/Support/Signals.h"
#include "llvm/Support/Timer.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <iostream>

// for raw `write` in the bad-alloc handler
Expand Down Expand Up @@ -142,6 +144,28 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) {
llvm::PrintStatistics(OS);
}

extern "C" void LLVMRustOffloadMapper(LLVMModuleRef M, LLVMValueRef OldFn,
LLVMValueRef NewFn) {
llvm::Module *module = llvm::unwrap(M);
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(OldFn);
llvm::Function *newFn = llvm::unwrap<llvm::Function>(NewFn);

// Map old arguments to new arguments. We skip the first dyn_ptr argument,
// since it can't be used directly by user code.
llvm::ValueToValueMapTy vmap;
auto newArgIt = newFn->arg_begin();
newArgIt->setName("dyn_ptr");
++newArgIt; // skip %dyn_ptr
for (auto &oldArg : oldFn->args()) {
vmap[&oldArg] = &*newArgIt++;
}

llvm::SmallVector<llvm::ReturnInst *, 8> returns;
llvm::CloneFunctionInto(newFn, oldFn, vmap,
llvm::CloneFunctionChangeType::LocalChangesOnly,
returns);
}

extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
size_t NameLen) {
return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_target/src/callconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ impl RiscvInterruptKind {
///
/// The signature represented by this type may not match the MIR function signature.
/// Certain attributes, like `#[track_caller]` can introduce additional arguments, which are present in [`FnAbi`], but not in `FnSig`.
/// The std::offload module also adds an addition dyn_ptr argument to the GpuKernel ABI.
/// While this difference is rarely relevant, it should still be kept in mind.
///
/// I will do my best to describe this structure, but these
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_target/src/spec/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ impl Target {
forward!(is_like_darwin);
forward!(is_like_solaris);
forward!(is_like_windows);
forward!(is_like_gpu);
forward!(is_like_msvc);
forward!(is_like_wasm);
forward!(is_like_android);
Expand Down Expand Up @@ -337,6 +338,7 @@ impl ToJson for Target {
target_option_val!(is_like_darwin);
target_option_val!(is_like_solaris);
target_option_val!(is_like_windows);
target_option_val!(is_like_gpu);
target_option_val!(is_like_msvc);
target_option_val!(is_like_wasm);
target_option_val!(is_like_android);
Expand Down Expand Up @@ -556,6 +558,7 @@ struct TargetSpecJson {
is_like_darwin: Option<bool>,
is_like_solaris: Option<bool>,
is_like_windows: Option<bool>,
is_like_gpu: Option<bool>,
is_like_msvc: Option<bool>,
is_like_wasm: Option<bool>,
is_like_android: Option<bool>,
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_target/src/spec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2180,6 +2180,8 @@ pub struct TargetOptions {
/// Also indicates whether to use Apple-specific ABI changes, such as extending function
/// parameters to 32-bits.
pub is_like_darwin: bool,
/// Whether the target is a GPU (e.g. NVIDIA, AMD, Intel).
pub is_like_gpu: bool,
/// Whether the target toolchain is like Solaris's.
/// Only useful for compiling against Illumos/Solaris,
/// as they have a different set of linker flags. Defaults to false.
Expand Down Expand Up @@ -2583,6 +2585,7 @@ impl Default for TargetOptions {
abi_return_struct_as_int: false,
is_like_aix: false,
is_like_darwin: false,
is_like_gpu: false,
is_like_solaris: false,
is_like_windows: false,
is_like_msvc: false,
Expand Down Expand Up @@ -2748,6 +2751,11 @@ impl Target {
self.os == "solaris" || self.os == "illumos",
"`is_like_solaris` must be set if and only if `os` is `solaris` or `illumos`"
);
check_eq!(
self.is_like_gpu,
self.arch == Arch::Nvptx64 || self.arch == Arch::AmdGpu,
"`is_like_gpu` must be set if and only if `target` is `nvptx64` or `amdgcn`"
);
check_eq!(
self.is_like_windows,
self.os == "windows" || self.os == "uefi" || self.os == "cygwin",
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_target/src/spec/targets/amdgcn_amd_amdhsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pub(crate) fn target() -> Target {
no_builtins: true,
simd_types_indirect: false,

// Clearly a GPU
is_like_gpu: true,

// Allow `cdylib` crate type.
dynamic_linking: true,
only_cdylib: true,
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_target/src/spec/targets/nvptx64_nvidia_cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ pub(crate) fn target() -> Target {
// Let the `ptx-linker` to handle LLVM lowering into MC / assembly.
obj_is_bitcode: true,

// Clearly a GPU
is_like_gpu: true,

// Convenient and predicable naming scheme.
dll_prefix: "".into(),
dll_suffix: ".ptx".into(),
Expand Down
Loading