@@ -653,6 +653,84 @@ pub(crate) unsafe fn llvm_optimize(
653653 None
654654 } ;
655655
656+ fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
657+ dbg ! ( & old_fn) ;
658+ unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
659+ //unsafe {llvm::LLVMDumpModule(m);}
660+ //unsafe {
661+ // // Get the old function type
662+ // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
663+ // dbg!(&old_fn_ty);
664+ // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
665+ // dbg!(&old_param_count);
666+
667+ // // Get the old parameter types
668+ // let mut old_param_types = Vec::with_capacity(old_param_count as usize);
669+ // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
670+ // old_param_types.set_len(old_param_count as usize);
671+
672+ // // Create the new parameter list, with ptr as the first argument
673+ // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
674+ // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
675+ // new_param_types.push(ptr_ty);
676+ // for old_param in old_param_types {
677+ // new_param_types.push(old_param);
678+ // }
679+ // dbg!(&new_param_types);
680+
681+ // // Create the new function type
682+ // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
683+ // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
684+ // dbg!(&new_fn_ty);
685+
686+ // // Create the new function
687+ // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
688+ // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
689+ // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
690+ // let new_fn_cstr = CString::new(new_fn_name).unwrap();
691+ // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
692+ // dbg!(&new_fn);
693+ // let a0 = llvm::LLVMGetParam(new_fn, 0);
694+ // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
695+ // dbg!(&new_fn);
696+
697+ // // Move basic blocks
698+ // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
699+ // //dbg!(&bb);
700+ // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
701+ // //while !bb.is_null() {
702+ // // let next = llvm::LLVMGetNextBasicBlock(bb);
703+ // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
704+ // // bb = next;
705+ // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
706+ // let old_n = llvm::LLVMCountParams(old_fn);
707+ // for i in 0..old_n {
708+ // let old_arg = llvm::LLVMGetParam(old_fn, i);
709+ // let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
710+ // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
711+ // }
712+
713+ // // Copy linkage and visibility
714+ // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
715+ // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
716+
717+ // // Replace all uses of old_fn with new_fn (RAUW)
718+ // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
719+
720+ // // Optionally, remove the old function
721+ // llvm::LLVMDeleteFunction(old_fn);
722+ //}
723+ }
724+
725+ for num in 0 ..9 {
726+ let name = format ! ( "kernel_{num}" ) ;
727+ let c_name = CString :: new ( name) . unwrap ( ) ;
728+ if let Some ( kernel) = unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) } {
729+ dbg ! ( "found offload kernel asfd" ) ;
730+ handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
731+ }
732+ }
733+
656734 let mut llvm_profiler = cgcx
657735 . prof
658736 . llvm_recording_enabled ( )
0 commit comments