Skip to content

Commit 0637c67

Browse files
committed
upgrade offload dyn_ptr handling from C++ to mostly safe Rust
1 parent a31e7cd commit 0637c67

File tree

4 files changed

+68
-42
lines changed

4 files changed

+68
-42
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use crate::errors::{
4545
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
4646
use crate::llvm::{self, DiagnosticInfo};
4747
use crate::type_::llvm_type_ptr;
48-
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
48+
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, base, common, llvm_util};
4949

5050
pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> ! {
5151
match llvm::last_error() {
@@ -658,19 +658,70 @@ pub(crate) unsafe fn llvm_optimize(
658658
None
659659
};
660660

661-
fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) {
662-
unsafe { llvm::LLVMRustOffloadWrapper(m, old_fn) };
661+
fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) {
662+
let old_fn_ty = cx.get_type_of_global(old_fn);
663+
let old_param_types = cx.func_params_types(old_fn_ty);
664+
let old_param_count = old_param_types.len();
665+
if old_param_count == 0 {
666+
return;
667+
}
668+
669+
let first_param = llvm::get_param(old_fn, 0);
670+
let c_name = llvm::get_value_name(first_param);
671+
let first_arg_name = str::from_utf8(&c_name).unwrap();
672+
// We might call llvm_optimize (and thus this code) multiple times on the same IR,
673+
// but we shouldn't add this helper ptr multiple times.
674+
if first_arg_name == "dyn_ptr" {
675+
return;
676+
}
677+
678+
// Create the new parameter list, with ptr as the first argument
679+
let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
680+
new_param_types.push(cx.type_ptr());
681+
for old_param in old_param_types {
682+
new_param_types.push(old_param);
683+
}
684+
685+
// Create the new function type
686+
let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };
687+
let new_fn_ty = cx.type_func(&new_param_types, ret_ty);
688+
689+
// Create the new function, with a temporary .offload name to avoid a name collision.
690+
let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
691+
let new_fn_name = format!("{}.offload", &old_fn_name);
692+
let new_fn = cx.add_func(&new_fn_name, new_fn_ty);
693+
let a0 = llvm::get_param(new_fn, 0);
694+
llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes());
695+
696+
// Here we map the old arguments to the new arguments, with an offset of 1 to make sure
697+
// that we don't use the newly added `%dyn_ptr`.
698+
unsafe {
699+
llvm::LLVMRustOffloadMapper(cx.llmod(), old_fn, new_fn);
700+
}
701+
702+
llvm::set_linkage(new_fn, llvm::get_linkage(old_fn));
703+
llvm::set_visibility(new_fn, llvm::get_visibility(old_fn));
704+
705+
// Replace all uses of old_fn with new_fn (RAUW)
706+
unsafe {
707+
llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
708+
}
709+
let name = llvm::get_value_name(old_fn);
710+
unsafe {
711+
llvm::LLVMDeleteFunction(old_fn);
712+
}
713+
// Now we can re-use the old name, without name collision.
714+
llvm::set_value_name(new_fn, &name);
663715
}
664716

665717
let consider_offload = config.offload.contains(&config::Offload::Enable);
666718
if consider_offload && (cgcx.target_arch == "amdgpu" || cgcx.target_arch == "nvptx64") {
719+
let cx =
720+
SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
667721
for num in 0..9 {
668722
let name = format!("kernel_{num}");
669-
let c_name = CString::new(name).unwrap();
670-
if let Some(kernel) =
671-
unsafe { llvm::LLVMGetNamedFunction(module.module_llvm.llmod(), c_name.as_ptr()) }
672-
{
673-
handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel);
723+
if let Some(kernel) = cx.get_function(&name) {
724+
handle_offload(&cx, kernel);
674725
}
675726
}
676727
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,6 @@ unsafe extern "C" {
11401140

11411141
// Operations on basic blocks
11421142
pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value;
1143-
pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock);
11441143
pub(crate) fn LLVMAppendBasicBlockInContext<'a>(
11451144
C: &'a Context,
11461145
Fn: &'a Value,
@@ -2013,7 +2012,7 @@ unsafe extern "C" {
20132012
) -> &Attribute;
20142013

20152014
// Operations on functions
2016-
pub(crate) fn LLVMRustOffloadWrapper<'a>(M: &'a Module, Fn: &'a Value);
2015+
pub(crate) fn LLVMRustOffloadMapper<'a>(M: &'a Module, Fn: &'a Value, Fn: &'a Value);
20172016
pub(crate) fn LLVMRustGetOrInsertFunction<'a>(
20182017
M: &'a Module,
20192018
Name: *const c_char,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
6868
unsafe { llvm::LLVMVectorType(ty, len as c_uint) }
6969
}
7070

71+
pub(crate) fn add_func(&self, name: &str, ty: &'ll Type) -> &'ll Value {
72+
let name = SmallCStr::new(name);
73+
unsafe { llvm::LLVMAddFunction(self.llmod(), name.as_ptr(), ty) }
74+
}
75+
7176
pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
7277
unsafe {
7378
let n_args = llvm::LLVMCountParamTypes(ty) as usize;

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -144,31 +144,10 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) {
144144
llvm::PrintStatistics(OS);
145145
}
146146

147-
extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
147+
extern "C" void LLVMRustOffloadMapper(LLVMModuleRef M, LLVMValueRef OldFn, LLVMValueRef NewFn) {
148148
llvm::Module *module = llvm::unwrap(M);
149-
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(Fn);
150-
151-
if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") {
152-
return;
153-
}
154-
155-
// 1. Create new function type with the leading extra %dyn_ptr arg which llvm
156-
// offload requries.
157-
llvm::LLVMContext &ctx = module->getContext();
158-
llvm::Type *dynPtrType = llvm::PointerType::get(ctx, 0);
159-
std::vector<llvm::Type *> argTypes;
160-
argTypes.push_back(dynPtrType);
161-
162-
for (auto &arg : oldFn->args()) {
163-
argTypes.push_back(arg.getType());
164-
}
165-
166-
llvm::FunctionType *newFnType = llvm::FunctionType::get(
167-
oldFn->getReturnType(), argTypes, oldFn->isVarArg());
168-
169-
// use a temporary .offload appendix to avoid name clashes
170-
llvm::Function *newFn = llvm::Function::Create(
171-
newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module);
149+
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(OldFn);
150+
llvm::Function *newFn = llvm::unwrap<llvm::Function>(NewFn);
172151

173152
// Map old arguments to new arguments. We skip the first dyn_ptr argument,
174153
// since it can't be used directly by user code.
@@ -184,14 +163,6 @@ extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
184163
llvm::CloneFunctionInto(newFn, oldFn, vmap,
185164
llvm::CloneFunctionChangeType::LocalChangesOnly,
186165
returns);
187-
newFn->setLinkage(oldFn->getLinkage());
188-
newFn->setVisibility(oldFn->getVisibility());
189-
190-
// Replace uses, delete old function, and reset name to the original one.
191-
oldFn->replaceAllUsesWith(newFn);
192-
auto name = oldFn->getName();
193-
oldFn->eraseFromParent();
194-
newFn->setName(name);
195166
}
196167

197168
extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,

0 commit comments

Comments
 (0)