Skip to content

Commit ee30fc9

Browse files
committed
upgrade offload dyn_ptr handling from C++ to mostly safe Rust
1 parent 8a40683 commit ee30fc9

File tree

4 files changed

+70
-106
lines changed

4 files changed

+70
-106
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 61 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use crate::errors::{
4545
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
4646
use crate::llvm::{self, DiagnosticInfo};
4747
use crate::type_::llvm_type_ptr;
48-
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
48+
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, base, common, llvm_util};
4949

5050
pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> ! {
5151
match llvm::last_error() {
@@ -658,83 +658,72 @@ pub(crate) unsafe fn llvm_optimize(
658658
None
659659
};
660660

661-
fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) {
662-
unsafe { llvm::LLVMRustOffloadWrapper(m, old_fn) };
663-
//unsafe {llvm::LLVMDumpModule(m);}
664-
//unsafe {
665-
// // Get the old function type
666-
// let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
667-
// dbg!(&old_fn_ty);
668-
// let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
669-
// dbg!(&old_param_count);
670-
671-
// // Get the old parameter types
672-
// let mut old_param_types = Vec::with_capacity(old_param_count as usize);
673-
// llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
674-
// old_param_types.set_len(old_param_count as usize);
675-
676-
// // Create the new parameter list, with ptr as the first argument
677-
// let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
678-
// let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
679-
// new_param_types.push(ptr_ty);
680-
// for old_param in old_param_types {
681-
// new_param_types.push(old_param);
682-
// }
683-
// dbg!(&new_param_types);
684-
685-
// // Create the new function type
686-
// let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
687-
// let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
688-
// dbg!(&new_fn_ty);
689-
690-
// // Create the new function
691-
// let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
692-
// //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
693-
// let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
694-
// let new_fn_cstr = CString::new(new_fn_name).unwrap();
695-
// let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
696-
// dbg!(&new_fn);
697-
// let a0 = llvm::LLVMGetParam(new_fn, 0);
698-
// llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
699-
// dbg!(&new_fn);
700-
701-
// // Move basic blocks
702-
// let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
703-
// //dbg!(&bb);
704-
// llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
705-
// //while !bb.is_null() {
706-
// // let next = llvm::LLVMGetNextBasicBlock(bb);
707-
// // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
708-
// // bb = next;
709-
// //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
710-
// let old_n = llvm::LLVMCountParams(old_fn);
711-
// for i in 0..old_n {
712-
// let old_arg = llvm::LLVMGetParam(old_fn, i);
713-
// let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
714-
// llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
715-
// }
716-
717-
// // Copy linkage and visibility
718-
// //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
719-
// //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
720-
721-
// // Replace all uses of old_fn with new_fn (RAUW)
722-
// llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
723-
724-
// // Optionally, remove the old function
725-
// llvm::LLVMDeleteFunction(old_fn);
726-
//}
661+
fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) {
662+
{
663+
let old_fn_ty = cx.get_type_of_global(old_fn);
664+
let old_param_types = cx.func_params_types(old_fn_ty);
665+
let old_param_count = old_param_types.len();
666+
if old_param_count == 0 {
667+
return;
668+
}
669+
670+
let first_param = llvm::get_param(old_fn, 0);
671+
let c_name = llvm::get_value_name(first_param);
672+
let first_arg_name = str::from_utf8(&c_name).unwrap();
673+
// We might call llvm_optimize (and thus this code) multiple times on the same IR,
674+
// but we shouldn't add this helper ptr multiple times.
675+
if first_arg_name == "dyn_ptr" {
676+
return;
677+
}
678+
679+
// Create the new parameter list, with ptr as the first argument
680+
let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
681+
new_param_types.push(cx.type_ptr());
682+
for old_param in old_param_types {
683+
new_param_types.push(old_param);
684+
}
685+
686+
// Create the new function type
687+
let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };
688+
let new_fn_ty = cx.type_func(&new_param_types, ret_ty);
689+
690+
// Create the new function, with a temporary .offload name to avoid a name collision.
691+
let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
692+
let new_fn_name = format!("{}.offload", &old_fn_name);
693+
let new_fn = cx.add_func(&new_fn_name, new_fn_ty);
694+
let a0 = llvm::get_param(new_fn, 0);
695+
llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes());
696+
697+
// Here we map the old arguments to the new arguments, with an offset of 1 to make sure
698+
// that we don't use the newly added `%dyn_ptr`.
699+
unsafe {
700+
llvm::LLVMRustOffloadMapper(cx.llmod(), old_fn, new_fn);
701+
}
702+
703+
llvm::set_linkage(new_fn, llvm::get_linkage(old_fn));
704+
llvm::set_visibility(new_fn, llvm::get_visibility(old_fn));
705+
706+
// Replace all uses of old_fn with new_fn (RAUW)
707+
unsafe {
708+
llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
709+
}
710+
let name = llvm::get_value_name(old_fn);
711+
unsafe {
712+
llvm::LLVMDeleteFunction(old_fn);
713+
}
714+
// Now we can re-use the old name, without name collision.
715+
llvm::set_value_name(new_fn, &name);
716+
}
727717
}
728718

729719
let consider_offload = config.offload.contains(&config::Offload::Enable);
730720
if consider_offload && (cgcx.target_arch == "amdgpu" || cgcx.target_arch == "nvptx64") {
721+
let cx =
722+
SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
731723
for num in 0..9 {
732724
let name = format!("kernel_{num}");
733-
let c_name = CString::new(name).unwrap();
734-
if let Some(kernel) =
735-
unsafe { llvm::LLVMGetNamedFunction(module.module_llvm.llmod(), c_name.as_ptr()) }
736-
{
737-
handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel);
725+
if let Some(kernel) = cx.get_function(&name) {
726+
handle_offload(&cx, kernel);
738727
}
739728
}
740729
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,6 @@ unsafe extern "C" {
11401140

11411141
// Operations on basic blocks
11421142
pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value;
1143-
pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock);
11441143
pub(crate) fn LLVMAppendBasicBlockInContext<'a>(
11451144
C: &'a Context,
11461145
Fn: &'a Value,
@@ -2013,7 +2012,7 @@ unsafe extern "C" {
20132012
) -> &Attribute;
20142013

20152014
// Operations on functions
2016-
pub(crate) fn LLVMRustOffloadWrapper<'a>(M: &'a Module, Fn: &'a Value);
2015+
pub(crate) fn LLVMRustOffloadMapper<'a>(M: &'a Module, Fn: &'a Value, Fn: &'a Value);
20172016
pub(crate) fn LLVMRustGetOrInsertFunction<'a>(
20182017
M: &'a Module,
20192018
Name: *const c_char,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
6868
unsafe { llvm::LLVMVectorType(ty, len as c_uint) }
6969
}
7070

71+
pub(crate) fn add_func(&self, name: &str, ty: &'ll Type) -> &'ll Value {
72+
let name = SmallCStr::new(name);
73+
unsafe { llvm::LLVMAddFunction(self.llmod(), name.as_ptr(), ty) }
74+
}
75+
7176
pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
7277
unsafe {
7378
let n_args = llvm::LLVMCountParamTypes(ty) as usize;

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -144,31 +144,10 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) {
144144
llvm::PrintStatistics(OS);
145145
}
146146

147-
extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
147+
extern "C" void LLVMRustOffloadMapper(LLVMModuleRef M, LLVMValueRef OldFn, LLVMValueRef NewFn) {
148148
llvm::Module *module = llvm::unwrap(M);
149-
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(Fn);
150-
151-
if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") {
152-
return;
153-
}
154-
155-
// 1. Create new function type with the leading extra %dyn_ptr arg which llvm
156-
// offload requries.
157-
llvm::LLVMContext &ctx = module->getContext();
158-
llvm::Type *dynPtrType = llvm::PointerType::get(ctx, 0);
159-
std::vector<llvm::Type *> argTypes;
160-
argTypes.push_back(dynPtrType);
161-
162-
for (auto &arg : oldFn->args()) {
163-
argTypes.push_back(arg.getType());
164-
}
165-
166-
llvm::FunctionType *newFnType = llvm::FunctionType::get(
167-
oldFn->getReturnType(), argTypes, oldFn->isVarArg());
168-
169-
// use a temporary .offload appendix to avoid name clashes
170-
llvm::Function *newFn = llvm::Function::Create(
171-
newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module);
149+
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(OldFn);
150+
llvm::Function *newFn = llvm::unwrap<llvm::Function>(NewFn);
172151

173152
// Map old arguments to new arguments. We skip the first dyn_ptr argument,
174153
// since it can't be used directly by user code.
@@ -184,14 +163,6 @@ extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
184163
llvm::CloneFunctionInto(newFn, oldFn, vmap,
185164
llvm::CloneFunctionChangeType::LocalChangesOnly,
186165
returns);
187-
newFn->setLinkage(oldFn->getLinkage());
188-
newFn->setVisibility(oldFn->getVisibility());
189-
190-
// Replace uses, delete old function, and reset name to the original one.
191-
oldFn->replaceAllUsesWith(newFn);
192-
auto name = oldFn->getName();
193-
oldFn->eraseFromParent();
194-
newFn->setName(name);
195166
}
196167

197168
extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,

0 commit comments

Comments
 (0)