@@ -43,7 +43,7 @@ use crate::errors::{
4343use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
4444use crate :: llvm:: { self , DiagnosticInfo } ;
4545use crate :: type_:: llvm_type_ptr;
46- use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
46+ use crate :: { LlvmCodegenBackend , ModuleLlvm , SimpleCx , base, common, llvm_util} ;
4747
4848pub ( crate ) fn llvm_err < ' a > ( dcx : DiagCtxtHandle < ' _ > , err : LlvmError < ' a > ) -> ! {
4949 match llvm:: last_error ( ) {
@@ -645,19 +645,70 @@ pub(crate) unsafe fn llvm_optimize(
645645 None
646646 } ;
647647
648- fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
649- unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
648+ fn handle_offload < ' ll > ( cx : & ' ll SimpleCx < ' _ > , old_fn : & llvm:: Value ) {
649+ let old_fn_ty = cx. get_type_of_global ( old_fn) ;
650+ let old_param_types = cx. func_params_types ( old_fn_ty) ;
651+ let old_param_count = old_param_types. len ( ) ;
652+ if old_param_count == 0 {
653+ return ;
654+ }
655+
656+ let first_param = llvm:: get_param ( old_fn, 0 ) ;
657+ let c_name = llvm:: get_value_name ( first_param) ;
658+ let first_arg_name = str:: from_utf8 ( & c_name) . unwrap ( ) ;
659+ // We might call llvm_optimize (and thus this code) multiple times on the same IR,
660+ // but we shouldn't add this helper ptr multiple times.
661+ if first_arg_name == "dyn_ptr" {
662+ return ;
663+ }
664+
665+ // Create the new parameter list, with ptr as the first argument
666+ let mut new_param_types = Vec :: with_capacity ( old_param_count as usize + 1 ) ;
667+ new_param_types. push ( cx. type_ptr ( ) ) ;
668+ for old_param in old_param_types {
669+ new_param_types. push ( old_param) ;
670+ }
671+
672+ // Create the new function type
673+ let ret_ty = unsafe { llvm:: LLVMGetReturnType ( old_fn_ty) } ;
674+ let new_fn_ty = cx. type_func ( & new_param_types, ret_ty) ;
675+
676+ // Create the new function, with a temporary .offload name to avoid a name collision.
677+ let old_fn_name = String :: from_utf8 ( llvm:: get_value_name ( old_fn) ) . unwrap ( ) ;
678+ let new_fn_name = format ! ( "{}.offload" , & old_fn_name) ;
679+ let new_fn = cx. add_func ( & new_fn_name, new_fn_ty) ;
680+ let a0 = llvm:: get_param ( new_fn, 0 ) ;
681+ llvm:: set_value_name ( a0, CString :: new ( "dyn_ptr" ) . unwrap ( ) . as_bytes ( ) ) ;
682+
683+ // Here we map the old arguments to the new arguments, with an offset of 1 to make sure
684+ // that we don't use the newly added `%dyn_ptr`.
685+ unsafe {
686+ llvm:: LLVMRustOffloadMapper ( cx. llmod ( ) , old_fn, new_fn) ;
687+ }
688+
689+ llvm:: set_linkage ( new_fn, llvm:: get_linkage ( old_fn) ) ;
690+ llvm:: set_visibility ( new_fn, llvm:: get_visibility ( old_fn) ) ;
691+
692+ // Replace all uses of old_fn with new_fn (RAUW)
693+ unsafe {
694+ llvm:: LLVMReplaceAllUsesWith ( old_fn, new_fn) ;
695+ }
696+ let name = llvm:: get_value_name ( old_fn) ;
697+ unsafe {
698+ llvm:: LLVMDeleteFunction ( old_fn) ;
699+ }
700+ // Now we can re-use the old name, without name collision.
701+ llvm:: set_value_name ( new_fn, & name) ;
650702 }
651703
652704 let consider_offload = config. offload . contains ( & config:: Offload :: Enable ) ;
653705 if consider_offload && ( cgcx. target_arch == "amdgpu" || cgcx. target_arch == "nvptx64" ) {
706+ let cx =
707+ SimpleCx :: new ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , cgcx. pointer_size ) ;
654708 for num in 0 ..9 {
655709 let name = format ! ( "kernel_{num}" ) ;
656- let c_name = CString :: new ( name) . unwrap ( ) ;
657- if let Some ( kernel) =
658- unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) }
659- {
660- handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
710+ if let Some ( kernel) = cx. get_function ( & name) {
711+ handle_offload ( & cx, kernel) ;
661712 }
662713 }
663714 }
0 commit comments