Skip to content

Commit 9f3ccf3

Browse files
committed
Add basic offload metadata
1 parent 0827957 commit 9f3ccf3

File tree

4 files changed

+84
-64
lines changed

4 files changed

+84
-64
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 7 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use llvm::Linkage::*;
44
use rustc_abi::Align;
55
use rustc_codegen_ssa::back::write::CodegenContext;
66
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7+
use rustc_middle::ty::offload_meta::OffloadMetadata;
78
use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
89

910
use crate::builder::SBuilder;
@@ -263,8 +264,7 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
263264
tcx: TyCtxt<'tcx>,
264265
kernel: &'ll llvm::Value,
265266
offload_entry_ty: &'ll llvm::Type,
266-
// TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
267-
tt: Vec<Ty<'tcx>>,
267+
metadata: Vec<OffloadMetadata>,
268268
symbol: &str,
269269
) -> (&'ll llvm::Value, &'ll llvm::Value) {
270270
let types = cx.func_params_types(cx.get_type_of_global(kernel));
@@ -275,12 +275,11 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
275275
.filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
276276
.count();
277277

278-
// TODO(Sa4dUs): Add typetrees here
279278
let ptr_sizes = types
280279
.iter()
281-
.zip(tt)
282-
.filter_map(|(&x, ty)| match cx.type_kind(x) {
283-
rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)),
280+
.zip(metadata)
281+
.filter_map(|(&x, meta)| match cx.type_kind(x) {
282+
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta.payload_size),
284283
_ => None,
285284
})
286285
.collect::<Vec<u64>>();
@@ -335,56 +334,6 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
335334
(memtransfer_types, region_id)
336335
}
337336

338-
// TODO(Sa4dUs): move this to a proper place
339-
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
340-
match ty.kind() {
341-
/*
342-
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
343-
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
344-
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
345-
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
346-
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
347-
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
348-
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
349-
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
350-
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
351-
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
352-
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
353-
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
354-
*/
355-
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
356-
/*
357-
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
358-
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
359-
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
360-
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
361-
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
362-
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
363-
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
364-
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
365-
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
366-
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
367-
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
368-
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
369-
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
370-
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
371-
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
372-
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
373-
*/
374-
_ => {
375-
tcx
376-
// TODO(Sa4dUs): Maybe `.as_query_input()`?
377-
.layout_of(PseudoCanonicalInput {
378-
typing_env: TypingEnv::fully_monomorphized(),
379-
value: ty,
380-
})
381-
.unwrap()
382-
.size
383-
.bytes()
384-
}
385-
}
386-
}
387-
388337
fn declare_offload_fn<'ll>(
389338
cx: &'ll SimpleCx<'_>,
390339
name: &str,
@@ -423,7 +372,7 @@ fn declare_offload_fn<'ll>(
423372
pub(crate) fn gen_call_handling<'ll>(
424373
cx: &SimpleCx<'ll>,
425374
bb: &BasicBlock,
426-
kernels: &[&'ll llvm::Value],
375+
kernel: &'ll llvm::Value,
427376
memtransfer_types: &[&'ll llvm::Value],
428377
region_ids: &[&'ll llvm::Value],
429378
llfn: &'ll Value,
@@ -441,7 +390,7 @@ pub(crate) fn gen_call_handling<'ll>(
441390

442391
let mut builder = SBuilder::build(cx, bb);
443392

444-
let types = cx.func_params_types(cx.get_type_of_global(kernels[0]));
393+
let types = cx.func_params_types(cx.get_type_of_global(kernel));
445394
let num_args = types.len() as u64;
446395

447396
// Step 0)

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE;
1313
use rustc_hir::{self as hir};
1414
use rustc_middle::mir::BinOp;
1515
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
16+
use rustc_middle::ty::offload_meta::OffloadMetadata;
1617
use rustc_middle::ty::{self, GenericArgsRef, Instance, SimdAlign, Ty, TyCtxt, TypingEnv};
1718
use rustc_middle::{bug, span_bug};
1819
use rustc_span::{Span, Symbol, sym};
@@ -1260,7 +1261,6 @@ fn codegen_offload<'ll, 'tcx>(
12601261
}
12611262
};
12621263

1263-
// TODO(Sa4dUs): Will need typetrees
12641264
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE);
12651265
let Some(kernel) = cx.get_function(&target_symbol) else {
12661266
bug!("could not find target function")
@@ -1272,26 +1272,26 @@ fn codegen_offload<'ll, 'tcx>(
12721272
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
12731273
let inputs = sig.inputs();
12741274

1275+
let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::<Vec<_>>();
1276+
12751277
// TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory
12761278
let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling(
12771279
cx,
12781280
tcx,
12791281
kernel,
12801282
offload_entry_ty,
1281-
inputs.to_vec(),
1283+
metadata,
12821284
&target_symbol,
12831285
);
12841286

1285-
let kernels = &[kernel];
1286-
12871287
let llfn = bx.llfn();
12881288

1289-
// TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix
1289+
// TODO(Sa4dUs): this is just to a void lifetime's issues
12901290
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
12911291
crate::builder::gpu_offload::gen_call_handling(
12921292
cx,
12931293
bb,
1294-
kernels,
1294+
kernel,
12951295
&[memtransfer_type],
12961296
&[region_id],
12971297
llfn,

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ pub mod fast_reject;
130130
pub mod inhabitedness;
131131
pub mod layout;
132132
pub mod normalize_erasing_regions;
133+
pub mod offload_meta;
133134
pub mod pattern;
134135
pub mod print;
135136
pub mod relate;
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use crate::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
2+
3+
// TODO(Sa4dUs): it doesn't feel correct for me to place this on `rustc_ast::expand`, will look for a proper location
4+
pub struct OffloadMetadata {
5+
pub payload_size: u64,
6+
pub mode: TransferKind,
7+
}
8+
9+
pub enum TransferKind {
10+
FromGpu = 1,
11+
ToGpu = 2,
12+
Both = 3,
13+
}
14+
15+
impl OffloadMetadata {
16+
pub fn new(payload_size: u64, mode: TransferKind) -> Self {
17+
OffloadMetadata { payload_size, mode }
18+
}
19+
20+
pub fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
21+
OffloadMetadata { payload_size: get_payload_size(tcx, ty), mode: TransferKind::Both }
22+
}
23+
}
24+
25+
// TODO(Sa4dUs): WIP, rn we just have a naive logic for references
26+
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
27+
match ty.kind() {
28+
/*
29+
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
30+
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
31+
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
32+
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
33+
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
34+
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
35+
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
36+
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
37+
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
38+
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
39+
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
40+
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
41+
*/
42+
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
43+
/*
44+
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
45+
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
46+
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
47+
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
48+
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
49+
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
50+
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
51+
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
52+
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
53+
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
54+
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
55+
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
56+
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
57+
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
58+
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
59+
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
60+
*/
61+
_ => tcx
62+
.layout_of(PseudoCanonicalInput {
63+
typing_env: TypingEnv::fully_monomorphized(),
64+
value: ty,
65+
})
66+
.unwrap()
67+
.size
68+
.bytes(),
69+
}
70+
}

0 commit comments

Comments
 (0)