Skip to content

Commit 0827957

Browse files
committed
first definition of offload intrinsic (dirty code)
1 parent 0bbef55 commit 0827957

File tree

7 files changed

+210
-45
lines changed

7 files changed

+210
-45
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@ use llvm::Linkage::*;
44
use rustc_abi::Align;
55
use rustc_codegen_ssa::back::write::CodegenContext;
66
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7+
use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
78

89
use crate::builder::SBuilder;
9-
use crate::common::AsCCharPtr;
1010
use crate::llvm::AttributePlace::Function;
11-
use crate::llvm::{self, Linkage, Type, Value};
11+
use crate::llvm::{self, BasicBlock, Linkage, Type, Value};
1212
use crate::{LlvmCodegenBackend, SimpleCx, attributes};
1313

1414
pub(crate) fn handle_gpu_code<'ll>(
1515
_cgcx: &CodegenContext<LlvmCodegenBackend>,
16-
cx: &'ll SimpleCx<'_>,
16+
_cx: &'ll SimpleCx<'_>,
1717
) {
18+
/*
1819
// The offload memory transfer type for each kernel
1920
let mut memtransfer_types = vec![];
2021
let mut region_ids = vec![];
@@ -32,6 +33,7 @@ pub(crate) fn handle_gpu_code<'ll>(
3233
}
3334
3435
gen_call_handling(&cx, &memtransfer_types, &region_ids);
36+
*/
3537
}
3638

3739
// ; Function Attrs: nounwind
@@ -79,7 +81,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7981
at_one
8082
}
8183

82-
struct TgtOffloadEntry {
84+
pub(crate) struct TgtOffloadEntry {
8385
// uint64_t Reserved;
8486
// uint16_t Version;
8587
// uint16_t Kind;
@@ -256,11 +258,14 @@ pub(crate) fn add_global<'ll>(
256258
// This function returns a memtransfer value which encodes how arguments to this kernel shall be
257259
// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
258260
// concatenated into the list of region_ids.
259-
fn gen_define_handling<'ll>(
260-
cx: &'ll SimpleCx<'_>,
261+
pub(crate) fn gen_define_handling<'ll, 'tcx>(
262+
cx: &SimpleCx<'ll>,
263+
tcx: TyCtxt<'tcx>,
261264
kernel: &'ll llvm::Value,
262265
offload_entry_ty: &'ll llvm::Type,
263-
num: i64,
266+
// TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
267+
tt: Vec<Ty<'tcx>>,
268+
symbol: &str,
264269
) -> (&'ll llvm::Value, &'ll llvm::Value) {
265270
let types = cx.func_params_types(cx.get_type_of_global(kernel));
266271
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
@@ -270,37 +275,50 @@ fn gen_define_handling<'ll>(
270275
.filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
271276
.count();
272277

278+
// TODO(Sa4dUs): Add typetrees here
279+
let ptr_sizes = types
280+
.iter()
281+
.zip(tt)
282+
.filter_map(|(&x, ty)| match cx.type_kind(x) {
283+
rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)),
284+
_ => None,
285+
})
286+
.collect::<Vec<u64>>();
287+
273288
// We do not know their size anymore at this level, so hardcode a placeholder.
274289
// A follow-up pr will track these from the frontend, where we still have Rust types.
275290
// Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
276291
// I decided that 1024 bytes is a great placeholder value for now.
277-
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
292+
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
278293
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
279294
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
280295
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
281296
// will be 2. For now, everything is 3, until we have our frontend set up.
282297
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
283298
let memtransfer_types = add_priv_unnamed_arr(
284299
&cx,
285-
&format!(".offload_maptypes.{num}"),
300+
&format!(".offload_maptypes.{symbol}"),
286301
&vec![1 + 2 + 32; num_ptr_types],
287302
);
303+
288304
// Next: For each function, generate these three entries. A weak constant,
289305
// the llvm.rodata entry name, and the llvm_offload_entries value
290306

291-
let name = format!(".kernel_{num}.region_id");
307+
let name = format!(".{symbol}.region_id");
292308
let initializer = cx.get_const_i8(0);
293309
let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
294310

295-
let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
311+
let c_entry_name = CString::new(symbol).unwrap();
296312
let c_val = c_entry_name.as_bytes_with_nul();
297-
let offload_entry_name = format!(".offloading.entry_name.{num}");
313+
let offload_entry_name = format!(".offloading.entry_name.{symbol}");
298314

299315
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
300316
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
301317
llvm::set_alignment(llglobal, Align::ONE);
302318
llvm::set_section(llglobal, c".llvm.rodata.offloading");
303-
let name = format!(".offloading.entry.kernel_{num}");
319+
320+
// Not actively used yet, for calling real kernels
321+
let name = format!(".offloading.entry.{symbol}");
304322

305323
// See the __tgt_offload_entry documentation above.
306324
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
@@ -317,7 +335,57 @@ fn gen_define_handling<'ll>(
317335
(memtransfer_types, region_id)
318336
}
319337

320-
pub(crate) fn declare_offload_fn<'ll>(
338+
// TODO(Sa4dUs): move this to a proper place
339+
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
340+
match ty.kind() {
341+
/*
342+
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
343+
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
344+
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
345+
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
346+
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
347+
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
348+
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
349+
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
350+
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
351+
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
352+
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
353+
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
354+
*/
355+
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
356+
/*
357+
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
358+
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
359+
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
360+
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
361+
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
362+
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
363+
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
364+
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
365+
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
366+
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
367+
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
368+
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
369+
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
370+
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
371+
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
372+
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
373+
*/
374+
_ => {
375+
tcx
376+
// TODO(Sa4dUs): Maybe `.as_query_input()`?
377+
.layout_of(PseudoCanonicalInput {
378+
typing_env: TypingEnv::fully_monomorphized(),
379+
value: ty,
380+
})
381+
.unwrap()
382+
.size
383+
.bytes()
384+
}
385+
}
386+
}
387+
388+
fn declare_offload_fn<'ll>(
321389
cx: &'ll SimpleCx<'_>,
322390
name: &str,
323391
ty: &'ll llvm::Type,
@@ -352,10 +420,13 @@ pub(crate) fn declare_offload_fn<'ll>(
352420
// 4. set insert point after kernel call.
353421
// 5. generate all the GEPS and stores, to be used in 6)
354422
// 6. generate __tgt_target_data_end calls to move data from the GPU
355-
fn gen_call_handling<'ll>(
356-
cx: &'ll SimpleCx<'_>,
423+
pub(crate) fn gen_call_handling<'ll>(
424+
cx: &SimpleCx<'ll>,
425+
bb: &BasicBlock,
426+
kernels: &[&'ll llvm::Value],
357427
memtransfer_types: &[&'ll llvm::Value],
358428
region_ids: &[&'ll llvm::Value],
429+
llfn: &'ll Value,
359430
) {
360431
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
361432
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
@@ -368,27 +439,14 @@ fn gen_call_handling<'ll>(
368439
let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
369440
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
370441

371-
let main_fn = cx.get_function("main");
372-
let Some(main_fn) = main_fn else { return };
373-
let kernel_name = "kernel_1";
374-
let call = unsafe {
375-
llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
376-
};
377-
let Some(kernel_call) = call else {
378-
return;
379-
};
380-
let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
381-
let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
382-
let mut builder = SBuilder::build(cx, kernel_call_bb);
383-
384-
let types = cx.func_params_types(cx.get_type_of_global(called));
442+
let mut builder = SBuilder::build(cx, bb);
443+
444+
let types = cx.func_params_types(cx.get_type_of_global(kernels[0]));
385445
let num_args = types.len() as u64;
386446

387447
// Step 0)
388448
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
389449
// %6 = alloca %struct.__tgt_bin_desc, align 8
390-
unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
391-
392450
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
393451

394452
let ty = cx.type_array(cx.type_ptr(), num_args);
@@ -404,15 +462,14 @@ fn gen_call_handling<'ll>(
404462
let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
405463

406464
// Step 1)
407-
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
408465
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
409466

410467
// Now we allocate once per function param, a copy to be passed to one of our maps.
411468
let mut vals = vec![];
412469
let mut geps = vec![];
413470
let i32_0 = cx.get_const_i32(0);
414-
for index in 0..types.len() {
415-
let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() };
471+
for index in 0..num_args {
472+
let v = unsafe { llvm::LLVMGetParam(llfn, index as u32) };
416473
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
417474
vals.push(v);
418475
geps.push(gep);
@@ -504,13 +561,8 @@ fn gen_call_handling<'ll>(
504561
region_ids[0],
505562
a5,
506563
];
507-
let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
564+
builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
508565
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
509-
unsafe {
510-
let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
511-
llvm::LLVMRustPositionAfter(builder.llbuilder, next);
512-
llvm::LLVMInstructionEraseFromParent(next);
513-
}
514566

515567
// Step 4)
516568
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
@@ -519,8 +571,4 @@ fn gen_call_handling<'ll>(
519571
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
520572

521573
drop(builder);
522-
// FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
523-
// and then delete the call to the CPU version. In the future, we should use an intrinsic which
524-
// directly resolves to a call to the GPU version.
525-
unsafe { llvm::LLVMDeleteFunction(called) };
526574
}

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use tracing::debug;
2323
use crate::abi::FnAbiLlvmExt;
2424
use crate::builder::Builder;
2525
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
26+
use crate::builder::gpu_offload::TgtOffloadEntry;
2627
use crate::context::CodegenCx;
2728
use crate::errors::AutoDiffWithoutEnable;
2829
use crate::llvm::{self, Metadata, Type, Value};
@@ -195,6 +196,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
195196
codegen_autodiff(self, tcx, instance, args, result);
196197
return Ok(());
197198
}
199+
sym::offload => {
200+
codegen_offload(self, tcx, instance, args, result);
201+
return Ok(());
202+
}
198203
sym::is_val_statically_known => {
199204
if let OperandValue::Immediate(imm) = args[0].val {
200205
self.call_intrinsic(
@@ -1227,6 +1232,72 @@ fn codegen_autodiff<'ll, 'tcx>(
12271232
);
12281233
}
12291234

1235+
fn codegen_offload<'ll, 'tcx>(
1236+
bx: &mut Builder<'_, 'll, 'tcx>,
1237+
tcx: TyCtxt<'tcx>,
1238+
instance: ty::Instance<'tcx>,
1239+
_args: &[OperandRef<'tcx, &'ll Value>],
1240+
_result: PlaceRef<'tcx, &'ll Value>,
1241+
) {
1242+
let cx = bx.cx;
1243+
let fn_args = instance.args;
1244+
1245+
let (target_id, target_args) = match fn_args.into_type_list(tcx)[0].kind() {
1246+
ty::FnDef(def_id, params) => (def_id, params),
1247+
_ => bug!("invalid offload intrinsic arg"),
1248+
};
1249+
1250+
let fn_target = match Instance::try_resolve(tcx, cx.typing_env(), *target_id, target_args) {
1251+
Ok(Some(instance)) => instance,
1252+
Ok(None) => bug!(
1253+
"could not resolve ({:?}, {:?}) to a specific offload instance",
1254+
target_id,
1255+
target_args
1256+
),
1257+
Err(_) => {
1258+
// An error has already been emitted
1259+
return;
1260+
}
1261+
};
1262+
1263+
// TODO(Sa4dUs): Will need typetrees
1264+
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE);
1265+
let Some(kernel) = cx.get_function(&target_symbol) else {
1266+
bug!("could not find target function")
1267+
};
1268+
1269+
let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
1270+
1271+
// Build TypeTree (or something similar)
1272+
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
1273+
let inputs = sig.inputs();
1274+
1275+
// TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory
1276+
let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling(
1277+
cx,
1278+
tcx,
1279+
kernel,
1280+
offload_entry_ty,
1281+
inputs.to_vec(),
1282+
&target_symbol,
1283+
);
1284+
1285+
let kernels = &[kernel];
1286+
1287+
let llfn = bx.llfn();
1288+
1289+
// TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix
1290+
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
1291+
crate::builder::gpu_offload::gen_call_handling(
1292+
cx,
1293+
bb,
1294+
kernels,
1295+
&[memtransfer_type],
1296+
&[region_id],
1297+
llfn,
1298+
);
1299+
}
1300+
12301301
fn get_args_from_tuple<'ll, 'tcx>(
12311302
bx: &mut Builder<'_, 'll, 'tcx>,
12321303
tuple_op: OperandRef<'tcx, &'ll Value>,

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
//!
55
//! This API is completely unstable and subject to change.
66
7+
// TODO(Sa4dUs): remove this once we have a great version, just to ignore unused LLVM wrappers
8+
#![allow(unused)]
79
// tidy-alphabetical-start
810
#![feature(assert_matches)]
911
#![feature(extern_types)]

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
163163
| sym::minnumf128
164164
| sym::mul_with_overflow
165165
| sym::needs_drop
166+
| sym::offload
166167
| sym::powf16
167168
| sym::powf32
168169
| sym::powf64
@@ -310,6 +311,7 @@ pub(crate) fn check_intrinsic_type(
310311
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
311312
(0, 0, vec![type_id, type_id], tcx.types.bool)
312313
}
314+
sym::offload => (2, 0, vec![param(0)], param(1)),
313315
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
314316
sym::arith_offset => (
315317
1,

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,7 @@ symbols! {
15781578
object_safe_for_dispatch,
15791579
of,
15801580
off,
1581+
offload,
15811582
offset,
15821583
offset_of,
15831584
offset_of_enum,

library/core/src/intrinsics/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3276,6 +3276,10 @@ pub const fn copysignf128(x: f128, y: f128) -> f128;
32763276
#[rustc_intrinsic]
32773277
pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
32783278

3279+
#[rustc_nounwind]
3280+
#[rustc_intrinsic]
3281+
pub const fn offload<F, R>(f: F) -> R;
3282+
32793283
/// Inform Miri that a given pointer definitely has a certain alignment.
32803284
#[cfg(miri)]
32813285
#[rustc_allow_const_fn_unstable(const_eval_select)]

0 commit comments

Comments
 (0)