@@ -851,10 +851,10 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
851851 fnc : & Value ,
852852 input_diffactivity : Vec < DiffActivity > ,
853853 ret_diffactivity : DiffActivity ,
854- mut ret_primary_ret : bool ,
855854 input_tts : Vec < TypeTree > ,
856855 output_tt : TypeTree ,
857856) -> & Value {
857+ let mut ret_primary_ret = false ;
858858 let ret_activity = cdiffe_from ( ret_diffactivity) ;
859859 assert ! ( ret_activity != CDIFFE_TYPE :: DFT_OUT_DIFF ) ;
860860 let mut input_activity: Vec < CDIFFE_TYPE > = vec ! [ ] ;
@@ -925,29 +925,22 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
925925 fnc : & Value ,
926926 input_activity : Vec < DiffActivity > ,
927927 ret_activity : DiffActivity ,
928- mut ret_primary_ret : bool ,
929- diff_primary_ret : bool ,
930928 input_tts : Vec < TypeTree > ,
931929 output_tt : TypeTree ,
932930) -> & Value {
933- let ret_activity = cdiffe_from ( ret_activity) ;
934- assert ! ( ret_activity == CDIFFE_TYPE :: DFT_CONSTANT || ret_activity == CDIFFE_TYPE :: DFT_OUT_DIFF ) ;
931+ let ( primary_ret, diff_ret, ret_activity) = match ret_activity {
932+ DiffActivity :: Const => ( true , false , CDIFFE_TYPE :: DFT_CONSTANT ) ,
933+ DiffActivity :: Active => ( true , true , CDIFFE_TYPE :: DFT_DUP_ARG ) ,
934+ DiffActivity :: ActiveOnly => ( false , true , CDIFFE_TYPE :: DFT_DUP_NONEED ) ,
935+ DiffActivity :: None => ( false , false , CDIFFE_TYPE :: DFT_CONSTANT ) ,
936+ _ => panic ! ( "Invalid return activity" ) ,
937+ } ;
938+
939+ //assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF);
935940 let input_activity: Vec < CDIFFE_TYPE > = input_activity. iter ( ) . map ( |& x| cdiffe_from ( x) ) . collect ( ) ;
936941
937942 dbg ! ( & fnc) ;
938943
939- if ret_activity == CDIFFE_TYPE :: DFT_DUP_ARG {
940- if ret_primary_ret != true {
941- dbg ! ( "overwriting ret_primary_ret!" ) ;
942- }
943- ret_primary_ret = true ;
944- } else if ret_activity == CDIFFE_TYPE :: DFT_DUP_NONEED {
945- if ret_primary_ret != false {
946- dbg ! ( "overwriting ret_primary_ret!" ) ;
947- }
948- ret_primary_ret = false ;
949- }
950-
951944 let mut args_tree = input_tts. iter ( ) . map ( |x| x. inner ) . collect :: < Vec < _ > > ( ) ;
952945
953946 // We don't support volatile / extern / (global?) values.
@@ -977,8 +970,8 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
977970 input_activity. as_ptr ( ) ,
978971 input_activity. len ( ) , // constant arguments
979972 type_analysis, // type analysis struct
980- ret_primary_ret as u8 ,
981- diff_primary_ret as u8 , //0
973+ primary_ret as u8 ,
974+ diff_ret as u8 , //0
982975 CDerivativeMode :: DEM_ReverseModeCombined , // return value, dret_used, top_level which was 1
983976 1 , // vector mode width
984977 1 , // free memory
@@ -2704,12 +2697,13 @@ pub enum CDIFFE_TYPE {
27042697fn cdiffe_from ( act : DiffActivity ) -> CDIFFE_TYPE {
27052698 return match act {
27062699 DiffActivity :: None => CDIFFE_TYPE :: DFT_CONSTANT ,
2707- DiffActivity :: Active => CDIFFE_TYPE :: DFT_OUT_DIFF ,
27082700 DiffActivity :: Const => CDIFFE_TYPE :: DFT_CONSTANT ,
2701+ DiffActivity :: Active => CDIFFE_TYPE :: DFT_OUT_DIFF ,
2702+ DiffActivity :: ActiveOnly => CDIFFE_TYPE :: DFT_OUT_DIFF ,
27092703 DiffActivity :: Dual => CDIFFE_TYPE :: DFT_DUP_ARG ,
2710- DiffActivity :: DualNoNeed => CDIFFE_TYPE :: DFT_DUP_NONEED ,
2704+ DiffActivity :: DualOnly => CDIFFE_TYPE :: DFT_DUP_NONEED ,
27112705 DiffActivity :: Duplicated => CDIFFE_TYPE :: DFT_DUP_ARG ,
2712- DiffActivity :: DuplicatedNoNeed => CDIFFE_TYPE :: DFT_DUP_NONEED ,
2706+ DiffActivity :: DuplicatedOnly => CDIFFE_TYPE :: DFT_DUP_NONEED ,
27132707 } ;
27142708}
27152709
0 commit comments