@@ -45,7 +45,7 @@ use crate::errors::{
4545use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
4646use crate :: llvm:: { self , DiagnosticInfo } ;
4747use crate :: type_:: llvm_type_ptr;
48- use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
48+ use crate :: { LlvmCodegenBackend , ModuleLlvm , SimpleCx , base, common, llvm_util} ;
4949
5050pub ( crate ) fn llvm_err < ' a > ( dcx : DiagCtxtHandle < ' _ > , err : LlvmError < ' a > ) -> ! {
5151 match llvm:: last_error ( ) {
@@ -658,83 +658,72 @@ pub(crate) unsafe fn llvm_optimize(
658658 None
659659 } ;
660660
661- fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
662- unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
663- //unsafe {llvm::LLVMDumpModule(m);}
664- //unsafe {
665- // // Get the old function type
666- // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
667- // dbg!(&old_fn_ty);
668- // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
669- // dbg!(&old_param_count);
670-
671- // // Get the old parameter types
672- // let mut old_param_types = Vec::with_capacity(old_param_count as usize);
673- // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
674- // old_param_types.set_len(old_param_count as usize);
675-
676- // // Create the new parameter list, with ptr as the first argument
677- // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
678- // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
679- // new_param_types.push(ptr_ty);
680- // for old_param in old_param_types {
681- // new_param_types.push(old_param);
682- // }
683- // dbg!(&new_param_types);
684-
685- // // Create the new function type
686- // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
687- // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
688- // dbg!(&new_fn_ty);
689-
690- // // Create the new function
691- // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
692- // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
693- // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
694- // let new_fn_cstr = CString::new(new_fn_name).unwrap();
695- // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
696- // dbg!(&new_fn);
697- // let a0 = llvm::LLVMGetParam(new_fn, 0);
698- // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
699- // dbg!(&new_fn);
700-
701- // // Move basic blocks
702- // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
703- // //dbg!(&bb);
704- // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
705- // //while !bb.is_null() {
706- // // let next = llvm::LLVMGetNextBasicBlock(bb);
707- // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
708- // // bb = next;
709- // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
710- // let old_n = llvm::LLVMCountParams(old_fn);
711- // for i in 0..old_n {
712- // let old_arg = llvm::LLVMGetParam(old_fn, i);
713- // let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
714- // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
715- // }
716-
717- // // Copy linkage and visibility
718- // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
719- // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
720-
721- // // Replace all uses of old_fn with new_fn (RAUW)
722- // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
723-
724- // // Optionally, remove the old function
725- // llvm::LLVMDeleteFunction(old_fn);
726- //}
661+ fn handle_offload < ' ll > ( cx : & ' ll SimpleCx < ' _ > , old_fn : & llvm:: Value ) {
662+ {
663+ let old_fn_ty = cx. get_type_of_global ( old_fn) ;
664+ let old_param_types = cx. func_params_types ( old_fn_ty) ;
665+ let old_param_count = old_param_types. len ( ) ;
666+ if old_param_count == 0 {
667+ return ;
668+ }
669+
670+ let first_param = llvm:: get_param ( old_fn, 0 ) ;
671+ let c_name = llvm:: get_value_name ( first_param) ;
672+ let first_arg_name = str:: from_utf8 ( & c_name) . unwrap ( ) ;
673+ // We might call llvm_optimize (and thus this code) multiple times on the same IR,
674+ // but we shouldn't add this helper ptr multiple times.
675+ if first_arg_name == "dyn_ptr" {
676+ return ;
677+ }
678+
679+ // Create the new parameter list, with ptr as the first argument
680+ let mut new_param_types = Vec :: with_capacity ( old_param_count as usize + 1 ) ;
681+ new_param_types. push ( cx. type_ptr ( ) ) ;
682+ for old_param in old_param_types {
683+ new_param_types. push ( old_param) ;
684+ }
685+
686+ // Create the new function type
687+ let ret_ty = unsafe { llvm:: LLVMGetReturnType ( old_fn_ty) } ;
688+ let new_fn_ty = cx. type_func ( & new_param_types, ret_ty) ;
689+
690+ // Create the new function, with a temporary .offload name to avoid a name collision.
691+ let old_fn_name = String :: from_utf8 ( llvm:: get_value_name ( old_fn) ) . unwrap ( ) ;
692+ let new_fn_name = format ! ( "{}.offload" , & old_fn_name) ;
693+ let new_fn = cx. add_func ( & new_fn_name, new_fn_ty) ;
694+ let a0 = llvm:: get_param ( new_fn, 0 ) ;
695+ llvm:: set_value_name ( a0, CString :: new ( "dyn_ptr" ) . unwrap ( ) . as_bytes ( ) ) ;
696+
697+ // Here we map the old arguments to the new arguments, with an offset of 1 to make sure
698+ // that we don't use the newly added `%dyn_ptr`.
699+ unsafe {
700+ llvm:: LLVMRustOffloadMapper ( cx. llmod ( ) , old_fn, new_fn) ;
701+ }
702+
703+ llvm:: set_linkage ( new_fn, llvm:: get_linkage ( old_fn) ) ;
704+ llvm:: set_visibility ( new_fn, llvm:: get_visibility ( old_fn) ) ;
705+
706+ // Replace all uses of old_fn with new_fn (RAUW)
707+ unsafe {
708+ llvm:: LLVMReplaceAllUsesWith ( old_fn, new_fn) ;
709+ }
710+ let name = llvm:: get_value_name ( old_fn) ;
711+ unsafe {
712+ llvm:: LLVMDeleteFunction ( old_fn) ;
713+ }
714+ // Now we can re-use the old name, without name collision.
715+ llvm:: set_value_name ( new_fn, & name) ;
716+ }
727717 }
728718
729719 let consider_offload = config. offload . contains ( & config:: Offload :: Enable ) ;
730720 if consider_offload && ( cgcx. target_arch == "amdgpu" || cgcx. target_arch == "nvptx64" ) {
721+ let cx =
722+ SimpleCx :: new ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , cgcx. pointer_size ) ;
731723 for num in 0 ..9 {
732724 let name = format ! ( "kernel_{num}" ) ;
733- let c_name = CString :: new ( name) . unwrap ( ) ;
734- if let Some ( kernel) =
735- unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) }
736- {
737- handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
725+ if let Some ( kernel) = cx. get_function ( & name) {
726+ handle_offload ( & cx, kernel) ;
738727 }
739728 }
740729 }
0 commit comments