@@ -25,7 +25,6 @@ fn create_struct_ty<'ll>(
2525 }
2626}
2727
28- //weak_odr hidden local_unnamed_addr addrspace(1) constant i32 0
2928pub ( crate ) fn gen_asdf < ' ll > ( cgcx : & CodegenContext < LlvmCodegenBackend > , old_cx : & SimpleCx < ' ll > ) {
3029 let llcx = unsafe { llvm:: LLVMRustContextCreate ( false ) } ;
3130 let module_name = CString :: new ( "offload.wrapper.module" ) . unwrap ( ) ;
@@ -236,7 +235,7 @@ pub(crate) fn handle_gpu_code<'ll>(
236235 cx : & ' ll SimpleCx < ' _ > ,
237236) {
238237 if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
239- let ( offload_entry_ty, at_one, begin, update, end, fn_ty) = gen_globals ( & cx) ;
238+ let ( offload_entry_ty, at_one, begin, update, end, tgt_bin_desc , fn_ty) = gen_globals ( & cx) ;
240239
241240 dbg ! ( "created struct" ) ;
242241 let mut o_types = vec ! [ ] ;
@@ -249,7 +248,7 @@ pub(crate) fn handle_gpu_code<'ll>(
249248 }
250249 }
251250 dbg ! ( "gen_call_handling" ) ;
252- gen_call_handling ( & cx, & kernels, at_one, begin, update, end, fn_ty, & o_types) ;
251+ gen_call_handling ( & cx, & kernels, at_one, begin, update, end, tgt_bin_desc , fn_ty, & o_types) ;
253252 gen_image_wrapper_module ( & cgcx, & cx) ;
254253 gen_asdf ( & cgcx, & cx) ;
255254 } else {
@@ -279,6 +278,7 @@ fn gen_globals<'ll>(
279278 & ' ll llvm:: Value ,
280279 & ' ll llvm:: Value ,
281280 & ' ll llvm:: Type ,
281+ & ' ll llvm:: Type ,
282282) {
283283 let offload_entry_ty = add_tgt_offload_entry ( & cx) ;
284284 let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
@@ -312,6 +312,11 @@ fn gen_globals<'ll>(
312312 let at_one = add_unnamed_global ( & cx, & "" , initializer, PrivateLinkage ) ;
313313 llvm:: set_alignment ( at_one, Align :: EIGHT ) ;
314314
315+ // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
316+ let tgt_bin_desc_ty = vec ! [ ti32, tptr, tptr, tptr] ;
317+ let tgt_bin_desc_name = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
318+ cx. set_struct_body ( tgt_bin_desc_name, & tgt_bin_desc_ty, false ) ;
319+
315320 // coppied from LLVM
316321 // typedef struct {
317322 // uint64_t Reserved;
@@ -379,7 +384,7 @@ fn gen_globals<'ll>(
379384 attributes:: apply_to_llfn ( bar, Function , & [ nounwind] ) ;
380385 attributes:: apply_to_llfn ( baz, Function , & [ nounwind] ) ;
381386
382- ( offload_entry_ty, at_one, foo, bar, baz, mapper_fn_ty)
387+ ( offload_entry_ty, at_one, foo, bar, baz, tgt_bin_desc_name , mapper_fn_ty)
383388}
384389
385390fn add_priv_unnamed_arr < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , vals : & [ u64 ] ) -> & ' ll llvm:: Value {
@@ -561,6 +566,7 @@ fn gen_call_handling<'ll>(
561566 begin : & ' ll llvm:: Value ,
562567 update : & ' ll llvm:: Value ,
563568 end : & ' ll llvm:: Value ,
569+ tgt_bin_desc : & ' ll llvm:: Type ,
564570 fn_ty : & ' ll llvm:: Type ,
565571 o_types : & [ & ' ll llvm:: Value ] ,
566572) {
@@ -586,7 +592,18 @@ fn gen_call_handling<'ll>(
586592 let mut names: Vec < & llvm:: Value > = Vec :: with_capacity ( num_args as usize ) ;
587593
588594 // Step 0)
595+ // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
596+ // %6 = alloca %struct.__tgt_bin_desc, align 8
589597 unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
598+
599+ let tgt_bin_desc_alloca = builder. my_alloca2 ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
600+ //fill_byte: &'ll Value,
601+ //size: &'ll Value,
602+ //align: Align,
603+ //flags: MemFlags,
604+ // call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false)
605+ // mem
606+
590607 let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
591608 // Baseptr are just the input pointer to the kernel, stored in a local alloca
592609 let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
@@ -616,6 +633,46 @@ fn gen_call_handling<'ll>(
616633
617634 // Step 1)
618635 unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
636+ builder. memset (
637+ tgt_bin_desc_alloca,
638+ cx. get_const_i8 ( 0 ) ,
639+ cx. get_const_i64 ( 32 ) ,
640+ Align :: from_bytes ( 8 ) . unwrap ( ) ,
641+ ) ;
642+
643+ let tptr = cx. type_ptr ( ) ;
644+ let mapper_fn_ty = cx. type_func ( & [ tptr] , cx. type_void ( ) ) ;
645+ let foo = crate :: declare:: declare_simple_fn (
646+ & cx,
647+ & "__tgt_register_lib" ,
648+ llvm:: CallConv :: CCallConv ,
649+ llvm:: UnnamedAddr :: No ,
650+ llvm:: Visibility :: Default ,
651+ mapper_fn_ty,
652+ ) ;
653+ let bar = crate :: declare:: declare_simple_fn (
654+ & cx,
655+ & "__tgt_unregister_lib" ,
656+ llvm:: CallConv :: CCallConv ,
657+ llvm:: UnnamedAddr :: No ,
658+ llvm:: Visibility :: Default ,
659+ mapper_fn_ty,
660+ ) ;
661+ let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
662+ let baz = crate :: declare:: declare_simple_fn (
663+ & cx,
664+ & "__tgt_init_all_rtls" ,
665+ llvm:: CallConv :: CCallConv ,
666+ llvm:: UnnamedAddr :: No ,
667+ llvm:: Visibility :: Default ,
668+ init_ty,
669+ ) ;
670+
671+ builder. call ( mapper_fn_ty, foo, & [ tgt_bin_desc_alloca] , None ) ;
672+ builder. call ( init_ty, baz, & [ ] , None ) ;
673+
674+ // call void @__tgt_register_lib(ptr noundef %6)
675+ // call void @__tgt_init_all_rtls()
619676 for i in 0 ..num_args {
620677 let idx = cx. get_const_i32 ( i) ;
621678 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
@@ -667,6 +724,7 @@ fn gen_call_handling<'ll>(
667724 nullptr,
668725 ] ;
669726 builder. call ( fn_ty, end, & args, None ) ;
727+ builder. call ( mapper_fn_ty, bar, & [ tgt_bin_desc_alloca] , None ) ;
670728
671729 // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
672730 // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
0 commit comments