Skip to content

Commit dd5af93

Browse files
committed
upgrade offload dyn_ptr handling from C++ to mostly safe Rust
1 parent cdbbe9c commit dd5af93

File tree

4 files changed

+70
-106
lines changed

4 files changed

+70
-106
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 61 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use crate::errors::{
4444
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
4545
use crate::llvm::{self, DiagnosticInfo};
4646
use crate::type_::Type;
47-
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
47+
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, base, common, llvm_util};
4848

4949
pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> FatalError {
5050
match llvm::last_error() {
@@ -653,83 +653,72 @@ pub(crate) unsafe fn llvm_optimize(
653653
None
654654
};
655655

656-
fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) {
657-
unsafe { llvm::LLVMRustOffloadWrapper(m, old_fn) };
658-
//unsafe {llvm::LLVMDumpModule(m);}
659-
//unsafe {
660-
// // Get the old function type
661-
// let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
662-
// dbg!(&old_fn_ty);
663-
// let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
664-
// dbg!(&old_param_count);
665-
666-
// // Get the old parameter types
667-
// let mut old_param_types = Vec::with_capacity(old_param_count as usize);
668-
// llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
669-
// old_param_types.set_len(old_param_count as usize);
670-
671-
// // Create the new parameter list, with ptr as the first argument
672-
// let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
673-
// let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
674-
// new_param_types.push(ptr_ty);
675-
// for old_param in old_param_types {
676-
// new_param_types.push(old_param);
677-
// }
678-
// dbg!(&new_param_types);
679-
680-
// // Create the new function type
681-
// let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
682-
// let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
683-
// dbg!(&new_fn_ty);
684-
685-
// // Create the new function
686-
// let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
687-
// //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
688-
// let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
689-
// let new_fn_cstr = CString::new(new_fn_name).unwrap();
690-
// let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
691-
// dbg!(&new_fn);
692-
// let a0 = llvm::LLVMGetParam(new_fn, 0);
693-
// llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
694-
// dbg!(&new_fn);
695-
696-
// // Move basic blocks
697-
// let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
698-
// //dbg!(&bb);
699-
// llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
700-
// //while !bb.is_null() {
701-
// // let next = llvm::LLVMGetNextBasicBlock(bb);
702-
// // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
703-
// // bb = next;
704-
// //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
705-
// let old_n = llvm::LLVMCountParams(old_fn);
706-
// for i in 0..old_n {
707-
// let old_arg = llvm::LLVMGetParam(old_fn, i);
708-
// let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
709-
// llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
710-
// }
711-
712-
// // Copy linkage and visibility
713-
// //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
714-
// //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
715-
716-
// // Replace all uses of old_fn with new_fn (RAUW)
717-
// llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
718-
719-
// // Optionally, remove the old function
720-
// llvm::LLVMDeleteFunction(old_fn);
721-
//}
656+
fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) {
657+
{
658+
let old_fn_ty = cx.get_type_of_global(old_fn);
659+
let old_param_types = cx.func_params_types(old_fn_ty);
660+
let old_param_count = old_param_types.len();
661+
if old_param_count == 0 {
662+
return;
663+
}
664+
665+
let first_param = llvm::get_param(old_fn, 0);
666+
let c_name = llvm::get_value_name(first_param);
667+
let first_arg_name = str::from_utf8(&c_name).unwrap();
668+
// We might call llvm_optimize (and thus this code) multiple times on the same IR,
669+
// but we shouldn't add this helper ptr multiple times.
670+
if first_arg_name == "dyn_ptr" {
671+
return;
672+
}
673+
674+
// Create the new parameter list, with ptr as the first argument
675+
let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
676+
new_param_types.push(cx.type_ptr());
677+
for old_param in old_param_types {
678+
new_param_types.push(old_param);
679+
}
680+
681+
// Create the new function type
682+
let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };
683+
let new_fn_ty = cx.type_func(&new_param_types, ret_ty);
684+
685+
// Create the new function, with a temporary .offload name to avoid a name collision.
686+
let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
687+
let new_fn_name = format!("{}.offload", &old_fn_name);
688+
let new_fn = cx.add_func(&new_fn_name, new_fn_ty);
689+
let a0 = llvm::get_param(new_fn, 0);
690+
llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes());
691+
692+
// Here we map the old arguments to the new arguments, with an offset of 1 to make sure
693+
// that we don't use the newly added `%dyn_ptr`.
694+
unsafe {
695+
llvm::LLVMRustOffloadMapper(cx.llmod(), old_fn, new_fn);
696+
}
697+
698+
llvm::set_linkage(new_fn, llvm::get_linkage(old_fn));
699+
llvm::set_visibility(new_fn, llvm::get_visibility(old_fn));
700+
701+
// Replace all uses of old_fn with new_fn (RAUW)
702+
unsafe {
703+
llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
704+
}
705+
let name = llvm::get_value_name(old_fn);
706+
unsafe {
707+
llvm::LLVMDeleteFunction(old_fn);
708+
}
709+
// Now we can re-use the old name, without name collision.
710+
llvm::set_value_name(new_fn, &name);
711+
}
722712
}
723713

724714
let consider_offload = config.offload.contains(&config::Offload::Enable);
725715
if consider_offload && (cgcx.target_arch == "amdgpu" || cgcx.target_arch == "nvptx64") {
716+
let cx =
717+
SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
726718
for num in 0..9 {
727719
let name = format!("kernel_{num}");
728-
let c_name = CString::new(name).unwrap();
729-
if let Some(kernel) =
730-
unsafe { llvm::LLVMGetNamedFunction(module.module_llvm.llmod(), c_name.as_ptr()) }
731-
{
732-
handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel);
720+
if let Some(kernel) = cx.get_function(&name) {
721+
handle_offload(&cx, kernel);
733722
}
734723
}
735724
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,6 @@ unsafe extern "C" {
12241224

12251225
// Operations on basic blocks
12261226
pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value;
1227-
pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock);
12281227
pub(crate) fn LLVMAppendBasicBlockInContext<'a>(
12291228
C: &'a Context,
12301229
Fn: &'a Value,
@@ -1898,7 +1897,7 @@ unsafe extern "C" {
18981897
) -> &Attribute;
18991898

19001899
// Operations on functions
1901-
pub(crate) fn LLVMRustOffloadWrapper<'a>(M: &'a Module, Fn: &'a Value);
1900+
pub(crate) fn LLVMRustOffloadMapper<'a>(M: &'a Module, Fn: &'a Value, Fn: &'a Value);
19021901
pub(crate) fn LLVMRustGetOrInsertFunction<'a>(
19031902
M: &'a Module,
19041903
Name: *const c_char,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
6868
unsafe { llvm::LLVMVectorType(ty, len as c_uint) }
6969
}
7070

71+
pub(crate) fn add_func(&self, name: &str, ty: &'ll Type) -> &'ll Value {
72+
let name = SmallCStr::new(name);
73+
unsafe { llvm::LLVMAddFunction(self.llmod(), name.as_ptr(), ty) }
74+
}
75+
7176
pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
7277
unsafe {
7378
let n_args = llvm::LLVMCountParamTypes(ty) as usize;

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -172,31 +172,10 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) {
172172
llvm::PrintStatistics(OS);
173173
}
174174

175-
extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
175+
extern "C" void LLVMRustOffloadMapper(LLVMModuleRef M, LLVMValueRef OldFn, LLVMValueRef NewFn) {
176176
llvm::Module *module = llvm::unwrap(M);
177-
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(Fn);
178-
179-
if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") {
180-
return;
181-
}
182-
183-
// 1. Create new function type with the leading extra %dyn_ptr arg which llvm
184-
// offload requries.
185-
llvm::LLVMContext &ctx = module->getContext();
186-
llvm::Type *dynPtrType = llvm::PointerType::get(ctx, 0);
187-
std::vector<llvm::Type *> argTypes;
188-
argTypes.push_back(dynPtrType);
189-
190-
for (auto &arg : oldFn->args()) {
191-
argTypes.push_back(arg.getType());
192-
}
193-
194-
llvm::FunctionType *newFnType = llvm::FunctionType::get(
195-
oldFn->getReturnType(), argTypes, oldFn->isVarArg());
196-
197-
// use a temporary .offload appendix to avoid name clashes
198-
llvm::Function *newFn = llvm::Function::Create(
199-
newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module);
177+
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(OldFn);
178+
llvm::Function *newFn = llvm::unwrap<llvm::Function>(NewFn);
200179

201180
// Map old arguments to new arguments. We skip the first dyn_ptr argument,
202181
// since it can't be used directly by user code.
@@ -212,14 +191,6 @@ extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
212191
llvm::CloneFunctionInto(newFn, oldFn, vmap,
213192
llvm::CloneFunctionChangeType::LocalChangesOnly,
214193
returns);
215-
newFn->setLinkage(oldFn->getLinkage());
216-
newFn->setVisibility(oldFn->getVisibility());
217-
218-
// Replace uses, delete old function, and reset name to the original one.
219-
oldFn->replaceAllUsesWith(newFn);
220-
auto name = oldFn->getName();
221-
oldFn->eraseFromParent();
222-
newFn->setName(name);
223194
}
224195

225196
extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,

0 commit comments

Comments
 (0)