@@ -658,6 +658,87 @@ pub(crate) unsafe fn llvm_optimize(
658658 None
659659 } ;
660660
661+ fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
662+ unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
663+ //unsafe {llvm::LLVMDumpModule(m);}
664+ //unsafe {
665+ // // Get the old function type
666+ // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
667+ // dbg!(&old_fn_ty);
668+ // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
669+ // dbg!(&old_param_count);
670+
671+ // // Get the old parameter types
672+ // let mut old_param_types = Vec::with_capacity(old_param_count as usize);
673+ // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
674+ // old_param_types.set_len(old_param_count as usize);
675+
676+ // // Create the new parameter list, with ptr as the first argument
677+ // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
678+ // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
679+ // new_param_types.push(ptr_ty);
680+ // for old_param in old_param_types {
681+ // new_param_types.push(old_param);
682+ // }
683+ // dbg!(&new_param_types);
684+
685+ // // Create the new function type
686+ // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
687+ // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
688+ // dbg!(&new_fn_ty);
689+
690+ // // Create the new function
691+ // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
692+ // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
693+ // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
694+ // let new_fn_cstr = CString::new(new_fn_name).unwrap();
695+ // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
696+ // dbg!(&new_fn);
697+ // let a0 = llvm::LLVMGetParam(new_fn, 0);
698+ // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
699+ // dbg!(&new_fn);
700+
701+ // // Move basic blocks
702+ // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
703+ // //dbg!(&bb);
704+ // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
705+ // //while !bb.is_null() {
706+ // // let next = llvm::LLVMGetNextBasicBlock(bb);
707+ // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
708+ // // bb = next;
709+ // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
710+ // let old_n = llvm::LLVMCountParams(old_fn);
711+ // for i in 0..old_n {
712+ // let old_arg = llvm::LLVMGetParam(old_fn, i);
713+ // let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
714+ // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
715+ // }
716+
717+ // // Copy linkage and visibility
718+ // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
719+ // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
720+
721+ // // Replace all uses of old_fn with new_fn (RAUW)
722+ // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
723+
724+ // // Optionally, remove the old function
725+ // llvm::LLVMDeleteFunction(old_fn);
726+ //}
727+ }
728+
729+ let consider_offload = config. offload . contains ( & config:: Offload :: Enable ) ;
730+ if consider_offload && ( cgcx. target_arch == "amdgpu" || cgcx. target_arch == "nvptx64" ) {
731+ for num in 0 ..9 {
732+ let name = format ! ( "kernel_{num}" ) ;
733+ let c_name = CString :: new ( name) . unwrap ( ) ;
734+ if let Some ( kernel) =
735+ unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) }
736+ {
737+ handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
738+ }
739+ }
740+ }
741+
661742 let mut llvm_profiler = cgcx
662743 . prof
663744 . llvm_recording_enabled ( )
0 commit comments