|
35 | 35 | #include "llvm/Support/Signals.h" |
36 | 36 | #include "llvm/Support/Timer.h" |
37 | 37 | #include "llvm/Support/ToolOutputFile.h" |
| 38 | +#include "llvm/Transforms/Utils/Cloning.h" |
| 39 | +#include "llvm/Transforms/Utils/ValueMapper.h" |
38 | 40 | #include <iostream> |
39 | 41 |
|
40 | 42 | // for raw `write` in the bad-alloc handler |
@@ -142,6 +144,56 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) { |
142 | 144 | llvm::PrintStatistics(OS); |
143 | 145 | } |
144 | 146 |
|
| 147 | +extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) { |
| 148 | + llvm::Module *module = llvm::unwrap(M); |
| 149 | + llvm::Function *oldFn = llvm::unwrap<llvm::Function>(Fn); |
| 150 | + |
| 151 | + if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") { |
| 152 | + return; |
| 153 | + } |
| 154 | + |
| 155 | + // 1. Create new function type with the leading extra %dyn_ptr arg which llvm |
| 156 | + // offload requries. |
| 157 | + llvm::LLVMContext &ctx = module->getContext(); |
| 158 | + llvm::Type *dynPtrType = llvm::PointerType::get(ctx, 0); |
| 159 | + std::vector<llvm::Type *> argTypes; |
| 160 | + argTypes.push_back(dynPtrType); |
| 161 | + |
| 162 | + for (auto &arg : oldFn->args()) { |
| 163 | + argTypes.push_back(arg.getType()); |
| 164 | + } |
| 165 | + |
| 166 | + llvm::FunctionType *newFnType = llvm::FunctionType::get( |
| 167 | + oldFn->getReturnType(), argTypes, oldFn->isVarArg()); |
| 168 | + |
| 169 | + // use a temporary .offload appendix to avoid name clashes |
| 170 | + llvm::Function *newFn = llvm::Function::Create( |
| 171 | + newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module); |
| 172 | + |
| 173 | + // Map old arguments to new arguments. We skip the first dyn_ptr argument, |
| 174 | + // since it can't be used directly by user code. |
| 175 | + llvm::ValueToValueMapTy vmap; |
| 176 | + auto newArgIt = newFn->arg_begin(); |
| 177 | + newArgIt->setName("dyn_ptr"); |
| 178 | + ++newArgIt; // skip %dyn_ptr |
| 179 | + for (auto &oldArg : oldFn->args()) { |
| 180 | + vmap[&oldArg] = &*newArgIt++; |
| 181 | + } |
| 182 | + |
| 183 | + llvm::SmallVector<llvm::ReturnInst *, 8> returns; |
| 184 | + llvm::CloneFunctionInto(newFn, oldFn, vmap, |
| 185 | + llvm::CloneFunctionChangeType::LocalChangesOnly, |
| 186 | + returns); |
| 187 | + newFn->setLinkage(oldFn->getLinkage()); |
| 188 | + newFn->setVisibility(oldFn->getVisibility()); |
| 189 | + |
| 190 | + // Replace uses, delete old function, and reset name to the original one. |
| 191 | + oldFn->replaceAllUsesWith(newFn); |
| 192 | + auto name = oldFn->getName(); |
| 193 | + oldFn->eraseFromParent(); |
| 194 | + newFn->setName(name); |
| 195 | +} |
| 196 | + |
145 | 197 | extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name, |
146 | 198 | size_t NameLen) { |
147 | 199 | return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen))); |
|
0 commit comments