@@ -45,7 +45,7 @@ use crate::errors::{
4545use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
4646use crate :: llvm:: { self , DiagnosticInfo } ;
4747use crate :: type_:: llvm_type_ptr;
48- use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
48+ use crate :: { LlvmCodegenBackend , ModuleLlvm , SimpleCx , base, common, llvm_util} ;
4949
5050pub ( crate ) fn llvm_err < ' a > ( dcx : DiagCtxtHandle < ' _ > , err : LlvmError < ' a > ) -> ! {
5151 match llvm:: last_error ( ) {
@@ -658,19 +658,70 @@ pub(crate) unsafe fn llvm_optimize(
658658 None
659659 } ;
660660
661- fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
662- unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
661+ fn handle_offload < ' ll > ( cx : & ' ll SimpleCx < ' _ > , old_fn : & llvm:: Value ) {
662+ let old_fn_ty = cx. get_type_of_global ( old_fn) ;
663+ let old_param_types = cx. func_params_types ( old_fn_ty) ;
664+ let old_param_count = old_param_types. len ( ) ;
665+ if old_param_count == 0 {
666+ return ;
667+ }
668+
669+ let first_param = llvm:: get_param ( old_fn, 0 ) ;
670+ let c_name = llvm:: get_value_name ( first_param) ;
671+ let first_arg_name = str:: from_utf8 ( & c_name) . unwrap ( ) ;
672+ // We might call llvm_optimize (and thus this code) multiple times on the same IR,
673+ // but we shouldn't add this helper ptr multiple times.
674+ if first_arg_name == "dyn_ptr" {
675+ return ;
676+ }
677+
678+ // Create the new parameter list, with ptr as the first argument
679+ let mut new_param_types = Vec :: with_capacity ( old_param_count as usize + 1 ) ;
680+ new_param_types. push ( cx. type_ptr ( ) ) ;
681+ for old_param in old_param_types {
682+ new_param_types. push ( old_param) ;
683+ }
684+
685+ // Create the new function type
686+ let ret_ty = unsafe { llvm:: LLVMGetReturnType ( old_fn_ty) } ;
687+ let new_fn_ty = cx. type_func ( & new_param_types, ret_ty) ;
688+
689+ // Create the new function, with a temporary .offload name to avoid a name collision.
690+ let old_fn_name = String :: from_utf8 ( llvm:: get_value_name ( old_fn) ) . unwrap ( ) ;
691+ let new_fn_name = format ! ( "{}.offload" , & old_fn_name) ;
692+ let new_fn = cx. add_func ( & new_fn_name, new_fn_ty) ;
693+ let a0 = llvm:: get_param ( new_fn, 0 ) ;
694+ llvm:: set_value_name ( a0, CString :: new ( "dyn_ptr" ) . unwrap ( ) . as_bytes ( ) ) ;
695+
696+ // Here we map the old arguments to the new arguments, with an offset of 1 to make sure
697+ // that we don't use the newly added `%dyn_ptr`.
698+ unsafe {
699+ llvm:: LLVMRustOffloadMapper ( cx. llmod ( ) , old_fn, new_fn) ;
700+ }
701+
702+ llvm:: set_linkage ( new_fn, llvm:: get_linkage ( old_fn) ) ;
703+ llvm:: set_visibility ( new_fn, llvm:: get_visibility ( old_fn) ) ;
704+
705+ // Replace all uses of old_fn with new_fn (RAUW)
706+ unsafe {
707+ llvm:: LLVMReplaceAllUsesWith ( old_fn, new_fn) ;
708+ }
709+ let name = llvm:: get_value_name ( old_fn) ;
710+ unsafe {
711+ llvm:: LLVMDeleteFunction ( old_fn) ;
712+ }
713+ // Now we can re-use the old name, without name collision.
714+ llvm:: set_value_name ( new_fn, & name) ;
663715 }
664716
665717 let consider_offload = config. offload . contains ( & config:: Offload :: Enable ) ;
666718 if consider_offload && ( cgcx. target_arch == "amdgpu" || cgcx. target_arch == "nvptx64" ) {
719+ let cx =
720+ SimpleCx :: new ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , cgcx. pointer_size ) ;
667721 for num in 0 ..9 {
668722 let name = format ! ( "kernel_{num}" ) ;
669- let c_name = CString :: new ( name) . unwrap ( ) ;
670- if let Some ( kernel) =
671- unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) }
672- {
673- handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
723+ if let Some ( kernel) = cx. get_function ( & name) {
724+ handle_offload ( & cx, kernel) ;
674725 }
675726 }
676727 }
0 commit comments