Skip to content

Commit a31e7cd

Browse files
committed
fix device code generation
1 parent 4068baf commit a31e7cd

File tree

5 files changed

+78
-1
lines changed

5 files changed

+78
-1
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ pub(crate) fn run_pass_manager(
613613
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage);
614614
}
615615

616-
if enable_gpu && !thin {
616+
if enable_gpu && !thin && !(cgcx.target_arch == "nvptx64" || cgcx.target_arch == "amdgpu") {
617617
let cx =
618618
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
619619
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,23 @@ pub(crate) unsafe fn llvm_optimize(
658658
None
659659
};
660660

661+
fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) {
662+
unsafe { llvm::LLVMRustOffloadWrapper(m, old_fn) };
663+
}
664+
665+
let consider_offload = config.offload.contains(&config::Offload::Enable);
666+
if consider_offload && (cgcx.target_arch == "amdgpu" || cgcx.target_arch == "nvptx64") {
667+
for num in 0..9 {
668+
let name = format!("kernel_{num}");
669+
let c_name = CString::new(name).unwrap();
670+
if let Some(kernel) =
671+
unsafe { llvm::LLVMGetNamedFunction(module.module_llvm.llmod(), c_name.as_ptr()) }
672+
{
673+
handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel);
674+
}
675+
}
676+
}
677+
661678
let mut llvm_profiler = cgcx
662679
.prof
663680
.llvm_recording_enabled()

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,11 @@ unsafe extern "C" {
11171117

11181118
// Operations on functions
11191119
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
1120+
pub(crate) fn LLVMAddFunction<'a>(
1121+
Mod: &'a Module,
1122+
Name: *const c_char,
1123+
FunctionTy: &'a Type,
1124+
) -> &'a Value;
11201125
pub(crate) fn LLVMDeleteFunction(Fn: &Value);
11211126

11221127
// Operations about llvm intrinsics
@@ -1135,6 +1140,7 @@ unsafe extern "C" {
11351140

11361141
// Operations on basic blocks
11371142
pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value;
1143+
pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock);
11381144
pub(crate) fn LLVMAppendBasicBlockInContext<'a>(
11391145
C: &'a Context,
11401146
Fn: &'a Value,
@@ -2007,6 +2013,7 @@ unsafe extern "C" {
20072013
) -> &Attribute;
20082014

20092015
// Operations on functions
2016+
pub(crate) fn LLVMRustOffloadWrapper<'a>(M: &'a Module, Fn: &'a Value);
20102017
pub(crate) fn LLVMRustGetOrInsertFunction<'a>(
20112018
M: &'a Module,
20122019
Name: *const c_char,

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "llvm/Support/Signals.h"
3636
#include "llvm/Support/Timer.h"
3737
#include "llvm/Support/ToolOutputFile.h"
38+
#include "llvm/Transforms/Utils/Cloning.h"
39+
#include "llvm/Transforms/Utils/ValueMapper.h"
3840
#include <iostream>
3941

4042
// for raw `write` in the bad-alloc handler
@@ -142,6 +144,56 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) {
142144
llvm::PrintStatistics(OS);
143145
}
144146

147+
extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
148+
llvm::Module *module = llvm::unwrap(M);
149+
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(Fn);
150+
151+
if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") {
152+
return;
153+
}
154+
155+
// 1. Create new function type with the leading extra %dyn_ptr arg which llvm
156+
// offload requries.
157+
llvm::LLVMContext &ctx = module->getContext();
158+
llvm::Type *dynPtrType = llvm::PointerType::get(ctx, 0);
159+
std::vector<llvm::Type *> argTypes;
160+
argTypes.push_back(dynPtrType);
161+
162+
for (auto &arg : oldFn->args()) {
163+
argTypes.push_back(arg.getType());
164+
}
165+
166+
llvm::FunctionType *newFnType = llvm::FunctionType::get(
167+
oldFn->getReturnType(), argTypes, oldFn->isVarArg());
168+
169+
// use a temporary .offload appendix to avoid name clashes
170+
llvm::Function *newFn = llvm::Function::Create(
171+
newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module);
172+
173+
// Map old arguments to new arguments. We skip the first dyn_ptr argument,
174+
// since it can't be used directly by user code.
175+
llvm::ValueToValueMapTy vmap;
176+
auto newArgIt = newFn->arg_begin();
177+
newArgIt->setName("dyn_ptr");
178+
++newArgIt; // skip %dyn_ptr
179+
for (auto &oldArg : oldFn->args()) {
180+
vmap[&oldArg] = &*newArgIt++;
181+
}
182+
183+
llvm::SmallVector<llvm::ReturnInst *, 8> returns;
184+
llvm::CloneFunctionInto(newFn, oldFn, vmap,
185+
llvm::CloneFunctionChangeType::LocalChangesOnly,
186+
returns);
187+
newFn->setLinkage(oldFn->getLinkage());
188+
newFn->setVisibility(oldFn->getVisibility());
189+
190+
// Replace uses, delete old function, and reset name to the original one.
191+
oldFn->replaceAllUsesWith(newFn);
192+
auto name = oldFn->getName();
193+
oldFn->eraseFromParent();
194+
newFn->setName(name);
195+
}
196+
145197
extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
146198
size_t NameLen) {
147199
return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));

compiler/rustc_target/src/callconv/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ impl RiscvInterruptKind {
577577
///
578578
/// The signature represented by this type may not match the MIR function signature.
579579
/// Certain attributes, like `#[track_caller]` can introduce additional arguments, which are present in [`FnAbi`], but not in `FnSig`.
580+
/// The std::offload module also adds an addition dyn_ptr argument to the GpuKernel ABI.
580581
/// While this difference is rarely relevant, it should still be kept in mind.
581582
///
582583
/// I will do my best to describe this structure, but these

0 commit comments

Comments
 (0)