@@ -4,17 +4,18 @@ use llvm::Linkage::*;
44use rustc_abi:: Align ;
55use rustc_codegen_ssa:: back:: write:: CodegenContext ;
66use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
7+ use rustc_middle:: ty:: { self , PseudoCanonicalInput , Ty , TyCtxt , TypingEnv } ;
78
89use crate :: builder:: SBuilder ;
9- use crate :: common:: AsCCharPtr ;
1010use crate :: llvm:: AttributePlace :: Function ;
11- use crate :: llvm:: { self , Linkage , Type , Value } ;
11+ use crate :: llvm:: { self , BasicBlock , Linkage , Type , Value } ;
1212use crate :: { LlvmCodegenBackend , SimpleCx , attributes} ;
1313
1414pub ( crate ) fn handle_gpu_code < ' ll > (
1515 _cgcx : & CodegenContext < LlvmCodegenBackend > ,
16- cx : & ' ll SimpleCx < ' _ > ,
16+ _cx : & ' ll SimpleCx < ' _ > ,
1717) {
18+ /*
1819 // The offload memory transfer type for each kernel
1920 let mut memtransfer_types = vec![];
2021 let mut region_ids = vec![];
@@ -32,6 +33,7 @@ pub(crate) fn handle_gpu_code<'ll>(
3233 }
3334
3435 gen_call_handling(&cx, &memtransfer_types, ®ion_ids);
36+ */
3537}
3638
3739// ; Function Attrs: nounwind
@@ -79,7 +81,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7981 at_one
8082}
8183
82- struct TgtOffloadEntry {
84+ pub ( crate ) struct TgtOffloadEntry {
8385 // uint64_t Reserved;
8486 // uint16_t Version;
8587 // uint16_t Kind;
@@ -256,11 +258,14 @@ pub(crate) fn add_global<'ll>(
256258// This function returns a memtransfer value which encodes how arguments to this kernel shall be
257259// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
258260// concatenated into the list of region_ids.
259- fn gen_define_handling < ' ll > (
260- cx : & ' ll SimpleCx < ' _ > ,
261+ pub ( crate ) fn gen_define_handling < ' ll , ' tcx > (
262+ cx : & SimpleCx < ' ll > ,
263+ tcx : TyCtxt < ' tcx > ,
261264 kernel : & ' ll llvm:: Value ,
262265 offload_entry_ty : & ' ll llvm:: Type ,
263- num : i64 ,
266+ // TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
267+ tt : Vec < Ty < ' tcx > > ,
268+ symbol : & str ,
264269) -> ( & ' ll llvm:: Value , & ' ll llvm:: Value ) {
265270 let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
266271 // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
@@ -270,37 +275,50 @@ fn gen_define_handling<'ll>(
270275 . filter ( |& x| matches ! ( cx. type_kind( x) , rustc_codegen_ssa:: common:: TypeKind :: Pointer ) )
271276 . count ( ) ;
272277
278+ // TODO(Sa4dUs): Add typetrees here
279+ let ptr_sizes = types
280+ . iter ( )
281+ . zip ( tt)
282+ . filter_map ( |( & x, ty) | match cx. type_kind ( x) {
283+ rustc_codegen_ssa:: common:: TypeKind :: Pointer => Some ( get_payload_size ( tcx, ty) ) ,
284+ _ => None ,
285+ } )
286+ . collect :: < Vec < u64 > > ( ) ;
287+
273288 // We do not know their size anymore at this level, so hardcode a placeholder.
274289 // A follow-up pr will track these from the frontend, where we still have Rust types.
275290 // Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
276291 // I decided that 1024 bytes is a great placeholder value for now.
277- add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num }" ) , & vec ! [ 1024 ; num_ptr_types ] ) ;
292+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{symbol }" ) , & ptr_sizes ) ;
278293 // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
279294 // or both to and from the gpu (=3). Other values shouldn't affect us for now.
280295 // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
281296 // will be 2. For now, everything is 3, until we have our frontend set up.
282297 // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
283298 let memtransfer_types = add_priv_unnamed_arr (
284299 & cx,
285- & format ! ( ".offload_maptypes.{num }" ) ,
300+ & format ! ( ".offload_maptypes.{symbol }" ) ,
286301 & vec ! [ 1 + 2 + 32 ; num_ptr_types] ,
287302 ) ;
303+
288304 // Next: For each function, generate these three entries. A weak constant,
289305 // the llvm.rodata entry name, and the llvm_offload_entries value
290306
291- let name = format ! ( ".kernel_{num }.region_id" ) ;
307+ let name = format ! ( ".{symbol }.region_id" ) ;
292308 let initializer = cx. get_const_i8 ( 0 ) ;
293309 let region_id = add_unnamed_global ( & cx, & name, initializer, WeakAnyLinkage ) ;
294310
295- let c_entry_name = CString :: new ( format ! ( "kernel_{num}" ) ) . unwrap ( ) ;
311+ let c_entry_name = CString :: new ( symbol ) . unwrap ( ) ;
296312 let c_val = c_entry_name. as_bytes_with_nul ( ) ;
297- let offload_entry_name = format ! ( ".offloading.entry_name.{num }" ) ;
313+ let offload_entry_name = format ! ( ".offloading.entry_name.{symbol }" ) ;
298314
299315 let initializer = crate :: common:: bytes_in_context ( cx. llcx , c_val) ;
300316 let llglobal = add_unnamed_global ( & cx, & offload_entry_name, initializer, InternalLinkage ) ;
301317 llvm:: set_alignment ( llglobal, Align :: ONE ) ;
302318 llvm:: set_section ( llglobal, c".llvm.rodata.offloading" ) ;
303- let name = format ! ( ".offloading.entry.kernel_{num}" ) ;
319+
320+ // Not actively used yet, for calling real kernels
321+ let name = format ! ( ".offloading.entry.{symbol}" ) ;
304322
305323 // See the __tgt_offload_entry documentation above.
306324 let elems = TgtOffloadEntry :: new ( & cx, region_id, llglobal) ;
@@ -317,7 +335,57 @@ fn gen_define_handling<'ll>(
317335 ( memtransfer_types, region_id)
318336}
319337
320- pub ( crate ) fn declare_offload_fn < ' ll > (
338+ // TODO(Sa4dUs): move this to a proper place
339+ fn get_payload_size < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> u64 {
340+ match ty. kind ( ) {
341+ /*
342+ rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
343+ rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
344+ rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
345+ rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
346+ rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
347+ rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
348+ rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
349+ rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
350+ rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
351+ rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
352+ rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
353+ rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
354+ */
355+ ty:: Ref ( _, inner, _) => get_payload_size ( tcx, * inner) ,
356+ /*
357+ rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
358+ rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
359+ rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
360+ rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
361+ rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
362+ rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
363+ rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
364+ rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
365+ rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
366+ rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
367+ rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
368+ rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
369+ rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
370+ rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
371+ rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
372+ rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
373+ */
374+ _ => {
375+ tcx
376+ // TODO(Sa4dUs): Maybe `.as_query_input()`?
377+ . layout_of ( PseudoCanonicalInput {
378+ typing_env : TypingEnv :: fully_monomorphized ( ) ,
379+ value : ty,
380+ } )
381+ . unwrap ( )
382+ . size
383+ . bytes ( )
384+ }
385+ }
386+ }
387+
388+ fn declare_offload_fn < ' ll > (
321389 cx : & ' ll SimpleCx < ' _ > ,
322390 name : & str ,
323391 ty : & ' ll llvm:: Type ,
@@ -352,10 +420,13 @@ pub(crate) fn declare_offload_fn<'ll>(
352420// 4. set insert point after kernel call.
353421// 5. generate all the GEPS and stores, to be used in 6)
354422// 6. generate __tgt_target_data_end calls to move data from the GPU
355- fn gen_call_handling < ' ll > (
356- cx : & ' ll SimpleCx < ' _ > ,
423+ pub ( crate ) fn gen_call_handling < ' ll > (
424+ cx : & SimpleCx < ' ll > ,
425+ bb : & BasicBlock ,
426+ kernels : & [ & ' ll llvm:: Value ] ,
357427 memtransfer_types : & [ & ' ll llvm:: Value ] ,
358428 region_ids : & [ & ' ll llvm:: Value ] ,
429+ llfn : & ' ll Value ,
359430) {
360431 let ( tgt_decl, tgt_target_kernel_ty) = generate_launcher ( & cx) ;
361432 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
@@ -368,27 +439,14 @@ fn gen_call_handling<'ll>(
368439 let tgt_kernel_decl = KernelArgsTy :: new_decl ( & cx) ;
369440 let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
370441
371- let main_fn = cx. get_function ( "main" ) ;
372- let Some ( main_fn) = main_fn else { return } ;
373- let kernel_name = "kernel_1" ;
374- let call = unsafe {
375- llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) )
376- } ;
377- let Some ( kernel_call) = call else {
378- return ;
379- } ;
380- let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
381- let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) . unwrap ( ) } ;
382- let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
383-
384- let types = cx. func_params_types ( cx. get_type_of_global ( called) ) ;
442+ let mut builder = SBuilder :: build ( cx, bb) ;
443+
444+ let types = cx. func_params_types ( cx. get_type_of_global ( kernels[ 0 ] ) ) ;
385445 let num_args = types. len ( ) as u64 ;
386446
387447 // Step 0)
388448 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
389449 // %6 = alloca %struct.__tgt_bin_desc, align 8
390- unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
391-
392450 let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
393451
394452 let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
@@ -404,15 +462,14 @@ fn gen_call_handling<'ll>(
404462 let a5 = builder. direct_alloca ( tgt_kernel_decl, Align :: EIGHT , "kernel_args" ) ;
405463
406464 // Step 1)
407- unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
408465 builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
409466
410467 // Now we allocate once per function param, a copy to be passed to one of our maps.
411468 let mut vals = vec ! [ ] ;
412469 let mut geps = vec ! [ ] ;
413470 let i32_0 = cx. get_const_i32 ( 0 ) ;
414- for index in 0 ..types . len ( ) {
415- let v = unsafe { llvm:: LLVMGetOperand ( kernel_call , index as u32 ) . unwrap ( ) } ;
471+ for index in 0 ..num_args {
472+ let v = unsafe { llvm:: LLVMGetParam ( llfn , index as u32 ) } ;
416473 let gep = builder. inbounds_gep ( cx. type_f32 ( ) , v, & [ i32_0] ) ;
417474 vals. push ( v) ;
418475 geps. push ( gep) ;
@@ -504,13 +561,8 @@ fn gen_call_handling<'ll>(
504561 region_ids[ 0 ] ,
505562 a5,
506563 ] ;
507- let offload_success = builder. call ( tgt_target_kernel_ty, tgt_decl, & args, None ) ;
564+ builder. call ( tgt_target_kernel_ty, tgt_decl, & args, None ) ;
508565 // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
509- unsafe {
510- let next = llvm:: LLVMGetNextInstruction ( offload_success) . unwrap ( ) ;
511- llvm:: LLVMRustPositionAfter ( builder. llbuilder , next) ;
512- llvm:: LLVMInstructionEraseFromParent ( next) ;
513- }
514566
515567 // Step 4)
516568 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
@@ -519,8 +571,4 @@ fn gen_call_handling<'ll>(
519571 builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
520572
521573 drop ( builder) ;
522- // FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
523- // and then delete the call to the CPU version. In the future, we should use an intrinsic which
524- // directly resolves to a call to the GPU version.
525- unsafe { llvm:: LLVMDeleteFunction ( called) } ;
526574}
0 commit comments