@@ -44,7 +44,7 @@ use crate::errors::{
4444use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
4545use crate :: llvm:: { self , DiagnosticInfo } ;
4646use crate :: type_:: Type ;
47- use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
47+ use crate :: { LlvmCodegenBackend , ModuleLlvm , SimpleCx , base, common, llvm_util} ;
4848
4949pub ( crate ) fn llvm_err < ' a > ( dcx : DiagCtxtHandle < ' _ > , err : LlvmError < ' a > ) -> FatalError {
5050 match llvm:: last_error ( ) {
@@ -653,83 +653,72 @@ pub(crate) unsafe fn llvm_optimize(
653653 None
654654 } ;
655655
656- fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
657- unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
658- //unsafe {llvm::LLVMDumpModule(m);}
659- //unsafe {
660- // // Get the old function type
661- // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
662- // dbg!(&old_fn_ty);
663- // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
664- // dbg!(&old_param_count);
665-
666- // // Get the old parameter types
667- // let mut old_param_types = Vec::with_capacity(old_param_count as usize);
668- // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
669- // old_param_types.set_len(old_param_count as usize);
670-
671- // // Create the new parameter list, with ptr as the first argument
672- // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
673- // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
674- // new_param_types.push(ptr_ty);
675- // for old_param in old_param_types {
676- // new_param_types.push(old_param);
677- // }
678- // dbg!(&new_param_types);
679-
680- // // Create the new function type
681- // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
682- // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
683- // dbg!(&new_fn_ty);
684-
685- // // Create the new function
686- // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
687- // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
688- // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
689- // let new_fn_cstr = CString::new(new_fn_name).unwrap();
690- // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
691- // dbg!(&new_fn);
692- // let a0 = llvm::LLVMGetParam(new_fn, 0);
693- // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
694- // dbg!(&new_fn);
695-
696- // // Move basic blocks
697- // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
698- // //dbg!(&bb);
699- // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
700- // //while !bb.is_null() {
701- // // let next = llvm::LLVMGetNextBasicBlock(bb);
702- // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
703- // // bb = next;
704- // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
705- // let old_n = llvm::LLVMCountParams(old_fn);
706- // for i in 0..old_n {
707- // let old_arg = llvm::LLVMGetParam(old_fn, i);
708- // let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
709- // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
710- // }
711-
712- // // Copy linkage and visibility
713- // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
714- // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
715-
716- // // Replace all uses of old_fn with new_fn (RAUW)
717- // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
718-
719- // // Optionally, remove the old function
720- // llvm::LLVMDeleteFunction(old_fn);
721- //}
656+ fn handle_offload < ' ll > ( cx : & ' ll SimpleCx < ' _ > , old_fn : & llvm:: Value ) {
657+ {
658+ let old_fn_ty = cx. get_type_of_global ( old_fn) ;
659+ let old_param_types = cx. func_params_types ( old_fn_ty) ;
660+ let old_param_count = old_param_types. len ( ) ;
661+ if old_param_count == 0 {
662+ return ;
663+ }
664+
665+ let first_param = llvm:: get_param ( old_fn, 0 ) ;
666+ let c_name = llvm:: get_value_name ( first_param) ;
667+ let first_arg_name = str:: from_utf8 ( & c_name) . unwrap ( ) ;
668+ // We might call llvm_optimize (and thus this code) multiple times on the same IR,
669+ // but we shouldn't add this helper ptr multiple times.
670+ if first_arg_name == "dyn_ptr" {
671+ return ;
672+ }
673+
674+ // Create the new parameter list, with ptr as the first argument
675+ let mut new_param_types = Vec :: with_capacity ( old_param_count as usize + 1 ) ;
676+ new_param_types. push ( cx. type_ptr ( ) ) ;
677+ for old_param in old_param_types {
678+ new_param_types. push ( old_param) ;
679+ }
680+
681+ // Create the new function type
682+ let ret_ty = unsafe { llvm:: LLVMGetReturnType ( old_fn_ty) } ;
683+ let new_fn_ty = cx. type_func ( & new_param_types, ret_ty) ;
684+
685+ // Create the new function, with a temporary .offload name to avoid a name collision.
686+ let old_fn_name = String :: from_utf8 ( llvm:: get_value_name ( old_fn) ) . unwrap ( ) ;
687+ let new_fn_name = format ! ( "{}.offload" , & old_fn_name) ;
688+ let new_fn = cx. add_func ( & new_fn_name, new_fn_ty) ;
689+ let a0 = llvm:: get_param ( new_fn, 0 ) ;
690+ llvm:: set_value_name ( a0, CString :: new ( "dyn_ptr" ) . unwrap ( ) . as_bytes ( ) ) ;
691+
692+ // Here we map the old arguments to the new arguments, with an offset of 1 to make sure
693+ // that we don't use the newly added `%dyn_ptr`.
694+ unsafe {
695+ llvm:: LLVMRustOffloadMapper ( cx. llmod ( ) , old_fn, new_fn) ;
696+ }
697+
698+ llvm:: set_linkage ( new_fn, llvm:: get_linkage ( old_fn) ) ;
699+ llvm:: set_visibility ( new_fn, llvm:: get_visibility ( old_fn) ) ;
700+
701+ // Replace all uses of old_fn with new_fn (RAUW)
702+ unsafe {
703+ llvm:: LLVMReplaceAllUsesWith ( old_fn, new_fn) ;
704+ }
705+ let name = llvm:: get_value_name ( old_fn) ;
706+ unsafe {
707+ llvm:: LLVMDeleteFunction ( old_fn) ;
708+ }
709+ // Now we can re-use the old name, without name collision.
710+ llvm:: set_value_name ( new_fn, & name) ;
711+ }
722712 }
723713
724714 let consider_offload = config. offload . contains ( & config:: Offload :: Enable ) ;
725715 if consider_offload && ( cgcx. target_arch == "amdgpu" || cgcx. target_arch == "nvptx64" ) {
716+ let cx =
717+ SimpleCx :: new ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , cgcx. pointer_size ) ;
726718 for num in 0 ..9 {
727719 let name = format ! ( "kernel_{num}" ) ;
728- let c_name = CString :: new ( name) . unwrap ( ) ;
729- if let Some ( kernel) =
730- unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) }
731- {
732- handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
720+ if let Some ( kernel) = cx. get_function ( & name) {
721+ handle_offload ( & cx, kernel) ;
733722 }
734723 }
735724 }
0 commit comments