@@ -18,7 +18,7 @@ pub(crate) fn handle_gpu_code<'ll>(
1818 // The offload memory transfer type for each kernel
1919 let mut memtransfer_types = vec ! [ ] ;
2020 let mut region_ids = vec ! [ ] ;
21- let offload_entry_ty = add_tgt_offload_entry ( & cx) ;
21+ let offload_entry_ty = TgtOffloadEntry :: new_decl ( & cx) ;
2222 for num in 0 ..9 {
2323 let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
2424 if let Some ( kernel) = kernel {
@@ -52,7 +52,6 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
5252// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
5353// offloaded was defined.
5454fn generate_at_one < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Value {
55- // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
5655 let unknown_txt = ";unknown;unknown;0;0;;" ;
5756 let c_entry_name = CString :: new ( unknown_txt) . unwrap ( ) ;
5857 let c_val = c_entry_name. as_bytes_with_nul ( ) ;
@@ -77,15 +76,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7776 at_one
7877}
7978
80- pub ( crate ) fn add_tgt_offload_entry < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
81- let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
82- let tptr = cx. type_ptr ( ) ;
83- let ti64 = cx. type_i64 ( ) ;
84- let ti32 = cx. type_i32 ( ) ;
85- let ti16 = cx. type_i16 ( ) ;
86- // For each kernel to run on the gpu, we will later generate one entry of this type.
87- // copied from LLVM
88- // typedef struct {
79+ struct TgtOffloadEntry {
8980 // uint64_t Reserved;
9081 // uint16_t Version;
9182 // uint16_t Kind;
@@ -95,21 +86,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
9586 // uint64_t Size; Size of the entry info (0 if it is a function)
9687 // uint64_t Data;
9788 // void *AuxAddr;
98- // } __tgt_offload_entry;
99- let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
100- cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
101- offload_entry_ty
10289}
10390
104- fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
105- let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
106- let tptr = cx. type_ptr ( ) ;
107- let ti64 = cx. type_i64 ( ) ;
108- let ti32 = cx. type_i32 ( ) ;
109- let tarr = cx. type_array ( ti32, 3 ) ;
91+ impl TgtOffloadEntry {
92+ pub ( crate ) fn new_decl < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
93+ let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
94+ let tptr = cx. type_ptr ( ) ;
95+ let ti64 = cx. type_i64 ( ) ;
96+ let ti32 = cx. type_i32 ( ) ;
97+ let ti16 = cx. type_i16 ( ) ;
98+ // For each kernel to run on the gpu, we will later generate one entry of this type.
99+ // copied from LLVM
100+ let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
101+ cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
102+ offload_entry_ty
103+ }
104+
105+ fn new < ' ll > (
106+ cx : & ' ll SimpleCx < ' _ > ,
107+ region_id : & ' ll Value ,
108+ llglobal : & ' ll Value ,
109+ ) -> Vec < & ' ll Value > {
110+ let reserved = cx. get_const_i64 ( 0 ) ;
111+ let version = cx. get_const_i16 ( 1 ) ;
112+ let kind = cx. get_const_i16 ( 1 ) ;
113+ let flags = cx. get_const_i32 ( 0 ) ;
114+ let size = cx. get_const_i64 ( 0 ) ;
115+ let data = cx. get_const_i64 ( 0 ) ;
116+ let aux_addr = cx. const_null ( cx. type_ptr ( ) ) ;
117+ vec ! [ reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
118+ }
119+ }
110120
111- // Taken from the LLVM APITypes.h declaration:
112- // struct KernelArgsTy {
121+ // Taken from the LLVM APITypes.h declaration:
122+ struct KernelArgsTy {
113123 // uint32_t Version = 0; // Version of this struct for ABI compatibility.
114124 // uint32_t NumArgs = 0; // Number of arguments in each input pointer.
115125 // void **ArgBasePtrs =
@@ -120,8 +130,8 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
120130 // void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
121131 // void **ArgMappers = nullptr; // User-defined mappers, possibly null.
122132 // uint64_t Tripcount =
123- // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
124- // struct {
133+ // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
134+ // struct {
125135 // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
126136 // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
127137 // uint64_t Unused : 62;
@@ -131,12 +141,53 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
131141 // // The number of threads (for x,y,z dimension).
132142 // uint32_t ThreadLimit[3] = {0, 0, 0};
133143 // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
134- //};
135- let kernel_elements =
136- vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
144+ }
145+
146+ impl KernelArgsTy {
147+ const OFFLOAD_VERSION : u64 = 3 ;
148+ const FLAGS : u64 = 0 ;
149+ const TRIPCOUNT : u64 = 0 ;
150+ fn new_decl < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll Type {
151+ let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
152+ let tptr = cx. type_ptr ( ) ;
153+ let ti64 = cx. type_i64 ( ) ;
154+ let ti32 = cx. type_i32 ( ) ;
155+ let tarr = cx. type_array ( ti32, 3 ) ;
156+
157+ let kernel_elements =
158+ vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
159+
160+ cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
161+ kernel_arguments_ty
162+ }
137163
138- cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
139- kernel_arguments_ty
164+ fn new < ' ll > (
165+ cx : & ' ll SimpleCx < ' _ > ,
166+ num_args : u64 ,
167+ memtransfer_types : & [ & ' ll Value ] ,
168+ geps : [ & ' ll Value ; 3 ] ,
169+ ) -> [ ( Align , & ' ll Value ) ; 13 ] {
170+ let four = Align :: from_bytes ( 4 ) . expect ( "4 Byte alignment should work" ) ;
171+ let eight = Align :: EIGHT ;
172+ let mut values = vec ! [ ] ;
173+ values. push ( ( four, cx. get_const_i32 ( KernelArgsTy :: OFFLOAD_VERSION ) ) ) ;
174+ values. push ( ( four, cx. get_const_i32 ( num_args) ) ) ;
175+ values. push ( ( eight, geps[ 0 ] ) ) ;
176+ values. push ( ( eight, geps[ 1 ] ) ) ;
177+ values. push ( ( eight, geps[ 2 ] ) ) ;
178+ values. push ( ( eight, memtransfer_types[ 0 ] ) ) ;
179+ // The next two are debug infos. FIXME(offload): set them
180+ values. push ( ( eight, cx. const_null ( cx. type_ptr ( ) ) ) ) ;
181+ values. push ( ( eight, cx. const_null ( cx. type_ptr ( ) ) ) ) ;
182+ values. push ( ( eight, cx. get_const_i64 ( KernelArgsTy :: TRIPCOUNT ) ) ) ;
183+ values. push ( ( eight, cx. get_const_i64 ( KernelArgsTy :: FLAGS ) ) ) ;
184+ let ti32 = cx. type_i32 ( ) ;
185+ let ci32_0 = cx. get_const_i32 ( 0 ) ;
186+ values. push ( ( four, cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 2097152 ) , ci32_0, ci32_0] ) ) ) ;
187+ values. push ( ( four, cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 256 ) , ci32_0, ci32_0] ) ) ) ;
188+ values. push ( ( four, cx. get_const_i32 ( 0 ) ) ) ;
189+ values. try_into ( ) . expect ( "tgt_kernel_arguments construction failed" )
190+ }
140191}
141192
142193fn gen_tgt_data_mappers < ' ll > (
@@ -242,19 +293,10 @@ fn gen_define_handling<'ll>(
242293 let llglobal = add_unnamed_global ( & cx, & offload_entry_name, initializer, InternalLinkage ) ;
243294 llvm:: set_alignment ( llglobal, Align :: ONE ) ;
244295 llvm:: set_section ( llglobal, c".llvm.rodata.offloading" ) ;
245-
246- // Not actively used yet, for calling real kernels
247296 let name = format ! ( ".offloading.entry.kernel_{num}" ) ;
248297
249298 // See the __tgt_offload_entry documentation above.
250- let reserved = cx. get_const_i64 ( 0 ) ;
251- let version = cx. get_const_i16 ( 1 ) ;
252- let kind = cx. get_const_i16 ( 1 ) ;
253- let flags = cx. get_const_i32 ( 0 ) ;
254- let size = cx. get_const_i64 ( 0 ) ;
255- let data = cx. get_const_i64 ( 0 ) ;
256- let aux_addr = cx. const_null ( cx. type_ptr ( ) ) ;
257- let elems = vec ! [ reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr] ;
299+ let elems = TgtOffloadEntry :: new ( & cx, region_id, llglobal) ;
258300
259301 let initializer = crate :: common:: named_struct ( offload_entry_ty, & elems) ;
260302 let c_name = CString :: new ( name) . unwrap ( ) ;
@@ -316,7 +358,7 @@ fn gen_call_handling<'ll>(
316358 let tgt_bin_desc = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
317359 cx. set_struct_body ( tgt_bin_desc, & tgt_bin_desc_ty, false ) ;
318360
319- let tgt_kernel_decl = gen_tgt_kernel_global ( & cx) ;
361+ let tgt_kernel_decl = KernelArgsTy :: new_decl ( & cx) ;
320362 let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
321363
322364 let main_fn = cx. get_function ( "main" ) ;
@@ -404,19 +446,19 @@ fn gen_call_handling<'ll>(
404446 a1 : & ' ll Value ,
405447 a2 : & ' ll Value ,
406448 a4 : & ' ll Value ,
407- ) -> ( & ' ll Value , & ' ll Value , & ' ll Value ) {
449+ ) -> [ & ' ll Value ; 3 ] {
408450 let i32_0 = cx. get_const_i32 ( 0 ) ;
409451
410452 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
411453 let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
412454 let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
413- ( gep1, gep2, gep3)
455+ [ gep1, gep2, gep3]
414456 }
415457
416458 fn generate_mapper_call < ' a , ' ll > (
417459 builder : & mut SBuilder < ' a , ' ll > ,
418460 cx : & ' ll SimpleCx < ' ll > ,
419- geps : ( & ' ll Value , & ' ll Value , & ' ll Value ) ,
461+ geps : [ & ' ll Value ; 3 ] ,
420462 o_type : & ' ll Value ,
421463 fn_to_call : & ' ll Value ,
422464 fn_ty : & ' ll Type ,
@@ -427,7 +469,7 @@ fn gen_call_handling<'ll>(
427469 let i64_max = cx. get_const_i64 ( u64:: MAX ) ;
428470 let num_args = cx. get_const_i32 ( num_args) ;
429471 let args =
430- vec ! [ s_ident_t, i64_max, num_args, geps. 0 , geps. 1 , geps. 2 , o_type, nullptr, nullptr] ;
472+ vec ! [ s_ident_t, i64_max, num_args, geps[ 0 ] , geps[ 1 ] , geps[ 2 ] , o_type, nullptr, nullptr] ;
431473 builder. call ( fn_ty, fn_to_call, & args, None ) ;
432474 }
433475
@@ -436,36 +478,20 @@ fn gen_call_handling<'ll>(
436478 let o = memtransfer_types[ 0 ] ;
437479 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
438480 generate_mapper_call ( & mut builder, & cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t) ;
481+ let values = KernelArgsTy :: new ( & cx, num_args, memtransfer_types, geps) ;
439482
440483 // Step 3)
441- let mut values = vec ! [ ] ;
442- let offload_version = cx. get_const_i32 ( 3 ) ;
443- values. push ( ( 4 , offload_version) ) ;
444- values. push ( ( 4 , cx. get_const_i32 ( num_args) ) ) ;
445- values. push ( ( 8 , geps. 0 ) ) ;
446- values. push ( ( 8 , geps. 1 ) ) ;
447- values. push ( ( 8 , geps. 2 ) ) ;
448- values. push ( ( 8 , memtransfer_types[ 0 ] ) ) ;
449- // The next two are debug infos. FIXME(offload) set them
450- values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
451- values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
452- values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
453- values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
454- let ti32 = cx. type_i32 ( ) ;
455- let ci32_0 = cx. get_const_i32 ( 0 ) ;
456- values. push ( ( 4 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 2097152 ) , ci32_0, ci32_0] ) ) ) ;
457- values. push ( ( 4 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 256 ) , ci32_0, ci32_0] ) ) ) ;
458- values. push ( ( 4 , cx. get_const_i32 ( 0 ) ) ) ;
459-
484+ // Here we fill the KernelArgsTy, see the documentation above
460485 for ( i, value) in values. iter ( ) . enumerate ( ) {
461486 let ptr = builder. inbounds_gep ( tgt_kernel_decl, a5, & [ i32_0, cx. get_const_i32 ( i as u64 ) ] ) ;
462- builder. store ( value. 1 , ptr, Align :: from_bytes ( value. 0 ) . unwrap ( ) ) ;
487+ builder. store ( value. 1 , ptr, value. 0 ) ;
463488 }
464489
465490 let args = vec ! [
466491 s_ident_t,
467- // MAX == -1
468- cx. get_const_i64( u64 :: MAX ) ,
492+ // FIXME(offload) give users a way to select which GPU to use.
493+ cx. get_const_i64( u64 :: MAX ) , // MAX == -1.
494+ // FIXME(offload): Don't hardcode the numbers of threads in the future.
469495 cx. get_const_i32( 2097152 ) ,
470496 cx. get_const_i32( 256 ) ,
471497 region_ids[ 0 ] ,
@@ -480,19 +506,14 @@ fn gen_call_handling<'ll>(
480506 }
481507
482508 // Step 4)
483- //unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
484-
485509 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
486510 generate_mapper_call ( & mut builder, & cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t) ;
487511
488512 builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
489513
490514 drop ( builder) ;
515+ // FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
516+ // and then delete the call to the CPU version. In the future, we should use an intrinsic which
517+ // directly resolves to a call to the GPU version.
491518 unsafe { llvm:: LLVMDeleteFunction ( called) } ;
492-
493- // With this we generated the following begin and end mappers. We could easily generate the
494- // update mapper in an update.
495- // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
496- // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
497- // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
498519}
0 commit comments