11use std:: str:: FromStr ;
22
3- use rustc_abi:: ExternAbi ;
3+ use rustc_abi:: { ExternAbi , Size } ;
44use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
55use rustc_ast:: { LitKind , MetaItem , MetaItemInner , attr} ;
66use rustc_attr_data_structures:: ReprAttr :: ReprAlign ;
@@ -16,7 +16,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
1616use rustc_middle:: mir:: mono:: Linkage ;
1717use rustc_middle:: query:: Providers ;
1818use rustc_middle:: span_bug;
19- use rustc_middle:: ty:: { self as ty, TyCtxt } ;
19+ use rustc_middle:: ty:: { self as ty, PseudoCanonicalInput , Ty , TyCtxt , TypingEnv } ;
2020use rustc_session:: parse:: feature_err;
2121use rustc_session:: { Session , lint} ;
2222use rustc_span:: { Ident , Span , sym} ;
@@ -138,6 +138,28 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
138138 sym:: rustc_allocator_zeroed => {
139139 codegen_fn_attrs. flags |= CodegenFnAttrFlags :: ALLOCATOR_ZEROED
140140 }
141+ sym:: rustc_autodiff => {
142+ let list = attr. meta_item_list ( ) . unwrap_or_default ( ) ;
143+ if list. is_empty ( ) {
144+ // Add the flag only to the primal function so LLVM can
145+ // optimize the derivative function.
146+ if let Some ( sig) = fn_sig ( ) {
147+ let sig = sig. skip_binder ( ) ;
148+
149+ let has_problematic_args = sig
150+ . skip_binder ( )
151+ . inputs ( )
152+ . iter ( )
153+ . any ( |ty| is_abi_opt_sensitive ( tcx, * ty) ) ;
154+
155+ if has_problematic_args {
156+ codegen_fn_attrs. flags |= CodegenFnAttrFlags :: RUSTC_AUTODIFF_NO_ABI_OPT ;
157+ }
158+ }
159+
160+ // TODO(Sa4dUs): Handle static variable passed as argument case.
161+ }
162+ }
141163 sym:: naked => codegen_fn_attrs. flags |= CodegenFnAttrFlags :: NAKED ,
142164 sym:: no_mangle => {
143165 no_mangle_span = Some ( attr. span ( ) ) ;
@@ -899,6 +921,44 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
899921 Some ( AutoDiffAttrs { mode, width, ret_activity, input_activity : arg_activities } )
900922}
901923
924+ fn is_abi_opt_sensitive < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> bool {
925+ match ty. kind ( ) {
926+ ty:: Ref ( _, inner, _) | ty:: RawPtr ( inner, _) => {
927+ match inner. kind ( ) {
928+ ty:: Slice ( _) => {
929+ // Since we cannot guarantee that the slice length is large enough
930+ // to avoid optimization, we assume it is ABI-opt sensitive.
931+ return true ;
932+ }
933+ ty:: Array ( elem_ty, len) => {
934+ let Some ( len_val) = len. try_to_target_usize ( tcx) else {
935+ return false ;
936+ } ;
937+
938+ let pci = PseudoCanonicalInput {
939+ typing_env : TypingEnv :: fully_monomorphized ( ) ,
940+ value : * elem_ty,
941+ } ;
942+
943+ if elem_ty. is_scalar ( ) {
944+ let elem_size =
945+ tcx. layout_of ( pci) . ok ( ) . map ( |layout| layout. size ) . unwrap_or ( Size :: ZERO ) ;
946+
947+ if elem_size. bytes ( ) * len_val <= tcx. data_layout . pointer_size . bytes ( ) * 2 {
948+ return true ;
949+ }
950+ }
951+ }
952+ _ => { }
953+ }
954+
955+ false
956+ }
957+ ty:: FnPtr ( _, _) => true ,
958+ _ => false ,
959+ }
960+ }
961+
902962pub ( crate ) fn provide ( providers : & mut Providers ) {
903963 * providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..* providers } ;
904964}
0 commit comments