@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
44use std:: sync:: Arc ;
55use std:: { fs, slice, str} ;
66
7- use libc:: { c_char, c_int, c_void, size_t} ;
7+ use libc:: { c_char, c_int, c_uint , c_void, size_t} ;
88use llvm:: {
99 LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
1010} ;
11+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
1112use rustc_codegen_ssa:: back:: link:: ensure_removed;
1213use rustc_codegen_ssa:: back:: versioned_llvm_target;
1314use rustc_codegen_ssa:: back:: write:: {
@@ -28,7 +29,7 @@ use rustc_session::config::{
2829use rustc_span:: InnerSpan ;
2930use rustc_span:: symbol:: sym;
3031use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
31- use tracing:: debug;
32+ use tracing:: { debug, trace } ;
3233
3334use crate :: back:: lto:: ThinBuffer ;
3435use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -40,8 +41,14 @@ use crate::errors::{
4041 CopyBitcode , FromLlvmDiag , FromLlvmOptimizationDiag , LlvmError , UnknownCompression ,
4142 WithLlvmError , WriteBytecode ,
4243} ;
44+ use crate :: llvm:: AttributePlace :: Function ;
4345use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
44- use crate :: llvm:: { self , DiagnosticInfo , PassManager } ;
46+ use crate :: llvm:: {
47+ self , AttributeKind , DiagnosticInfo , LLVMGetFirstFunction ,
48+ LLVMGetNextFunction , LLVMGetStringAttributeAtIndex , LLVMIsEnumAttribute , LLVMIsStringAttribute ,
49+ LLVMRemoveStringAttributeAtIndex , LLVMRustGetEnumAttributeAtIndex ,
50+ LLVMRustRemoveEnumAttributeAtIndex , PassManager ,
51+ } ;
4552use crate :: type_:: Type ;
4653use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
4754
@@ -517,9 +524,38 @@ pub(crate) unsafe fn llvm_optimize(
517524 config : & ModuleConfig ,
518525 opt_level : config:: OptLevel ,
519526 opt_stage : llvm:: OptStage ,
527+ skip_size_increasing_opts : bool ,
520528) -> Result < ( ) , FatalError > {
521- let unroll_loops =
522- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
529+ // Enzyme:
530+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
531+ // source code. However, benchmarks show that optimizations increasing the code size
532+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
533+ // and finally re-optimize the module, now with all optimizations available.
534+ // TODO: In a future update we could figure out how to only optimize individual functions getting
535+ // differentiated.
536+
537+ let unroll_loops;
538+ let vectorize_slp;
539+ let vectorize_loop;
540+
541+ // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
542+ // optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
543+ // we should make this more granular, or at least check that the user has at least one autodiff
544+ // call in their code, to justify altering the compilation pipeline.
545+ if skip_size_increasing_opts && cfg ! ( llvm_enzyme) {
546+ unroll_loops = false ;
547+ vectorize_slp = false ;
548+ vectorize_loop = false ;
549+ } else {
550+ unroll_loops =
551+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
552+ vectorize_slp = config. vectorize_slp ;
553+ vectorize_loop = config. vectorize_loop ;
554+ }
555+ trace ! (
556+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
557+ unroll_loops, vectorize_slp, vectorize_loop
558+ ) ;
523559 let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
524560 let pgo_gen_path = get_pgo_gen_path ( config) ;
525561 let pgo_use_path = get_pgo_use_path ( config) ;
@@ -583,8 +619,8 @@ pub(crate) unsafe fn llvm_optimize(
583619 using_thin_buffers,
584620 config. merge_functions ,
585621 unroll_loops,
586- config . vectorize_slp ,
587- config . vectorize_loop ,
622+ vectorize_slp,
623+ vectorize_loop,
588624 config. no_builtins ,
589625 config. emit_lifetime_markers ,
590626 sanitizer_options. as_ref ( ) ,
@@ -606,6 +642,115 @@ pub(crate) unsafe fn llvm_optimize(
606642 result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
607643}
608644
645+ pub ( crate ) fn differentiate (
646+ module : & ModuleCodegen < ModuleLlvm > ,
647+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
648+ diff_items : Vec < AutoDiffItem > ,
649+ config : & ModuleConfig ,
650+ ) -> Result < ( ) , FatalError > {
651+ for item in & diff_items {
652+ trace ! ( "{}" , item) ;
653+ }
654+
655+ let llmod = module. module_llvm . llmod ( ) ;
656+ let llcx = & module. module_llvm . llcx ;
657+ let diag_handler = cgcx. create_dcx ( ) ;
658+
659+ // Before dumping the module, we want all the tt to become part of the module.
660+ for item in diff_items. iter ( ) {
661+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
662+ let fn_def: Option < & llvm:: Value > =
663+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
664+ let fn_def = match fn_def {
665+ Some ( x) => x,
666+ None => {
667+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
668+ src : item. source . clone ( ) ,
669+ target : item. target . clone ( ) ,
670+ error : "could not find source function" . to_owned ( ) ,
671+ } ) ) ;
672+ }
673+ } ;
674+ let target_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
675+ debug ! ( "target name: {:?}" , & target_name) ;
676+ let fn_target: Option < & llvm:: Value > =
677+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, target_name. as_ptr ( ) ) } ;
678+ let fn_target = match fn_target {
679+ Some ( x) => x,
680+ None => {
681+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
682+ src : item. source . clone ( ) ,
683+ target : item. target . clone ( ) ,
684+ error : "could not find target function" . to_owned ( ) ,
685+ } ) ) ;
686+ }
687+ } ;
688+
689+ crate :: builder:: generate_enzyme_call ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
690+ }
691+
692+ // We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
693+ // which Enzyme doesn't understand.
694+ unsafe {
695+ let mut f = LLVMGetFirstFunction ( llmod) ;
696+ loop {
697+ if let Some ( lf) = f {
698+ f = LLVMGetNextFunction ( lf) ;
699+ let myhwattr = "enzyme_hw" ;
700+ let attr = LLVMGetStringAttributeAtIndex (
701+ lf,
702+ c_uint:: MAX ,
703+ myhwattr. as_ptr ( ) as * const c_char ,
704+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
705+ ) ;
706+ if LLVMIsStringAttribute ( attr) {
707+ LLVMRemoveStringAttributeAtIndex (
708+ lf,
709+ c_uint:: MAX ,
710+ myhwattr. as_ptr ( ) as * const c_char ,
711+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
712+ ) ;
713+ } else {
714+ LLVMRustRemoveEnumAttributeAtIndex (
715+ lf,
716+ c_uint:: MAX ,
717+ AttributeKind :: SanitizeHWAddress ,
718+ ) ;
719+ }
720+ } else {
721+ break ;
722+ }
723+ }
724+ }
725+
726+ if let Some ( opt_level) = config. opt_level {
727+ let opt_stage = match cgcx. lto {
728+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
729+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
730+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
731+ _ => llvm:: OptStage :: PreLinkNoLTO ,
732+ } ;
733+ // This is our second opt call, so now we run all opts,
734+ // to make sure we get the best performance.
735+ let skip_size_increasing_opts = false ;
736+ trace ! ( "running Module Optimization after differentiation" ) ;
737+ unsafe {
738+ llvm_optimize (
739+ cgcx,
740+ diag_handler. handle ( ) ,
741+ module,
742+ config,
743+ opt_level,
744+ opt_stage,
745+ skip_size_increasing_opts,
746+ ) ?
747+ } ;
748+ }
749+ trace ! ( "done with differentiate()" ) ;
750+
751+ Ok ( ( ) )
752+ }
753+
609754// Unsafe due to LLVM calls.
610755pub ( crate ) unsafe fn optimize (
611756 cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -628,14 +773,57 @@ pub(crate) unsafe fn optimize(
628773 unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
629774 }
630775
776+ // This code enables Enzyme to differentiate code containing Rust enums.
777+ // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
778+ // away the enums and allows Enzyme to understand why a value can be of different types in
779+ // different code sections. We remove this attribute after Enzyme is done, to not affect the
780+ // rest of the compilation.
781+ //#[cfg(llvm_enzyme)]
782+ unsafe {
783+ let mut f = LLVMGetFirstFunction ( llmod) ;
784+ loop {
785+ if let Some ( lf) = f {
786+ f = LLVMGetNextFunction ( lf) ;
787+ let myhwattr = "enzyme_hw" ;
788+ let prevattr = LLVMRustGetEnumAttributeAtIndex (
789+ lf,
790+ c_uint:: MAX ,
791+ AttributeKind :: SanitizeHWAddress ,
792+ ) ;
793+ if LLVMIsEnumAttribute ( prevattr) {
794+ let attr = llvm:: CreateAttrString ( llcx, myhwattr) ;
795+ crate :: attributes:: apply_to_llfn ( lf, Function , & [ attr] ) ;
796+ } else {
797+ let attr = AttributeKind :: SanitizeHWAddress . create_attr ( llcx) ;
798+ crate :: attributes:: apply_to_llfn ( lf, Function , & [ attr] ) ;
799+ }
800+ } else {
801+ break ;
802+ }
803+ }
804+ }
805+
631806 if let Some ( opt_level) = config. opt_level {
632807 let opt_stage = match cgcx. lto {
633808 Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
634809 Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
635810 _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
636811 _ => llvm:: OptStage :: PreLinkNoLTO ,
637812 } ;
638- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
813+
814+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
815+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
816+ return unsafe {
817+ llvm_optimize (
818+ cgcx,
819+ dcx,
820+ module,
821+ config,
822+ opt_level,
823+ opt_stage,
824+ skip_size_increasing_opts,
825+ )
826+ } ;
639827 }
640828 Ok ( ( ) )
641829}
0 commit comments