@@ -69,7 +69,7 @@ use rustc_data_structures::small_c_str::SmallCStr;
6969use rustc_errors:: { DiagCtxt , FatalError , Level } ;
7070use rustc_fs_util:: { link_or_copy, path_to_c_string} ;
7171use rustc_middle:: ty:: TyCtxt ;
72- use rustc_session:: config:: { self , Lto , OutputType , Passes , SplitDwarfKind , SwitchWithOptPath } ;
72+ use rustc_session:: config:: { self , AutoDiff , Lto , OutputType , Passes , SplitDwarfKind , SwitchWithOptPath } ;
7373use rustc_session:: Session ;
7474use rustc_span:: symbol:: sym;
7575use rustc_span:: InnerSpan ;
@@ -707,7 +707,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
707707
708708
709709unsafe fn create_call < ' a > ( tgt : & ' a Value , src : & ' a Value , rev_mode : bool ,
710- llmod : & ' a llvm:: Module , llcx : & llvm:: Context , size_positions : & [ usize ] ) {
710+ llmod : & ' a llvm:: Module , llcx : & llvm:: Context , size_positions : & [ usize ] , ad : & [ AutoDiff ] ) {
711711
712712 // first, remove all calls from fnc
713713 let bb = LLVMGetFirstBasicBlock ( tgt) ;
@@ -729,12 +729,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
729729 let last_inst = LLVMRustGetLastInstruction ( bb) . unwrap ( ) ;
730730 LLVMPositionBuilderAtEnd ( builder, bb) ;
731731
732- let safety_run_checks;
733- if std:: env:: var ( "ENZYME_NO_SAFETY_CHECKS" ) . is_ok ( ) {
734- safety_run_checks = false ;
735- } else {
736- safety_run_checks = true ;
737- }
732+ let safety_run_checks = !ad. contains ( & AutoDiff :: NoSafetyChecks ) ;
738733
739734 if inner_param_num == outer_param_num {
740735 call_args = outer_args;
@@ -951,6 +946,7 @@ pub(crate) unsafe fn enzyme_ad(
951946 diag_handler : & DiagCtxt ,
952947 item : AutoDiffItem ,
953948 logic_ref : EnzymeLogicRef ,
949+ ad : & [ AutoDiff ] ,
954950) -> Result < ( ) , FatalError > {
955951 let autodiff_mode = item. attrs . mode ;
956952 let rust_name = item. source ;
@@ -1010,16 +1006,16 @@ pub(crate) unsafe fn enzyme_ad(
10101006
10111007 llvm:: set_strict_aliasing ( false ) ;
10121008
1013- if std :: env :: var ( "ENZYME_PRINT_TA" ) . is_ok ( ) {
1009+ if ad . contains ( & AutoDiff :: PrintTA ) {
10141010 llvm:: set_print_type ( true ) ;
10151011 }
1016- if std :: env :: var ( "ENZYME_PRINT_AA" ) . is_ok ( ) {
1017- llvm:: set_print_activity ( true ) ;
1012+ if ad . contains ( & AutoDiff :: PrintTA ) {
1013+ llvm:: set_print_type ( true ) ;
10181014 }
1019- if std :: env :: var ( "ENZYME_PRINT_PERF" ) . is_ok ( ) {
1015+ if ad . contains ( & AutoDiff :: PrintPerf ) {
10201016 llvm:: set_print_perf ( true ) ;
10211017 }
1022- if std :: env :: var ( "ENZYME_PRINT" ) . is_ok ( ) {
1018+ if ad . contains ( & AutoDiff :: Print ) {
10231019 llvm:: set_print ( true ) ;
10241020 }
10251021
@@ -1062,7 +1058,7 @@ pub(crate) unsafe fn enzyme_ad(
10621058 let f_return_type = LLVMGetReturnType ( LLVMGlobalGetValueType ( res) ) ;
10631059
10641060 let rev_mode = item. attrs . mode == DiffMode :: Reverse ;
1065- create_call ( target_fnc, res, rev_mode, llmod, llcx, & size_positions) ;
1061+ create_call ( target_fnc, res, rev_mode, llmod, llcx, & size_positions, ad ) ;
10661062 // TODO: implement drop for wrapper type?
10671063 FreeTypeAnalysis ( type_analysis) ;
10681064
@@ -1087,7 +1083,9 @@ pub(crate) unsafe fn differentiate(
10871083
10881084 llvm:: set_strict_aliasing ( false ) ;
10891085
1090- if std:: env:: var ( "ENZYME_LOOSE_TYPES" ) . is_ok ( ) {
1086+ let ad = & config. autodiff ;
1087+
1088+ if ad. contains ( & AutoDiff :: LooseTypes ) {
10911089 dbg ! ( "Setting loose types to true" ) ;
10921090 llvm:: set_loose_types ( true ) ;
10931091 }
@@ -1110,41 +1108,42 @@ pub(crate) unsafe fn differentiate(
11101108 // trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary.
11111109 // This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in
11121110 // Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions?
1113- if std :: env :: var ( "ENZYME_OPT" ) . is_ok ( ) {
1111+ if ad . contains ( & AutoDiff :: OPT ) {
11141112 dbg ! ( "Enable extra debug helper to debug Enzyme through the opt plugin" ) ;
11151113 crate :: builder:: add_opt_dbg_helper ( llmod, llcx, fn_def, item. attrs . clone ( ) , i) ;
11161114 }
11171115 }
11181116
1119- if std :: env :: var ( "ENZYME_PRINT_MOD_BEFORE" ) . is_ok ( ) || std :: env :: var ( "ENZYME_OPT" ) . is_ok ( ) {
1117+ if ad . contains ( & AutoDiff :: PrintModBefore ) || ad . contains ( & AutoDiff :: OPT ) {
11201118 unsafe {
11211119 LLVMDumpModule ( llmod) ;
11221120 }
11231121 }
11241122
1125- if std :: env :: var ( "ENZYME_INLINE" ) . is_ok ( ) {
1123+ if ad . contains ( & AutoDiff :: Inline ) {
11261124 dbg ! ( "Setting inline to true" ) ;
11271125 llvm:: set_inline ( true ) ;
11281126 }
11291127
1130- if std:: env:: var ( "ENZYME_TT_DEPTH" ) . is_ok ( ) {
1131- let depth = std:: env:: var ( "ENZYME_TT_DEPTH" ) . unwrap ( ) ;
1132- let depth = depth. parse :: < u64 > ( ) . unwrap ( ) ;
1133- assert ! ( depth >= 1 ) ;
1134- llvm:: set_max_int_offset ( depth) ;
1135- }
1136- if std:: env:: var ( "ENZYME_TT_WIDTH" ) . is_ok ( ) {
1137- let width = std:: env:: var ( "ENZYME_TT_WIDTH" ) . unwrap ( ) ;
1138- let width = width. parse :: < u64 > ( ) . unwrap ( ) ;
1139- assert ! ( width >= 1 ) ;
1140- llvm:: set_max_type_offset ( width) ;
1141- }
1142-
1143- if std:: env:: var ( "ENZYME_RUNTIME_ACTIVITY" ) . is_ok ( ) {
1128+ if ad. contains ( & AutoDiff :: RuntimeActivity ) {
11441129 dbg ! ( "Setting runtime activity check to true" ) ;
11451130 llvm:: set_runtime_activity_check ( true ) ;
11461131 }
11471132
1133+ for val in ad {
1134+ match & val {
1135+ AutoDiff :: TTDepth ( depth) => {
1136+ assert ! ( * depth >= 1 ) ;
1137+ llvm:: set_max_int_offset ( * depth) ;
1138+ }
1139+ AutoDiff :: TTWidth ( width) => {
1140+ assert ! ( * width >= 1 ) ;
1141+ llvm:: set_max_type_offset ( * width) ;
1142+ }
1143+ _ => { } ,
1144+ }
1145+ } ;
1146+
11481147 let differentiate = !diff_items. is_empty ( ) ;
11491148 let mut first_order_items: Vec < AutoDiffItem > = vec ! [ ] ;
11501149 let mut higher_order_items: Vec < AutoDiffItem > = vec ! [ ] ;
@@ -1157,29 +1156,29 @@ pub(crate) unsafe fn differentiate(
11571156 }
11581157 }
11591158
1160- let mut fnc_opt = false ;
1161- if std:: env:: var ( "ENZYME_ENABLE_FNC_OPT" ) . is_ok ( ) {
1162- dbg ! ( "Enable extra optimizations for Enzyme" ) ;
1163- fnc_opt = true ;
1164- }
1159+
1160+ let fnc_opt = ad. contains ( & AutoDiff :: EnableFncOpt ) ;
11651161
11661162 // If a function is a base for some higher order ad, always optimize
11671163 let fnc_opt_base = true ;
11681164 let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic ( fnc_opt_base as u8 ) ;
11691165
11701166 for item in first_order_items {
1171- let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref_opt) ;
1167+ let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref_opt, ad ) ;
11721168 assert ! ( res. is_ok( ) ) ;
11731169 }
11741170
11751171 // For the rest, follow the user choice on debug vs release.
11761172 // Reuse the opt one if possible for better compile time (Enzyme internal caching).
11771173 let logic_ref = match fnc_opt {
1178- true => logic_ref_opt,
1174+ true => {
1175+ dbg ! ( "Enable extra optimizations for Enzyme" ) ;
1176+ logic_ref_opt
1177+ }
11791178 false => CreateEnzymeLogic ( fnc_opt as u8 ) ,
11801179 } ;
11811180 for item in higher_order_items {
1182- let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref) ;
1181+ let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref, ad ) ;
11831182 assert ! ( res. is_ok( ) ) ;
11841183 }
11851184
@@ -1212,14 +1211,14 @@ pub(crate) unsafe fn differentiate(
12121211 break ;
12131212 }
12141213 }
1215- if std :: env :: var ( "ENZYME_PRINT_MOD_AFTER_ENZYME" ) . is_ok ( ) {
1214+ if ad . contains ( & AutoDiff :: PrintModAfterEnzyme ) {
12161215 unsafe {
12171216 LLVMDumpModule ( llmod) ;
12181217 }
12191218 }
12201219
12211220
1222- if std :: env :: var ( "ENZYME_NO_MOD_OPT_AFTER" ) . is_ok ( ) || !differentiate {
1221+ if ad . contains ( & AutoDiff :: NoModOptAfter ) || !differentiate {
12231222 trace ! ( "Skipping module optimization after automatic differentiation" ) ;
12241223 } else {
12251224 if let Some ( opt_level) = config. opt_level {
@@ -1231,18 +1230,18 @@ pub(crate) unsafe fn differentiate(
12311230 } ;
12321231 let mut first_run = false ;
12331232 dbg ! ( "Running Module Optimization after differentiation" ) ;
1234- if std :: env :: var ( "ENZYME_NO_VEC_UNROLL" ) . is_ok ( ) {
1233+ if ad . contains ( & AutoDiff :: NoVecUnroll ) {
12351234 // disables vectorization and loop unrolling
12361235 first_run = true ;
12371236 }
1238- if std :: env :: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1237+ if ad . contains ( & AutoDiff :: AltPipeline ) {
12391238 dbg ! ( "Running first postAD optimization" ) ;
12401239 first_run = true ;
12411240 }
12421241 let noop = false ;
12431242 llvm_optimize ( cgcx, & diag_handler, module, config, opt_level, opt_stage, first_run, noop) ?;
12441243 }
1245- if std :: env :: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1244+ if ad . contains ( & AutoDiff :: AltPipeline ) {
12461245 dbg ! ( "Running Second postAD optimization" ) ;
12471246 if let Some ( opt_level) = config. opt_level {
12481247 let opt_stage = match cgcx. lto {
@@ -1253,7 +1252,7 @@ pub(crate) unsafe fn differentiate(
12531252 } ;
12541253 let mut first_run = false ;
12551254 dbg ! ( "Running Module Optimization after differentiation" ) ;
1256- if std :: env :: var ( "ENZYME_NO_VEC_UNROLL" ) . is_ok ( ) {
1255+ if ad . contains ( & AutoDiff :: NoVecUnroll ) {
12571256 // enables vectorization and loop unrolling
12581257 first_run = false ;
12591258 }
@@ -1263,7 +1262,7 @@ pub(crate) unsafe fn differentiate(
12631262 }
12641263 }
12651264
1266- if std :: env :: var ( "ENZYME_PRINT_MOD_AFTER_OPTS" ) . is_ok ( ) {
1265+ if ad . contains ( & AutoDiff :: PrintModAfterOpts ) {
12671266 unsafe {
12681267 LLVMDumpModule ( llmod) ;
12691268 }
@@ -1341,15 +1340,16 @@ pub(crate) unsafe fn optimize(
13411340 _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
13421341 _ => llvm:: OptStage :: PreLinkNoLTO ,
13431342 } ;
1343+
13441344 // Second run only relevant for AD
13451345 let first_run = true ;
1346- let noop;
1347- if std :: env :: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1348- noop = true ;
1349- dbg ! ( "Skipping PreAD optimization" ) ;
1350- } else {
1351- noop = false ;
1352- }
1346+ let noop = false ;
1347+ // if ad.contains(&AutoDiff::AltPipeline ) {
1348+ // noop = true;
1349+ // dbg!("Skipping PreAD optimization");
1350+ // } else {
1351+ // noop = false;
1352+ // }
13531353 return llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop) ;
13541354 }
13551355 Ok ( ( ) )
0 commit comments