@@ -653,6 +653,87 @@ pub(crate) unsafe fn llvm_optimize(
653653 None
654654 } ;
655655
656+ fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
657+ unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
658+ //unsafe {llvm::LLVMDumpModule(m);}
659+ //unsafe {
660+ // // Get the old function type
661+ // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
662+ // dbg!(&old_fn_ty);
663+ // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
664+ // dbg!(&old_param_count);
665+
666+ // // Get the old parameter types
667+ // let mut old_param_types = Vec::with_capacity(old_param_count as usize);
668+ // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
669+ // old_param_types.set_len(old_param_count as usize);
670+
671+ // // Create the new parameter list, with ptr as the first argument
672+ // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
673+ // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
674+ // new_param_types.push(ptr_ty);
675+ // for old_param in old_param_types {
676+ // new_param_types.push(old_param);
677+ // }
678+ // dbg!(&new_param_types);
679+
680+ // // Create the new function type
681+ // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
682+ // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
683+ // dbg!(&new_fn_ty);
684+
685+ // // Create the new function
686+ // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
687+ // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
688+ // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
689+ // let new_fn_cstr = CString::new(new_fn_name).unwrap();
690+ // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
691+ // dbg!(&new_fn);
692+ // let a0 = llvm::LLVMGetParam(new_fn, 0);
693+ // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
694+ // dbg!(&new_fn);
695+
696+ // // Move basic blocks
697+ // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
698+ // //dbg!(&bb);
699+ // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
700+ // //while !bb.is_null() {
701+ // // let next = llvm::LLVMGetNextBasicBlock(bb);
702+ // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
703+ // // bb = next;
704+ // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
705+ // let old_n = llvm::LLVMCountParams(old_fn);
706+ // for i in 0..old_n {
707+ // let old_arg = llvm::LLVMGetParam(old_fn, i);
708+ // let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
709+ // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
710+ // }
711+
712+ // // Copy linkage and visibility
713+ // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
714+ // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
715+
716+ // // Replace all uses of old_fn with new_fn (RAUW)
717+ // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
718+
719+ // // Optionally, remove the old function
720+ // llvm::LLVMDeleteFunction(old_fn);
721+ //}
722+ }
723+
724+ let consider_offload = config. offload . contains ( & config:: Offload :: Enable ) ;
725+ if consider_offload && ( cgcx. target_arch == "amdgpu" || cgcx. target_arch == "nvptx64" ) {
726+ for num in 0 ..9 {
727+ let name = format ! ( "kernel_{num}" ) ;
728+ let c_name = CString :: new ( name) . unwrap ( ) ;
729+ if let Some ( kernel) =
730+ unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) }
731+ {
732+ handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
733+ }
734+ }
735+ }
736+
656737 let mut llvm_profiler = cgcx
657738 . prof
658739 . llvm_recording_enabled ( )
0 commit comments