Skip to content

Commit 97a8e96

Browse files
committed
Don't depend on outer fn and some cleanup
1 parent 834347a commit 97a8e96

File tree

6 files changed

+53
-90
lines changed

6 files changed

+53
-90
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -262,23 +262,19 @@ pub(crate) fn add_global<'ll>(
262262
pub(crate) fn gen_define_handling<'ll>(
263263
cx: &SimpleCx<'ll>,
264264
offload_entry_ty: &'ll llvm::Type,
265-
metadata: &Vec<OffloadMetadata>,
266-
types: &Vec<&Type>,
265+
metadata: &[OffloadMetadata],
266+
types: &[&Type],
267267
symbol: &str,
268268
) -> (&'ll llvm::Value, &'ll llvm::Value) {
269269
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
270270
// reference) types.
271-
let ptr_meta = types
272-
.iter()
273-
.zip(metadata)
274-
.filter_map(|(&x, meta)| match cx.type_kind(x) {
275-
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
276-
_ => None,
277-
})
278-
.collect::<Vec<_>>();
279-
280-
let ptr_sizes = ptr_meta.iter().map(|m| m.payload_size).collect::<Vec<_>>();
281-
let ptr_transfer = ptr_meta.iter().map(|m| m.mode as u64 | 0x20).collect::<Vec<_>>();
271+
let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
272+
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
273+
_ => None,
274+
});
275+
276+
let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) =
277+
ptr_meta.map(|m| (m.payload_size, m.mode as u64 | 0x20)).unzip();
282278

283279
// We do not know their size anymore at this level, so hardcode a placeholder.
284280
// A follow-up pr will track these from the frontend, where we still have Rust types.
@@ -369,9 +365,9 @@ pub(crate) fn gen_call_handling<'ll>(
369365
bb: &BasicBlock,
370366
memtransfer_types: &[&'ll llvm::Value],
371367
region_ids: &[&'ll llvm::Value],
372-
llfn: &'ll Value,
373-
types: &Vec<&Type>,
374-
metadata: &Vec<OffloadMetadata>,
368+
args: &[&'ll Value],
369+
types: &[&Type],
370+
metadata: &[OffloadMetadata],
375371
) {
376372
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
377373
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
@@ -413,7 +409,7 @@ pub(crate) fn gen_call_handling<'ll>(
413409
let mut geps = vec![];
414410
let i32_0 = cx.get_const_i32(0);
415411
for index in 0..num_args {
416-
let v = unsafe { llvm::LLVMGetParam(llfn, index as u32) };
412+
let v = args[index as usize];
417413
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
418414
vals.push(v);
419415
geps.push(gep);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
198198
return Ok(());
199199
}
200200
sym::offload => {
201-
codegen_offload(self, tcx, instance, args, result);
201+
codegen_offload(self, tcx, instance, args);
202202
return Ok(());
203203
}
204204
sym::is_val_statically_known => {
@@ -1237,8 +1237,7 @@ fn codegen_offload<'ll, 'tcx>(
12371237
bx: &mut Builder<'_, 'll, 'tcx>,
12381238
tcx: TyCtxt<'tcx>,
12391239
instance: ty::Instance<'tcx>,
1240-
_args: &[OperandRef<'tcx, &'ll Value>],
1241-
_result: PlaceRef<'tcx, &'ll Value>,
1240+
args: &[OperandRef<'tcx, &'ll Value>],
12421241
) {
12431242
let cx = bx.cx;
12441243
let fn_args = instance.args;
@@ -1261,7 +1260,8 @@ fn codegen_offload<'ll, 'tcx>(
12611260
}
12621261
};
12631262

1264-
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE);
1263+
let args = get_args_from_tuple(bx, args[1], fn_target);
1264+
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);
12651265

12661266
let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
12671267

@@ -1270,7 +1270,6 @@ fn codegen_offload<'ll, 'tcx>(
12701270
let inputs = sig.inputs();
12711271

12721272
let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::<Vec<_>>();
1273-
let llfn = bx.llfn();
12741273

12751274
let types = inputs.iter().map(|ty| cx.layout_of(*ty).llvm_type(cx)).collect::<Vec<_>>();
12761275

@@ -1290,7 +1289,7 @@ fn codegen_offload<'ll, 'tcx>(
12901289
bb,
12911290
&[memtransfer_type],
12921291
&[region_id],
1293-
llfn,
1292+
&args,
12941293
&types,
12951294
&metadata,
12961295
);

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ pub(crate) fn check_intrinsic_type(
311311
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
312312
(0, 0, vec![type_id, type_id], tcx.types.bool)
313313
}
314-
sym::offload => (2, 0, vec![param(0)], param(1)),
314+
sym::offload => (3, 0, vec![param(0), param(1)], param(2)),
315315
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
316316
sym::arith_offset => (
317317
1,

compiler/rustc_middle/src/ty/offload_meta.rs

Lines changed: 32 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -31,39 +31,7 @@ impl OffloadMetadata {
3131
// TODO(Sa4dUs): WIP, rn we just have a naive logic for references
3232
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
3333
match ty.kind() {
34-
/*
35-
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
36-
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
37-
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
38-
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
39-
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
40-
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
41-
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
42-
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
43-
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
44-
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
45-
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
46-
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
47-
*/
48-
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
49-
/*
50-
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
51-
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
52-
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
53-
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
54-
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
55-
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
56-
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
57-
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
58-
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
59-
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
60-
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
61-
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
62-
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
63-
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
64-
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
65-
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
66-
*/
34+
ty::RawPtr(inner, _) | ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
6735
_ => tcx
6836
.layout_of(PseudoCanonicalInput {
6937
typing_env: TypingEnv::fully_monomorphized(),
@@ -79,44 +47,44 @@ impl TransferKind {
7947
pub fn from_ty<'tcx>(_tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
8048
// TODO(Sa4dUs): this logic is probs not fully correct, but it works for now
8149
match ty.kind() {
82-
rustc_type_ir::TyKind::Bool
83-
| rustc_type_ir::TyKind::Char
84-
| rustc_type_ir::TyKind::Int(_)
85-
| rustc_type_ir::TyKind::Uint(_)
86-
| rustc_type_ir::TyKind::Float(_) => TransferKind::ToGpu,
50+
ty::Bool
51+
| ty::Char
52+
| ty::Int(_)
53+
| ty::Uint(_)
54+
| ty::Float(_) => TransferKind::ToGpu,
8755

88-
rustc_type_ir::TyKind::Adt(_, _)
89-
| rustc_type_ir::TyKind::Tuple(_)
90-
| rustc_type_ir::TyKind::Array(_, _) => TransferKind::ToGpu,
56+
ty::Adt(_, _)
57+
| ty::Tuple(_)
58+
| ty::Array(_, _) => TransferKind::ToGpu,
9159

92-
rustc_type_ir::TyKind::RawPtr(_, rustc_ast::Mutability::Not)
93-
| rustc_type_ir::TyKind::Ref(_, _, rustc_ast::Mutability::Not) => TransferKind::ToGpu,
60+
ty::RawPtr(_, rustc_ast::Mutability::Not)
61+
| ty::Ref(_, _, rustc_ast::Mutability::Not) => TransferKind::ToGpu,
9462

95-
rustc_type_ir::TyKind::RawPtr(_, rustc_ast::Mutability::Mut)
96-
| rustc_type_ir::TyKind::Ref(_, _, rustc_ast::Mutability::Mut) => TransferKind::Both,
63+
ty::RawPtr(_, rustc_ast::Mutability::Mut)
64+
| ty::Ref(_, _, rustc_ast::Mutability::Mut) => TransferKind::Both,
9765

98-
rustc_type_ir::TyKind::Slice(_)
99-
| rustc_type_ir::TyKind::Str
100-
| rustc_type_ir::TyKind::Dynamic(_, _) => TransferKind::Both,
66+
ty::Slice(_)
67+
| ty::Str
68+
| ty::Dynamic(_, _) => TransferKind::Both,
10169

102-
rustc_type_ir::TyKind::FnDef(_, _)
103-
| rustc_type_ir::TyKind::FnPtr(_, _)
104-
| rustc_type_ir::TyKind::Closure(_, _)
105-
| rustc_type_ir::TyKind::CoroutineClosure(_, _)
106-
| rustc_type_ir::TyKind::Coroutine(_, _)
107-
| rustc_type_ir::TyKind::CoroutineWitness(_, _) => TransferKind::ToGpu,
70+
ty::FnDef(_, _)
71+
| ty::FnPtr(_, _)
72+
| ty::Closure(_, _)
73+
| ty::CoroutineClosure(_, _)
74+
| ty::Coroutine(_, _)
75+
| ty::CoroutineWitness(_, _) => TransferKind::ToGpu,
10876

109-
rustc_type_ir::TyKind::Alias(_, _)
110-
| rustc_type_ir::TyKind::Param(_)
111-
| rustc_type_ir::TyKind::Bound(_, _)
112-
| rustc_type_ir::TyKind::Placeholder(_)
113-
| rustc_type_ir::TyKind::Infer(_)
114-
| rustc_type_ir::TyKind::Error(_) => TransferKind::ToGpu,
77+
ty::Alias(_, _)
78+
| ty::Param(_)
79+
| ty::Bound(_, _)
80+
| ty::Placeholder(_)
81+
| ty::Infer(_)
82+
| ty::Error(_) => TransferKind::ToGpu,
11583

116-
rustc_type_ir::TyKind::Never => TransferKind::ToGpu,
117-
rustc_type_ir::TyKind::Foreign(_) => TransferKind::Both,
118-
rustc_type_ir::TyKind::Pat(_, _) => TransferKind::Both,
119-
rustc_type_ir::TyKind::UnsafeBinder(_) => TransferKind::Both,
84+
ty::Never => TransferKind::ToGpu,
85+
ty::Foreign(_) => TransferKind::Both,
86+
ty::Pat(_, _) => TransferKind::Both,
87+
ty::UnsafeBinder(_) => TransferKind::Both,
12088
}
12189
}
12290
}

library/core/src/intrinsics/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3278,7 +3278,7 @@ pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) ->
32783278

32793279
#[rustc_nounwind]
32803280
#[rustc_intrinsic]
3281-
pub const fn offload<F, R>(f: F) -> R;
3281+
pub const fn offload<F, T: crate::marker::Tuple, R>(f: F, args: T) -> R;
32823282

32833283
/// Inform Miri that a given pointer definitely has a certain alignment.
32843284
#[cfg(miri)]

tests/codegen-llvm/gpu_offload/gpu_host.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ fn main() {
102102
#[unsafe(no_mangle)]
103103
#[inline(never)]
104104
pub fn kernel(x: &mut [f32; 256]) {
105-
core::intrinsics::offload(_kernel)
105+
core::intrinsics::offload(_kernel, (x,))
106106
}
107107

108108
#[unsafe(no_mangle)]

0 commit comments

Comments
 (0)