@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
44use std:: sync:: Arc ;
55use std:: { fs, slice, str} ;
66
7- use libc:: { c_char, c_int, c_void, size_t} ;
7+ use libc:: { c_char, c_int, c_uint , c_void, size_t} ;
88use llvm:: {
99 LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
1010} ;
11+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
1112use rustc_codegen_ssa:: back:: link:: ensure_removed;
1213use rustc_codegen_ssa:: back:: versioned_llvm_target;
1314use rustc_codegen_ssa:: back:: write:: {
@@ -28,7 +29,7 @@ use rustc_session::config::{
2829use rustc_span:: InnerSpan ;
2930use rustc_span:: symbol:: sym;
3031use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
31- use tracing:: debug;
32+ use tracing:: { debug, trace } ;
3233
3334use crate :: back:: lto:: ThinBuffer ;
3435use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -41,7 +42,13 @@ use crate::errors::{
4142 WithLlvmError , WriteBytecode ,
4243} ;
4344use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
44- use crate :: llvm:: { self , DiagnosticInfo , PassManager } ;
45+ use crate :: llvm:: {
46+ self , AttributeKind , DiagnosticInfo , LLVMCreateStringAttribute , LLVMGetFirstFunction ,
47+ LLVMGetNextFunction , LLVMGetStringAttributeAtIndex , LLVMIsEnumAttribute , LLVMIsStringAttribute ,
48+ LLVMRemoveStringAttributeAtIndex , LLVMRustAddEnumAttributeAtIndex ,
49+ LLVMRustAddFunctionAttributes , LLVMRustGetEnumAttributeAtIndex ,
50+ LLVMRustRemoveEnumAttributeAtIndex , PassManager ,
51+ } ;
4552use crate :: type_:: Type ;
4653use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
4754
@@ -517,9 +524,34 @@ pub(crate) unsafe fn llvm_optimize(
517524 config : & ModuleConfig ,
518525 opt_level : config:: OptLevel ,
519526 opt_stage : llvm:: OptStage ,
527+ skip_size_increasing_opts : bool ,
520528) -> Result < ( ) , FatalError > {
521- let unroll_loops =
522- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
529+ // Enzyme:
530+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
531+ // source code. However, benchmarks show that optimizations increasing the code size
532+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
533+ // and finally re-optimize the module, now with all optimizations available.
534+ // TODO: In a future update we could figure out how to only optimize functions getting
535+ // differentiated.
536+
537+ let unroll_loops;
538+ let vectorize_slp;
539+ let vectorize_loop;
540+
541+ if skip_size_increasing_opts {
542+ unroll_loops = false ;
543+ vectorize_slp = false ;
544+ vectorize_loop = false ;
545+ } else {
546+ unroll_loops =
547+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
548+ vectorize_slp = config. vectorize_slp ;
549+ vectorize_loop = config. vectorize_loop ;
550+ }
551+ trace ! (
552+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
553+ unroll_loops, vectorize_slp, vectorize_loop
554+ ) ;
523555 let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
524556 let pgo_gen_path = get_pgo_gen_path ( config) ;
525557 let pgo_use_path = get_pgo_use_path ( config) ;
@@ -583,8 +615,8 @@ pub(crate) unsafe fn llvm_optimize(
583615 using_thin_buffers,
584616 config. merge_functions ,
585617 unroll_loops,
586- config . vectorize_slp ,
587- config . vectorize_loop ,
618+ vectorize_slp,
619+ vectorize_loop,
588620 config. no_builtins ,
589621 config. emit_lifetime_markers ,
590622 sanitizer_options. as_ref ( ) ,
@@ -606,6 +638,113 @@ pub(crate) unsafe fn llvm_optimize(
606638 result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
607639}
608640
641+ pub ( crate ) fn differentiate (
642+ module : & ModuleCodegen < ModuleLlvm > ,
643+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
644+ diff_items : Vec < AutoDiffItem > ,
645+ config : & ModuleConfig ,
646+ ) -> Result < ( ) , FatalError > {
647+ for item in & diff_items {
648+ trace ! ( "{}" , item) ;
649+ }
650+
651+ let llmod = module. module_llvm . llmod ( ) ;
652+ let llcx = & module. module_llvm . llcx ;
653+ let diag_handler = cgcx. create_dcx ( ) ;
654+
655+ // Before dumping the module, we want all the tt to become part of the module.
656+ for item in diff_items. iter ( ) {
657+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
658+ let fn_def: Option < & llvm:: Value > =
659+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
660+ let fn_def = match fn_def {
661+ Some ( x) => x,
662+ None => {
663+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
664+ src : item. source . clone ( ) ,
665+ target : item. target . clone ( ) ,
666+ error : "could not find source function" . to_owned ( ) ,
667+ } ) ) ;
668+ }
669+ } ;
670+ let tgt_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
671+ dbg ! ( "Target name: {:?}" , & tgt_name) ;
672+ let fn_target: Option < & llvm:: Value > =
673+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, tgt_name. as_ptr ( ) ) } ;
674+ let fn_target = match fn_target {
675+ Some ( x) => x,
676+ None => {
677+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
678+ src : item. source . clone ( ) ,
679+ target : item. target . clone ( ) ,
680+ error : "could not find target function" . to_owned ( ) ,
681+ } ) ) ;
682+ }
683+ } ;
684+
685+ crate :: builder:: add_opt_dbg_helper2 ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
686+ }
687+
688+ // We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
689+ // which Enzyme doesn't understand.
690+ unsafe {
691+ let mut f = LLVMGetFirstFunction ( llmod) ;
692+ loop {
693+ if let Some ( lf) = f {
694+ f = LLVMGetNextFunction ( lf) ;
695+ let myhwattr = "enzyme_hw" ;
696+ let attr = LLVMGetStringAttributeAtIndex (
697+ lf,
698+ c_uint:: MAX ,
699+ myhwattr. as_ptr ( ) as * const c_char ,
700+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
701+ ) ;
702+ if LLVMIsStringAttribute ( attr) {
703+ LLVMRemoveStringAttributeAtIndex (
704+ lf,
705+ c_uint:: MAX ,
706+ myhwattr. as_ptr ( ) as * const c_char ,
707+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
708+ ) ;
709+ } else {
710+ LLVMRustRemoveEnumAttributeAtIndex (
711+ lf,
712+ c_uint:: MAX ,
713+ AttributeKind :: SanitizeHWAddress ,
714+ ) ;
715+ }
716+ } else {
717+ break ;
718+ }
719+ }
720+ }
721+
722+ if let Some ( opt_level) = config. opt_level {
723+ let opt_stage = match cgcx. lto {
724+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
725+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
726+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
727+ _ => llvm:: OptStage :: PreLinkNoLTO ,
728+ } ;
729+ let skip_size_increasing_opts = false ;
730+ dbg ! ( "Running Module Optimization after differentiation" ) ;
731+ unsafe {
732+ llvm_optimize (
733+ cgcx,
734+ diag_handler. handle ( ) ,
735+ module,
736+ config,
737+ opt_level,
738+ opt_stage,
739+ skip_size_increasing_opts,
740+ ) ?
741+ } ;
742+ }
743+ dbg ! ( "Done with differentiate()" ) ;
744+
745+ Ok ( ( ) )
746+ }
747+
609748// Unsafe due to LLVM calls.
610749pub ( crate ) unsafe fn optimize (
611750 cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -628,14 +767,68 @@ pub(crate) unsafe fn optimize(
628767 unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
629768 }
630769
770+ // This code enables Enzyme to differentiate code containing Rust enums.
771+ // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
772+ // away the enums and allows Enzyme to understand why a value can be of different types in
773+ // different code sections. We remove this attribute after Enzyme is done, to not affect the
774+ // rest of the compilation.
775+ #[ cfg( llvm_enzyme) ]
776+ unsafe {
777+ let mut f = LLVMGetFirstFunction ( llmod) ;
778+ loop {
779+ if let Some ( lf) = f {
780+ f = LLVMGetNextFunction ( lf) ;
781+ let myhwattr = "enzyme_hw" ;
782+ let myhwv = "" ;
783+ let prevattr = LLVMRustGetEnumAttributeAtIndex (
784+ lf,
785+ c_uint:: MAX ,
786+ AttributeKind :: SanitizeHWAddress ,
787+ ) ;
788+ if LLVMIsEnumAttribute ( prevattr) {
789+ let attr = LLVMCreateStringAttribute (
790+ llcx,
791+ myhwattr. as_ptr ( ) as * const c_char ,
792+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
793+ myhwv. as_ptr ( ) as * const c_char ,
794+ myhwv. as_bytes ( ) . len ( ) as c_uint ,
795+ ) ;
796+ LLVMRustAddFunctionAttributes ( lf, c_uint:: MAX , & attr, 1 ) ;
797+ } else {
798+ LLVMRustAddEnumAttributeAtIndex (
799+ llcx,
800+ lf,
801+ c_uint:: MAX ,
802+ AttributeKind :: SanitizeHWAddress ,
803+ ) ;
804+ }
805+ } else {
806+ break ;
807+ }
808+ }
809+ }
810+
631811 if let Some ( opt_level) = config. opt_level {
632812 let opt_stage = match cgcx. lto {
633813 Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
634814 Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
635815 _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
636816 _ => llvm:: OptStage :: PreLinkNoLTO ,
637817 } ;
638- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
818+
819+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
820+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
821+ return unsafe {
822+ llvm_optimize (
823+ cgcx,
824+ dcx,
825+ module,
826+ config,
827+ opt_level,
828+ opt_stage,
829+ skip_size_increasing_opts,
830+ )
831+ } ;
639832 }
640833 Ok ( ( ) )
641834}
0 commit comments