@@ -698,6 +698,7 @@ pub(crate) unsafe fn extract_return_type<'a>(
698698pub ( crate ) unsafe fn enzyme_ad (
699699 llmod : & llvm:: Module ,
700700 llcx : & llvm:: Context ,
701+ diag_handler : & rustc_errors:: Handler ,
701702 item : AutoDiffItem ,
702703) -> Result < ( ) , FatalError > {
703704 let autodiff_mode = item. attrs . mode ;
@@ -710,8 +711,28 @@ pub(crate) unsafe fn enzyme_ad(
710711 // get target and source function
711712 let name = CString :: new ( rust_name. to_owned ( ) ) . unwrap ( ) ;
712713 let name2 = CString :: new ( rust_name2. clone ( ) ) . unwrap ( ) ;
713- let src_fnc = llvm:: LLVMGetNamedFunction ( llmod, name. as_c_str ( ) . as_ptr ( ) ) . unwrap ( ) ;
714- let target_fnc = llvm:: LLVMGetNamedFunction ( llmod, name2. as_ptr ( ) ) . unwrap ( ) ;
714+ let src_fnc_opt = llvm:: LLVMGetNamedFunction ( llmod, name. as_c_str ( ) . as_ptr ( ) ) ;
715+ let src_fnc = match src_fnc_opt {
716+ Some ( x) => x,
717+ None => {
718+ return Err ( llvm_err ( diag_handler, LlvmError :: PrepareAutoDiff {
719+ src : rust_name. to_owned ( ) ,
720+ target : rust_name2. to_owned ( ) ,
721+ error : "could not find src function" . to_owned ( ) ,
722+ } ) ) ;
723+ }
724+ } ;
725+ let target_fnc_opt = llvm:: LLVMGetNamedFunction ( llmod, name2. as_ptr ( ) ) ;
726+ let target_fnc = match target_fnc_opt {
727+ Some ( x) => x,
728+ None => {
729+ return Err ( llvm_err ( diag_handler, LlvmError :: PrepareAutoDiff {
730+ src : rust_name. to_owned ( ) ,
731+ target : rust_name2. to_owned ( ) ,
732+ error : "could not find target function" . to_owned ( ) ,
733+ } ) ) ;
734+ }
735+ } ;
715736 let src_num_args = llvm:: LLVMCountParams ( src_fnc) ;
716737 let target_num_args = llvm:: LLVMCountParams ( target_fnc) ;
717738 assert ! ( src_num_args <= target_num_args) ;
@@ -791,13 +812,14 @@ pub(crate) unsafe fn enzyme_ad(
791812
792813pub ( crate ) unsafe fn differentiate (
793814 module : & ModuleCodegen < ModuleLlvm > ,
794- _cgcx : & CodegenContext < LlvmCodegenBackend > ,
815+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
795816 diff_items : Vec < AutoDiffItem > ,
796817 _typetrees : FxHashMap < String , DiffTypeTree > ,
797818 _config : & ModuleConfig ,
798819) -> Result < ( ) , FatalError > {
799820 let llmod = module. module_llvm . llmod ( ) ;
800821 let llcx = & module. module_llvm . llcx ;
822+ let diag_handler = cgcx. create_diag_handler ( ) ;
801823
802824 llvm:: EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( llvm:: EnzymeStrictAliasing ) , 0 ) ;
803825
@@ -818,7 +840,7 @@ pub(crate) unsafe fn differentiate(
818840 }
819841
820842 for item in diff_items {
821- let res = enzyme_ad ( llmod, llcx, item) ;
843+ let res = enzyme_ad ( llmod, llcx, & diag_handler , item) ;
822844 assert ! ( res. is_ok( ) ) ;
823845 }
824846
0 commit comments