@@ -19,7 +19,7 @@ pub(crate) fn handle_gpu_code<'ll>(
1919 let mut o_types = vec ! [ ] ;
2020 let mut kernels = vec ! [ ] ;
2121 let mut region_ids = vec ! [ ] ;
22- let offload_entry_ty = add_tgt_offload_entry ( & cx) ;
22+ let offload_entry_ty = TgtOffloadEntry :: new_decl ( & cx) ;
2323 for num in 0 ..9 {
2424 let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
2525 if let Some ( kernel) = kernel {
@@ -54,7 +54,6 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
5454// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
5555// offloaded was defined.
5656fn generate_at_one < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Value {
57- // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
5857 let unknown_txt = ";unknown;unknown;0;0;;" ;
5958 let c_entry_name = CString :: new ( unknown_txt) . unwrap ( ) ;
6059 let c_val = c_entry_name. as_bytes_with_nul ( ) ;
@@ -79,15 +78,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7978 at_one
8079}
8180
82- pub ( crate ) fn add_tgt_offload_entry < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
83- let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
84- let tptr = cx. type_ptr ( ) ;
85- let ti64 = cx. type_i64 ( ) ;
86- let ti32 = cx. type_i32 ( ) ;
87- let ti16 = cx. type_i16 ( ) ;
88- // For each kernel to run on the gpu, we will later generate one entry of this type.
89- // copied from LLVM
90- // typedef struct {
81+ struct TgtOffloadEntry {
9182 // uint64_t Reserved;
9283 // uint16_t Version;
9384 // uint16_t Kind;
@@ -97,21 +88,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
9788 // uint64_t Size; Size of the entry info (0 if it is a function)
9889 // uint64_t Data;
9990 // void *AuxAddr;
100- // } __tgt_offload_entry;
101- let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
102- cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
103- offload_entry_ty
10491}
10592
106- fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
107- let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
108- let tptr = cx. type_ptr ( ) ;
109- let ti64 = cx. type_i64 ( ) ;
110- let ti32 = cx. type_i32 ( ) ;
111- let tarr = cx. type_array ( ti32, 3 ) ;
93+ impl TgtOffloadEntry {
94+ pub ( crate ) fn new_decl < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
95+ let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
96+ let tptr = cx. type_ptr ( ) ;
97+ let ti64 = cx. type_i64 ( ) ;
98+ let ti32 = cx. type_i32 ( ) ;
99+ let ti16 = cx. type_i16 ( ) ;
100+ // For each kernel to run on the gpu, we will later generate one entry of this type.
101+ // copied from LLVM
102+ let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
103+ cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
104+ offload_entry_ty
105+ }
106+
107+ fn new < ' ll > (
108+ cx : & ' ll SimpleCx < ' _ > ,
109+ region_id : & ' ll Value ,
110+ llglobal : & ' ll Value ,
111+ ) -> Vec < & ' ll Value > {
112+ let reserved = cx. get_const_i64 ( 0 ) ;
113+ let version = cx. get_const_i16 ( 1 ) ;
114+ let kind = cx. get_const_i16 ( 1 ) ;
115+ let flags = cx. get_const_i32 ( 0 ) ;
116+ let size = cx. get_const_i64 ( 0 ) ;
117+ let data = cx. get_const_i64 ( 0 ) ;
118+ let aux_addr = cx. const_null ( cx. type_ptr ( ) ) ;
119+ vec ! [ reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
120+ }
121+ }
112122
113- // Taken from the LLVM APITypes.h declaration:
114- // struct KernelArgsTy {
123+ // Taken from the LLVM APITypes.h declaration:
124+ struct KernelArgsTy {
115125 // uint32_t Version = 0; // Version of this struct for ABI compatibility.
116126 // uint32_t NumArgs = 0; // Number of arguments in each input pointer.
117127 // void **ArgBasePtrs =
@@ -122,8 +132,8 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
122132 // void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
123133 // void **ArgMappers = nullptr; // User-defined mappers, possibly null.
124134 // uint64_t Tripcount =
125- // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
126- // struct {
135+ // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
136+ // struct {
127137 // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
128138 // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
129139 // uint64_t Unused : 62;
@@ -133,12 +143,53 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
133143 // // The number of threads (for x,y,z dimension).
134144 // uint32_t ThreadLimit[3] = {0, 0, 0};
135145 // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
136- //};
137- let kernel_elements =
138- vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
146+ }
147+
148+ impl KernelArgsTy {
149+ const OFFLOAD_VERSION : u64 = 3 ;
150+ const FLAGS : u64 = 0 ;
151+ const TRIPCOUNT : u64 = 0 ;
152+ fn new_decl < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll Type {
153+ let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
154+ let tptr = cx. type_ptr ( ) ;
155+ let ti64 = cx. type_i64 ( ) ;
156+ let ti32 = cx. type_i32 ( ) ;
157+ let tarr = cx. type_array ( ti32, 3 ) ;
158+
159+ let kernel_elements =
160+ vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
161+
162+ cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
163+ kernel_arguments_ty
164+ }
139165
140- cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
141- kernel_arguments_ty
166+ fn new < ' ll > (
167+ cx : & ' ll SimpleCx < ' _ > ,
168+ num_args : u64 ,
169+ o_types : & [ & ' ll Value ] ,
170+ geps : [ & ' ll Value ; 3 ] ,
171+ ) -> [ ( Align , & ' ll Value ) ; 13 ] {
172+ let four = Align :: from_bytes ( 4 ) . expect ( "4 Byte alignment should work" ) ;
173+ let eight = Align :: EIGHT ;
174+ let mut values = vec ! [ ] ;
175+ values. push ( ( four, cx. get_const_i32 ( KernelArgsTy :: OFFLOAD_VERSION ) ) ) ;
176+ values. push ( ( four, cx. get_const_i32 ( num_args) ) ) ;
177+ values. push ( ( eight, geps[ 0 ] ) ) ;
178+ values. push ( ( eight, geps[ 1 ] ) ) ;
179+ values. push ( ( eight, geps[ 2 ] ) ) ;
180+ values. push ( ( eight, o_types[ 0 ] ) ) ;
181+ // The next two are debug infos. FIXME(offload): set them
182+ values. push ( ( eight, cx. const_null ( cx. type_ptr ( ) ) ) ) ;
183+ values. push ( ( eight, cx. const_null ( cx. type_ptr ( ) ) ) ) ;
184+ values. push ( ( eight, cx. get_const_i64 ( KernelArgsTy :: TRIPCOUNT ) ) ) ;
185+ values. push ( ( eight, cx. get_const_i64 ( KernelArgsTy :: FLAGS ) ) ) ;
186+ let ti32 = cx. type_i32 ( ) ;
187+ let ci32_0 = cx. get_const_i32 ( 0 ) ;
188+ values. push ( ( four, cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 2097152 ) , ci32_0, ci32_0] ) ) ) ;
189+ values. push ( ( four, cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 256 ) , ci32_0, ci32_0] ) ) ) ;
190+ values. push ( ( four, cx. get_const_i32 ( 0 ) ) ) ;
191+ values. try_into ( ) . expect ( "tgt_kernel_arguments construction failed" )
192+ }
142193}
143194
144195fn gen_tgt_data_mappers < ' ll > (
@@ -244,19 +295,10 @@ fn gen_define_handling<'ll>(
244295 let llglobal = add_unnamed_global ( & cx, & offload_entry_name, initializer, InternalLinkage ) ;
245296 llvm:: set_alignment ( llglobal, Align :: ONE ) ;
246297 llvm:: set_section ( llglobal, c".llvm.rodata.offloading" ) ;
247-
248- // Not actively used yet, for calling real kernels
249298 let name = format ! ( ".offloading.entry.kernel_{num}" ) ;
250299
251300 // See the __tgt_offload_entry documentation above.
252- let reserved = cx. get_const_i64 ( 0 ) ;
253- let version = cx. get_const_i16 ( 1 ) ;
254- let kind = cx. get_const_i16 ( 1 ) ;
255- let flags = cx. get_const_i32 ( 0 ) ;
256- let size = cx. get_const_i64 ( 0 ) ;
257- let data = cx. get_const_i64 ( 0 ) ;
258- let aux_addr = cx. const_null ( cx. type_ptr ( ) ) ;
259- let elems = vec ! [ reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr] ;
301+ let elems = TgtOffloadEntry :: new ( & cx, region_id, llglobal) ;
260302
261303 let initializer = crate :: common:: named_struct ( offload_entry_ty, & elems) ;
262304 let c_name = CString :: new ( name) . unwrap ( ) ;
@@ -319,7 +361,7 @@ fn gen_call_handling<'ll>(
319361 let tgt_bin_desc = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
320362 cx. set_struct_body ( tgt_bin_desc, & tgt_bin_desc_ty, false ) ;
321363
322- let tgt_kernel_decl = gen_tgt_kernel_global ( & cx) ;
364+ let tgt_kernel_decl = KernelArgsTy :: new_decl ( & cx) ;
323365 let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
324366
325367 let main_fn = cx. get_function ( "main" ) ;
@@ -407,19 +449,19 @@ fn gen_call_handling<'ll>(
407449 a1 : & ' ll Value ,
408450 a2 : & ' ll Value ,
409451 a4 : & ' ll Value ,
410- ) -> ( & ' ll Value , & ' ll Value , & ' ll Value ) {
452+ ) -> [ & ' ll Value ; 3 ] {
411453 let i32_0 = cx. get_const_i32 ( 0 ) ;
412454
413455 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
414456 let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
415457 let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
416- ( gep1, gep2, gep3)
458+ [ gep1, gep2, gep3]
417459 }
418460
419461 fn generate_mapper_call < ' a , ' ll > (
420462 builder : & mut SBuilder < ' a , ' ll > ,
421463 cx : & ' ll SimpleCx < ' ll > ,
422- geps : ( & ' ll Value , & ' ll Value , & ' ll Value ) ,
464+ geps : [ & ' ll Value ; 3 ] ,
423465 o_type : & ' ll Value ,
424466 fn_to_call : & ' ll Value ,
425467 fn_ty : & ' ll Type ,
@@ -430,7 +472,7 @@ fn gen_call_handling<'ll>(
430472 let i64_max = cx. get_const_i64 ( u64:: MAX ) ;
431473 let num_args = cx. get_const_i32 ( num_args) ;
432474 let args =
433- vec ! [ s_ident_t, i64_max, num_args, geps. 0 , geps. 1 , geps. 2 , o_type, nullptr, nullptr] ;
475+ vec ! [ s_ident_t, i64_max, num_args, geps[ 0 ] , geps[ 1 ] , geps[ 2 ] , o_type, nullptr, nullptr] ;
434476 builder. call ( fn_ty, fn_to_call, & args, None ) ;
435477 }
436478
@@ -439,36 +481,20 @@ fn gen_call_handling<'ll>(
439481 let o = o_types[ 0 ] ;
440482 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
441483 generate_mapper_call ( & mut builder, & cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t) ;
484+ let values = KernelArgsTy :: new ( & cx, num_args, o_types, geps) ;
442485
443486 // Step 3)
444- let mut values = vec ! [ ] ;
445- let offload_version = cx. get_const_i32 ( 3 ) ;
446- values. push ( ( 4 , offload_version) ) ;
447- values. push ( ( 4 , cx. get_const_i32 ( num_args) ) ) ;
448- values. push ( ( 8 , geps. 0 ) ) ;
449- values. push ( ( 8 , geps. 1 ) ) ;
450- values. push ( ( 8 , geps. 2 ) ) ;
451- values. push ( ( 8 , o_types[ 0 ] ) ) ;
452- // The next two are debug infos. FIXME(offload) set them
453- values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
454- values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
455- values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
456- values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
457- let ti32 = cx. type_i32 ( ) ;
458- let ci32_0 = cx. get_const_i32 ( 0 ) ;
459- values. push ( ( 4 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 2097152 ) , ci32_0, ci32_0] ) ) ) ;
460- values. push ( ( 4 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 256 ) , ci32_0, ci32_0] ) ) ) ;
461- values. push ( ( 4 , cx. get_const_i32 ( 0 ) ) ) ;
462-
487+ // Here we fill the KernelArgsTy, see the documentation above
463488 for ( i, value) in values. iter ( ) . enumerate ( ) {
464489 let ptr = builder. inbounds_gep ( tgt_kernel_decl, a5, & [ i32_0, cx. get_const_i32 ( i as u64 ) ] ) ;
465- builder. store ( value. 1 , ptr, Align :: from_bytes ( value. 0 ) . unwrap ( ) ) ;
490+ builder. store ( value. 1 , ptr, value. 0 ) ;
466491 }
467492
468493 let args = vec ! [
469494 s_ident_t,
470- // MAX == -1
471- cx. get_const_i64( u64 :: MAX ) ,
495+ // FIXME(offload) give users a way to select which GPU to use.
496+ cx. get_const_i64( u64 :: MAX ) , // MAX == -1.
497+ // FIXME(offload): Don't hardcode the numbers of threads in the future.
472498 cx. get_const_i32( 2097152 ) ,
473499 cx. get_const_i32( 256 ) ,
474500 region_ids[ 0 ] ,
@@ -483,19 +509,14 @@ fn gen_call_handling<'ll>(
483509 }
484510
485511 // Step 4)
486- //unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
487-
488512 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
489513 generate_mapper_call ( & mut builder, & cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t) ;
490514
491515 builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
492516
493517 drop ( builder) ;
518+ // FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
519+ // and then delete the call to the CPU version. In the future, we should use an intrinsic which
520+ // directly resolves to a call to the GPU version.
494521 unsafe { llvm:: LLVMDeleteFunction ( called) } ;
495-
496- // With this we generated the following begin and end mappers. We could easily generate the
497- // update mapper in an update.
498- // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
499- // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
500- // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
501522}
0 commit comments