@@ -58,19 +58,25 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
5858 at_one
5959}
6060
61- // The meaning of the __tgt_offload_entry (as per llvm docs) is
62- // Type, Identifier, Description
63- // void*, addr, Address of global symbol within device image (function or global)
64- // char*, name, Name of the symbol
65- // size_t, size, Size of the entry info (0 if it is a function)
66- // int32_t, flags, Flags associated with the entry (see Target Region Entry Flags)
67- // int32_t, reserved, Reserved, to be used by the runtime library.
6861pub ( crate ) fn add_tgt_offload_entry < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
6962 let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
7063 let tptr = cx. type_ptr ( ) ;
7164 let ti64 = cx. type_i64 ( ) ;
7265 let ti32 = cx. type_i32 ( ) ;
7366 let ti16 = cx. type_i16 ( ) ;
67+ // For each kernel to run on the gpu, we will later generate one entry of this type.
68+ // coppied from LLVM
69+ // typedef struct {
70+ // uint64_t Reserved;
71+ // uint16_t Version;
72+ // uint16_t Kind;
73+ // uint32_t Flags; Flags associated with the entry (see Target Region Entry Flags)
74+ // void *Address; Address of global symbol within device image (function or global)
75+ // char *SymbolName;
76+ // uint64_t Size; Size of the entry info (0 if it is a function)
77+ // uint64_t Data;
78+ // void *AuxAddr;
79+ // } __tgt_offload_entry;
7480 let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
7581 cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
7682 offload_entry_ty
@@ -83,19 +89,30 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
8389 let ti32 = cx. type_i32 ( ) ;
8490 let tarr = cx. type_array ( ti32, 3 ) ;
8591
86- // For each kernel to run on the gpu, we will later generate one entry of this type.
87- // coppied from LLVM
88- // typedef struct {
89- // uint64_t Reserved;
90- // uint16_t Version;
91- // uint16_t Kind;
92- // uint32_t Flags;
93- // void *Address;
94- // char *SymbolName;
95- // uint64_t Size;
96- // uint64_t Data;
97- // void *AuxAddr;
98- // } __tgt_offload_entry;
92+ // Taken from the LLVM APITypes.h declaration:
93+ //struct KernelArgsTy {
94+ // uint32_t Version = 0; // Version of this struct for ABI compatibility.
95+ // uint32_t NumArgs = 0; // Number of arguments in each input pointer.
96+ // void **ArgBasePtrs =
97+ // nullptr; // Base pointer of each argument (e.g. a struct).
98+ // void **ArgPtrs = nullptr; // Pointer to the argument data.
99+ // int64_t *ArgSizes = nullptr; // Size of the argument data in bytes.
100+ // int64_t *ArgTypes = nullptr; // Type of the data (e.g. to / from).
101+ // void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
102+ // void **ArgMappers = nullptr; // User-defined mappers, possibly null.
103+ // uint64_t Tripcount =
104+ // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
105+ // struct {
106+ // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
107+ // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
108+ // uint64_t Unused : 62;
109+ // } Flags = {0, 0, 0};
110+ // // The number of teams (for x,y,z dimension).
111+ // uint32_t NumTeams[3] = {0, 0, 0};
112+ // // The number of threads (for x,y,z dimension).
113+ // uint32_t ThreadLimit[3] = {0, 0, 0};
114+ // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
115+ //};
99116 let kernel_elements =
100117 vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
101118
@@ -180,7 +197,7 @@ fn gen_define_handling<'ll>(
180197
181198 // We do not know their size anymore at this level, so hardcode a placeholder.
182199 // A follow-up pr will track these from the frontend, where we still have Rust types.
183- // Then, we will be able to figure out that e.g. `&[f32;1024 ]` will result in 32*1024 bytes.
200+ // Then, we will be able to figure out that e.g. `&[f32;256 ]` will result in 4*256 bytes.
184201 // I decided that 1024 bytes is a great placeholder value for now.
185202 add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 1024 ; num_ptr_types] ) ;
186203 // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
@@ -285,135 +302,139 @@ fn gen_call_handling<'ll>(
285302 let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
286303
287304 let main_fn = cx. get_function ( "main" ) ;
288- if let Some ( main_fn) = main_fn {
289- let kernel_name = "kernel_1" ;
290- let call = unsafe {
291- llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) )
292- } ;
293- let kernel_call = if call. is_some ( ) {
294- call. unwrap ( )
295- } else {
296- return ;
297- } ;
298- let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
299- let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) . unwrap ( ) } ;
300- let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
301-
302- let types = cx. func_params_types ( cx. get_type_of_global ( called) ) ;
303- let num_args = types. len ( ) as u64 ;
304-
305- // Step 0)
306- // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
307- // %6 = alloca %struct.__tgt_bin_desc, align 8
308- unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
309-
310- let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
311-
312- let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
313- // Baseptr are just the input pointer to the kernel, stored in a local alloca
314- let a1 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
315- // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
316- let a2 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
317- // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
318- let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
319- let a4 = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
320- // Now we allocate once per function param, a copy to be passed to one of our maps.
321- let mut vals = vec ! [ ] ;
322- let mut geps = vec ! [ ] ;
323- let i32_0 = cx. get_const_i32 ( 0 ) ;
324- for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
325- // get function arg, store it into the alloca, and read it.
326- let p = llvm:: get_param ( called, index as u32 ) ;
327- let name = llvm:: get_value_name ( p) ;
328- let name = str:: from_utf8 ( name) . unwrap ( ) ;
329- let arg_name = CString :: new ( format ! ( "{name}.addr" ) ) . unwrap ( ) ;
330- let alloca =
331- unsafe { llvm:: LLVMBuildAlloca ( builder. llbuilder , in_ty, arg_name. as_ptr ( ) ) } ;
332- builder. store ( p, alloca, Align :: EIGHT ) ;
333- let val = builder. load ( in_ty, alloca, Align :: EIGHT ) ;
334- let gep = builder. inbounds_gep ( cx. type_f32 ( ) , val, & [ i32_0] ) ;
335- vals. push ( val) ;
336- geps. push ( gep) ;
337- }
338-
339- // Step 1)
340- unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
341- builder. memset (
342- tgt_bin_desc_alloca,
343- cx. get_const_i8 ( 0 ) ,
344- cx. get_const_i64 ( 32 ) ,
345- Align :: from_bytes ( 8 ) . unwrap ( ) ,
346- ) ;
347-
348- let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
349- let register_lib_decl = declare_offload_fn ( & cx, "__tgt_register_lib" , mapper_fn_ty) ;
350- let unregister_lib_decl = declare_offload_fn ( & cx, "__tgt_unregister_lib" , mapper_fn_ty) ;
351- let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
352- let init_rtls_decl = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
353-
354- // call void @__tgt_register_lib(ptr noundef %6)
355- builder. call ( mapper_fn_ty, register_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
356- // call void @__tgt_init_all_rtls()
357- builder. call ( init_ty, init_rtls_decl, & [ ] , None ) ;
358-
359- for i in 0 ..num_args {
360- let idx = cx. get_const_i32 ( i) ;
361- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
362- builder. store ( vals[ i as usize ] , gep1, Align :: EIGHT ) ;
363- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, idx] ) ;
364- builder. store ( geps[ i as usize ] , gep2, Align :: EIGHT ) ;
365- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, idx] ) ;
366- builder. store ( cx. get_const_i64 ( 1024 ) , gep3, Align :: EIGHT ) ;
367- }
305+ let Some ( main_fn) = main_fn else { return } ;
306+ let kernel_name = "kernel_1" ;
307+ let call = unsafe {
308+ llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) )
309+ } ;
310+ let Some ( kernel_call) = call else {
311+ return ;
312+ } ;
313+ let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
314+ let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) . unwrap ( ) } ;
315+ let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
316+
317+ let types = cx. func_params_types ( cx. get_type_of_global ( called) ) ;
318+ let num_args = types. len ( ) as u64 ;
319+
320+ // Step 0)
321+ // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
322+ // %6 = alloca %struct.__tgt_bin_desc, align 8
323+ unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
324+
325+ let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
326+
327+ let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
328+ // Baseptr are just the input pointer to the kernel, stored in a local alloca
329+ let a1 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
330+ // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
331+ let a2 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
332+ // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
333+ let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
334+ let a4 = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
335+ // Now we allocate once per function param, a copy to be passed to one of our maps.
336+ let mut vals = vec ! [ ] ;
337+ let mut geps = vec ! [ ] ;
338+ let i32_0 = cx. get_const_i32 ( 0 ) ;
339+ for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
340+ // get function arg, store it into the alloca, and read it.
341+ let p = llvm:: get_param ( called, index as u32 ) ;
342+ let name = llvm:: get_value_name ( p) ;
343+ let name = str:: from_utf8 ( name) . unwrap ( ) ;
344+ let arg_name = format ! ( "{name}.addr" ) ;
345+ let alloca = builder. direct_alloca ( in_ty, Align :: EIGHT , & arg_name) ;
346+
347+ builder. store ( p, alloca, Align :: EIGHT ) ;
348+ let val = builder. load ( in_ty, alloca, Align :: EIGHT ) ;
349+ let gep = builder. inbounds_gep ( cx. type_f32 ( ) , val, & [ i32_0] ) ;
350+ vals. push ( val) ;
351+ geps. push ( gep) ;
352+ }
368353
369- // Step 2)
370- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
371- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
372- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
373-
374- let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
375- let o_type = o_types[ 0 ] ;
376- let s_ident_t = generate_at_one ( & cx) ;
377- let args = vec ! [
378- s_ident_t,
379- cx. get_const_i64( u64 :: MAX ) ,
380- cx. get_const_i32( num_args) ,
381- gep1,
382- gep2,
383- gep3,
384- o_type,
385- nullptr,
386- nullptr,
387- ] ;
388- builder. call ( fn_ty, begin_mapper_decl, & args, None ) ;
389-
390- // Step 4)
391- unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
392-
393- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
394- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
395- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
396-
397- let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
398- let o_type = o_types[ 0 ] ;
399- let args = vec ! [
400- s_ident_t,
401- cx. get_const_i64( u64 :: MAX ) ,
402- cx. get_const_i32( num_args) ,
403- gep1,
404- gep2,
405- gep3,
406- o_type,
407- nullptr,
408- nullptr,
409- ] ;
410- builder. call ( fn_ty, end_mapper_decl, & args, None ) ;
411- builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
412-
413- // With this we generated the following begin and end mappers. We could easily generate the
414- // update mapper in an update.
415- // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
416- // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
417- // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
354+ // Step 1)
355+ unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
356+ builder. memset (
357+ tgt_bin_desc_alloca,
358+ cx. get_const_i8 ( 0 ) ,
359+ cx. get_const_i64 ( 32 ) ,
360+ Align :: from_bytes ( 8 ) . unwrap ( ) ,
361+ ) ;
362+
363+ let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
364+ let register_lib_decl = declare_offload_fn ( & cx, "__tgt_register_lib" , mapper_fn_ty) ;
365+ let unregister_lib_decl = declare_offload_fn ( & cx, "__tgt_unregister_lib" , mapper_fn_ty) ;
366+ let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
367+ let init_rtls_decl = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
368+
369+ // call void @__tgt_register_lib(ptr noundef %6)
370+ builder. call ( mapper_fn_ty, register_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
371+ // call void @__tgt_init_all_rtls()
372+ builder. call ( init_ty, init_rtls_decl, & [ ] , None ) ;
373+
374+ for i in 0 ..num_args {
375+ let idx = cx. get_const_i32 ( i) ;
376+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
377+ builder. store ( vals[ i as usize ] , gep1, Align :: EIGHT ) ;
378+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, idx] ) ;
379+ builder. store ( geps[ i as usize ] , gep2, Align :: EIGHT ) ;
380+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, idx] ) ;
381+ // As mentioned above, we don't use Rust type informatino yet. So for now we will just
382+ // assume that we have 1024 bytes, 256 f32 values.
383+ // FIXME(offload): write an offload frontend and handle arbitrary types.
384+ builder. store ( cx. get_const_i64 ( 1024 ) , gep3, Align :: EIGHT ) ;
418385 }
386+
387+ // Step 2)
388+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
389+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
390+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
391+
392+ let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
393+ let o_type = o_types[ 0 ] ;
394+ let s_ident_t = generate_at_one ( & cx) ;
395+ let args = vec ! [
396+ s_ident_t,
397+ cx. get_const_i64( u64 :: MAX ) ,
398+ cx. get_const_i32( num_args) ,
399+ gep1,
400+ gep2,
401+ gep3,
402+ o_type,
403+ nullptr,
404+ nullptr,
405+ ] ;
406+ builder. call ( fn_ty, begin_mapper_decl, & args, None ) ;
407+
408+ // Step 3)
409+ // Here we will add code for the actual kernel launches in a follow-up PR.
410+ // FIXME(offload): launch kernels
411+
412+ // Step 4)
413+ unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
414+
415+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
416+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
417+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
418+
419+ let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
420+ let o_type = o_types[ 0 ] ;
421+ let args = vec ! [
422+ s_ident_t,
423+ cx. get_const_i64( u64 :: MAX ) ,
424+ cx. get_const_i32( num_args) ,
425+ gep1,
426+ gep2,
427+ gep3,
428+ o_type,
429+ nullptr,
430+ nullptr,
431+ ] ;
432+ builder. call ( fn_ty, end_mapper_decl, & args, None ) ;
433+ builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
434+
435+ // With this we generated the following begin and end mappers. We could easily generate the
436+ // update mapper in an update.
437+ // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
438+ // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
439+ // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
419440}
0 commit comments