@@ -4,17 +4,18 @@ use llvm::Linkage::*;
44use rustc_abi:: Align ;
55use rustc_codegen_ssa:: back:: write:: CodegenContext ;
66use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
7+ use rustc_middle:: ty:: { self , PseudoCanonicalInput , Ty , TyCtxt , TypingEnv } ;
78
89use crate :: builder:: SBuilder ;
9- use crate :: common:: AsCCharPtr ;
1010use crate :: llvm:: AttributePlace :: Function ;
11- use crate :: llvm:: { self , Linkage , Type , Value } ;
11+ use crate :: llvm:: { self , BasicBlock , Linkage , Type , Value } ;
1212use crate :: { LlvmCodegenBackend , SimpleCx , attributes} ;
1313
1414pub ( crate ) fn handle_gpu_code < ' ll > (
1515 _cgcx : & CodegenContext < LlvmCodegenBackend > ,
16- cx : & ' ll SimpleCx < ' _ > ,
16+ _cx : & ' ll SimpleCx < ' _ > ,
1717) {
18+ /*
1819 // The offload memory transfer type for each kernel
1920 let mut o_types = vec![];
2021 let mut kernels = vec![];
@@ -28,6 +29,7 @@ pub(crate) fn handle_gpu_code<'ll>(
2829 }
2930
3031 gen_call_handling(&cx, &kernels, &o_types);
32+ */
3133}
3234
3335// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
@@ -83,7 +85,7 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
8385 offload_entry_ty
8486}
8587
86- fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) {
88+ pub ( crate ) fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) {
8789 let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
8890 let tptr = cx. type_ptr ( ) ;
8991 let ti64 = cx. type_i64 ( ) ;
@@ -182,11 +184,14 @@ pub(crate) fn add_global<'ll>(
182184 llglobal
183185}
184186
185- fn gen_define_handling < ' ll > (
186- cx : & ' ll SimpleCx < ' _ > ,
187+ pub ( crate ) fn gen_define_handling < ' ll , ' tcx > (
188+ cx : & SimpleCx < ' ll > ,
189+ tcx : TyCtxt < ' tcx > ,
187190 kernel : & ' ll llvm:: Value ,
188191 offload_entry_ty : & ' ll llvm:: Type ,
189- num : i64 ,
192+ // TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
193+ tt : Vec < Ty < ' tcx > > ,
194+ symbol : & str ,
190195) -> & ' ll llvm:: Value {
191196 let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
192197 // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
@@ -196,35 +201,47 @@ fn gen_define_handling<'ll>(
196201 . filter ( |& x| matches ! ( cx. type_kind( x) , rustc_codegen_ssa:: common:: TypeKind :: Pointer ) )
197202 . count ( ) ;
198203
204+ // TODO(Sa4dUs): Add typetrees here
205+ let ptr_sizes = types
206+ . iter ( )
207+ . zip ( tt)
208+ . filter_map ( |( & x, ty) | match cx. type_kind ( x) {
209+ rustc_codegen_ssa:: common:: TypeKind :: Pointer => Some ( get_payload_size ( tcx, ty) ) ,
210+ _ => None ,
211+ } )
212+ . collect :: < Vec < u64 > > ( ) ;
213+
199214 // We do not know their size anymore at this level, so hardcode a placeholder.
200215 // A follow-up pr will track these from the frontend, where we still have Rust types.
201216 // Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
202217 // I decided that 1024 bytes is a great placeholder value for now.
203- add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num }" ) , & vec ! [ 1024 ; num_ptr_types ] ) ;
218+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{symbol }" ) , & ptr_sizes ) ;
204219 // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
205220 // or both to and from the gpu (=3). Other values shouldn't affect us for now.
206221 // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
207222 // will be 2. For now, everything is 3, until we have our frontend set up.
223+
224+ // TODO(Sa4dUs): Check the way to figure out this
208225 let o_types =
209- add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num }" ) , & vec ! [ 3 ; num_ptr_types] ) ;
226+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{symbol }" ) , & vec ! [ 3 ; num_ptr_types] ) ;
210227 // Next: For each function, generate these three entries. A weak constant,
211228 // the llvm.rodata entry name, and the omp_offloading_entries value
212229
213- let name = format ! ( ".kernel_{num }.region_id" ) ;
230+ let name = format ! ( ".{symbol }.region_id" ) ;
214231 let initializer = cx. get_const_i8 ( 0 ) ;
215232 let region_id = add_unnamed_global ( & cx, & name, initializer, WeakAnyLinkage ) ;
216233
217- let c_entry_name = CString :: new ( format ! ( "kernel_{num}" ) ) . unwrap ( ) ;
234+ let c_entry_name = CString :: new ( symbol ) . unwrap ( ) ;
218235 let c_val = c_entry_name. as_bytes_with_nul ( ) ;
219- let offload_entry_name = format ! ( ".offloading.entry_name.{num }" ) ;
236+ let offload_entry_name = format ! ( ".offloading.entry_name.{symbol }" ) ;
220237
221238 let initializer = crate :: common:: bytes_in_context ( cx. llcx , c_val) ;
222239 let llglobal = add_unnamed_global ( & cx, & offload_entry_name, initializer, InternalLinkage ) ;
223240 llvm:: set_alignment ( llglobal, Align :: ONE ) ;
224241 llvm:: set_section ( llglobal, c".llvm.rodata.offloading" ) ;
225242
226243 // Not actively used yet, for calling real kernels
227- let name = format ! ( ".offloading.entry.kernel_{num }" ) ;
244+ let name = format ! ( ".offloading.entry.{symbol }" ) ;
228245
229246 // See the __tgt_offload_entry documentation above.
230247 let reserved = cx. get_const_i64 ( 0 ) ;
@@ -248,6 +265,56 @@ fn gen_define_handling<'ll>(
248265 o_types
249266}
250267
268+ // TODO(Sa4dUs): move this to a proper place
269+ fn get_payload_size < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> u64 {
270+ match ty. kind ( ) {
271+ /*
272+ rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
273+ rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
274+ rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
275+ rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
276+ rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
277+ rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
278+ rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
279+ rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
280+ rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
281+ rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
282+ rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
283+ rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
284+ */
285+ ty:: Ref ( _, inner, _) => get_payload_size ( tcx, * inner) ,
286+ /*
287+ rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
288+ rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
289+ rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
290+ rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
291+ rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
292+ rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
293+ rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
294+ rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
295+ rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
296+ rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
297+ rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
298+ rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
299+ rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
300+ rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
301+ rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
302+ rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
303+ */
304+ _ => {
305+ tcx
306+ // TODO(Sa4dUs): Maybe `.as_query_input()`?
307+ . layout_of ( PseudoCanonicalInput {
308+ typing_env : TypingEnv :: fully_monomorphized ( ) ,
309+ value : ty,
310+ } )
311+ . unwrap ( )
312+ . size
313+ . bytes ( )
314+ }
315+ }
316+ }
317+
251318fn declare_offload_fn < ' ll > (
252319 cx : & ' ll SimpleCx < ' _ > ,
253320 name : & str ,
@@ -283,10 +350,13 @@ fn declare_offload_fn<'ll>(
283350// 4. set insert point after kernel call.
284351// 5. generate all the GEPS and stores, to be used in 6)
285352// 6. generate __tgt_target_data_end calls to move data from the GPU
286- fn gen_call_handling < ' ll > (
287- cx : & ' ll SimpleCx < ' _ > ,
288- _kernels : & [ & ' ll llvm:: Value ] ,
353+ pub ( crate ) fn gen_call_handling < ' ll > (
354+ cx : & SimpleCx < ' ll > ,
355+ bb : & BasicBlock ,
356+ kernels : & [ & ' ll llvm:: Value ] ,
289357 o_types : & [ & ' ll llvm:: Value ] ,
358+ llty : & ' ll Type ,
359+ llfn : & ' ll Value ,
290360) {
291361 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
292362 let tptr = cx. type_ptr ( ) ;
@@ -298,27 +368,14 @@ fn gen_call_handling<'ll>(
298368 gen_tgt_kernel_global ( & cx) ;
299369 let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
300370
301- let main_fn = cx. get_function ( "main" ) ;
302- let Some ( main_fn) = main_fn else { return } ;
303- let kernel_name = "kernel_1" ;
304- let call = unsafe {
305- llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) )
306- } ;
307- let Some ( kernel_call) = call else {
308- return ;
309- } ;
310- let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
311- let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) . unwrap ( ) } ;
312- let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
313-
314- let types = cx. func_params_types ( cx. get_type_of_global ( called) ) ;
371+ let mut builder = SBuilder :: build ( cx, bb) ;
372+
373+ let types = cx. func_params_types ( cx. get_type_of_global ( kernels[ 0 ] ) ) ;
315374 let num_args = types. len ( ) as u64 ;
316375
317376 // Step 0)
318377 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
319378 // %6 = alloca %struct.__tgt_bin_desc, align 8
320- unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
321-
322379 let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
323380
324381 let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
@@ -335,7 +392,7 @@ fn gen_call_handling<'ll>(
335392 let i32_0 = cx. get_const_i32 ( 0 ) ;
336393 for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
337394 // get function arg, store it into the alloca, and read it.
338- let p = llvm:: get_param ( called , index as u32 ) ;
395+ let p = llvm:: get_param ( kernels [ 0 ] , index as u32 ) ;
339396 let name = llvm:: get_value_name ( p) ;
340397 let name = str:: from_utf8 ( & name) . unwrap ( ) ;
341398 let arg_name = format ! ( "{name}.addr" ) ;
@@ -349,7 +406,6 @@ fn gen_call_handling<'ll>(
349406 }
350407
351408 // Step 1)
352- unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
353409 builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
354410
355411 let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
@@ -422,10 +478,16 @@ fn gen_call_handling<'ll>(
422478 // Step 3)
423479 // Here we will add code for the actual kernel launches in a follow-up PR.
424480 // FIXME(offload): launch kernels
481+ let nparams = llvm:: LLVMCountParams ( llfn) ;
482+ let mut args = Vec :: with_capacity ( nparams as usize ) ;
483+ for i in 0 ..nparams {
484+ let param = unsafe { llvm:: LLVMGetParam ( llfn, i) } ;
485+ args. push ( param) ;
486+ }
425487
426- // Step 4)
427- unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
488+ builder. call ( llty, kernels[ 0 ] , & args, None ) ;
428489
490+ // Step 4)
429491 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
430492 generate_mapper_call ( & mut builder, & cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t) ;
431493
0 commit comments