Skip to content

Commit 9118683

Browse files
committed
Add basic offload metadata
1 parent 4496fb8 commit 9118683

File tree

4 files changed

+90
-71
lines changed

4 files changed

+90
-71
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 11 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use llvm::Linkage::*;
44
use rustc_abi::Align;
55
use rustc_codegen_ssa::back::write::CodegenContext;
66
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7+
use rustc_middle::ty::offload_meta::OffloadMetadata;
78
use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
89

910
use crate::builder::SBuilder;
@@ -189,8 +190,7 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
189190
tcx: TyCtxt<'tcx>,
190191
kernel: &'ll llvm::Value,
191192
offload_entry_ty: &'ll llvm::Type,
192-
// TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
193-
tt: Vec<Ty<'tcx>>,
193+
metadata: Vec<OffloadMetadata>,
194194
symbol: &str,
195195
) -> &'ll llvm::Value {
196196
let types = cx.func_params_types(cx.get_type_of_global(kernel));
@@ -201,12 +201,11 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
201201
.filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
202202
.count();
203203

204-
// TODO(Sa4dUs): Add typetrees here
205204
let ptr_sizes = types
206205
.iter()
207-
.zip(tt)
208-
.filter_map(|(&x, ty)| match cx.type_kind(x) {
209-
rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)),
206+
.zip(metadata)
207+
.filter_map(|(&x, meta)| match cx.type_kind(x) {
208+
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta.payload_size),
210209
_ => None,
211210
})
212211
.collect::<Vec<u64>>();
@@ -265,56 +264,6 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
265264
o_types
266265
}
267266

268-
// TODO(Sa4dUs): move this to a proper place
269-
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
270-
match ty.kind() {
271-
/*
272-
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
273-
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
274-
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
275-
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
276-
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
277-
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
278-
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
279-
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
280-
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
281-
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
282-
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
283-
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
284-
*/
285-
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
286-
/*
287-
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
288-
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
289-
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
290-
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
291-
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
292-
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
293-
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
294-
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
295-
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
296-
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
297-
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
298-
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
299-
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
300-
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
301-
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
302-
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
303-
*/
304-
_ => {
305-
tcx
306-
// TODO(Sa4dUs): Maybe `.as_query_input()`?
307-
.layout_of(PseudoCanonicalInput {
308-
typing_env: TypingEnv::fully_monomorphized(),
309-
value: ty,
310-
})
311-
.unwrap()
312-
.size
313-
.bytes()
314-
}
315-
}
316-
}
317-
318267
fn declare_offload_fn<'ll>(
319268
cx: &'ll SimpleCx<'_>,
320269
name: &str,
@@ -353,8 +302,8 @@ fn declare_offload_fn<'ll>(
353302
pub(crate) fn gen_call_handling<'ll>(
354303
cx: &SimpleCx<'ll>,
355304
bb: &BasicBlock,
356-
kernels: &[&'ll llvm::Value],
357-
o_types: &[&'ll llvm::Value],
305+
kernel: &'ll llvm::Value,
306+
o_type: &'ll llvm::Value,
358307
llty: &'ll Type,
359308
llfn: &'ll Value,
360309
) {
@@ -370,7 +319,7 @@ pub(crate) fn gen_call_handling<'ll>(
370319

371320
let mut builder = SBuilder::build(cx, bb);
372321

373-
let types = cx.func_params_types(cx.get_type_of_global(kernels[0]));
322+
let types = cx.func_params_types(cx.get_type_of_global(kernel));
374323
let num_args = types.len() as u64;
375324

376325
// Step 0)
@@ -392,7 +341,7 @@ pub(crate) fn gen_call_handling<'ll>(
392341
let i32_0 = cx.get_const_i32(0);
393342
for (index, in_ty) in types.iter().enumerate() {
394343
// get function arg, store it into the alloca, and read it.
395-
let p = llvm::get_param(kernels[0], index as u32);
344+
let p = llvm::get_param(kernel, index as u32);
396345
let name = llvm::get_value_name(p);
397346
let name = str::from_utf8(&name).unwrap();
398347
let arg_name = format!("{name}.addr");
@@ -471,7 +420,7 @@ pub(crate) fn gen_call_handling<'ll>(
471420

472421
// Step 2)
473422
let s_ident_t = generate_at_one(&cx);
474-
let o = o_types[0];
423+
let o = o_type;
475424
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
476425
generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
477426

@@ -485,7 +434,7 @@ pub(crate) fn gen_call_handling<'ll>(
485434
args.push(param);
486435
}
487436

488-
builder.call(llty, kernels[0], &args, None);
437+
builder.call(llty, kernel, &args, None);
489438

490439
// Step 4)
491440
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE;
1313
use rustc_hir::{self as hir};
1414
use rustc_middle::mir::BinOp;
1515
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
16+
use rustc_middle::ty::offload_meta::OffloadMetadata;
1617
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
1718
use rustc_middle::{bug, span_bug};
1819
use rustc_span::{Span, Symbol, sym};
@@ -1259,7 +1260,6 @@ fn codegen_offload<'ll, 'tcx>(
12591260
}
12601261
};
12611262

1262-
// TODO(Sa4dUs): Will need typetrees
12631263
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE);
12641264
let Some(kernel) = cx.get_function(&target_symbol) else {
12651265
bug!("could not find target function")
@@ -1271,31 +1271,30 @@ fn codegen_offload<'ll, 'tcx>(
12711271
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
12721272
let inputs = sig.inputs();
12731273

1274-
// TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory
1274+
let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::<Vec<_>>();
1275+
1276+
// TODO(Sa4dUs): separate globals from call-independent headers
12751277
let o = crate::builder::gpu_offload::gen_define_handling(
12761278
cx,
12771279
tcx,
12781280
kernel,
12791281
offload_entry_ty,
1280-
inputs.to_vec(),
1282+
metadata,
12811283
&target_symbol,
12821284
);
12831285

1284-
let kernels = &[kernel];
1285-
let o_types = &[o];
1286-
12871286
let llvm_args = inputs.iter().map(|ty| bx.layout_of(*ty).llvm_type(cx)).collect::<Vec<_>>();
12881287
let ret_ty = match sig.output().kind() {
1289-
// TODO(Sa4dUs): dunno if there's a better way of doing this
1288+
// TODO(Sa4dUs): there's probs better way of doing this
12901289
ty::Tuple(tys) if tys.is_empty() => bx.type_void(),
12911290
_ => bx.layout_of(sig.output()).llvm_type(cx),
12921291
};
12931292
let llty = bx.type_func(&llvm_args, ret_ty);
12941293
let llfn = bx.llfn();
12951294

1296-
// TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix
1295+
// TODO(Sa4dUs): this is just to a void lifetime's issues
12971296
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
1298-
crate::builder::gpu_offload::gen_call_handling(cx, bb, kernels, o_types, llty, llfn);
1297+
crate::builder::gpu_offload::gen_call_handling(cx, bb, kernel, o, llty, llfn);
12991298
}
13001299

13011300
fn get_args_from_tuple<'ll, 'tcx>(

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ pub mod print;
137137
pub mod relate;
138138
pub mod significant_drop_order;
139139
pub mod trait_def;
140+
pub mod offload_meta;
140141
pub mod util;
141142
pub mod vtable;
142143

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use crate::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
2+
3+
// TODO(Sa4dUs): it doesn't feel correct for me to place this on `rustc_ast::expand`, will look for a proper location
4+
pub struct OffloadMetadata {
5+
pub payload_size: u64,
6+
pub mode: TransferKind,
7+
}
8+
9+
pub enum TransferKind {
10+
FromGpu = 1,
11+
ToGpu = 2,
12+
Both = 3,
13+
}
14+
15+
impl OffloadMetadata {
16+
pub fn new(payload_size: u64, mode: TransferKind) -> Self {
17+
OffloadMetadata { payload_size, mode }
18+
}
19+
20+
pub fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
21+
OffloadMetadata { payload_size: get_payload_size(tcx, ty), mode: TransferKind::Both }
22+
}
23+
}
24+
25+
// TODO(Sa4dUs): WIP, rn we just have a naive logic for references
26+
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
27+
match ty.kind() {
28+
/*
29+
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
30+
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
31+
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
32+
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
33+
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
34+
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
35+
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
36+
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
37+
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
38+
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
39+
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
40+
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
41+
*/
42+
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
43+
/*
44+
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
45+
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
46+
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
47+
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
48+
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
49+
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
50+
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
51+
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
52+
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
53+
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
54+
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
55+
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
56+
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
57+
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
58+
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
59+
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
60+
*/
61+
_ => tcx
62+
.layout_of(PseudoCanonicalInput {
63+
typing_env: TypingEnv::fully_monomorphized(),
64+
value: ty,
65+
})
66+
.unwrap()
67+
.size
68+
.bytes(),
69+
}
70+
}

0 commit comments

Comments
 (0)