1+ use rustc_abi:: HasDataLayout ;
12use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffItem , DiffActivity } ;
23use rustc_hir:: def_id:: LOCAL_CRATE ;
34use rustc_middle:: bug;
@@ -16,6 +17,7 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
1617 // We don't actually pass the types back into the type system.
1718 // All we do is decide how to handle the arguments.
1819 let sig = fn_ty. fn_sig ( tcx) . skip_binder ( ) ;
20+ let pointer_size = tcx. data_layout ( ) . pointer_size ;
1921
2022 let mut new_activities = vec ! [ ] ;
2123 let mut new_positions = vec ! [ ] ;
@@ -70,6 +72,25 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
7072 continue ;
7173 }
7274 }
75+
76+ let pci = PseudoCanonicalInput { typing_env : TypingEnv :: fully_monomorphized ( ) , value : * ty } ;
77+
78+ let layout = match tcx. layout_of ( pci) {
79+ Ok ( layout) => layout. layout ,
80+ Err ( _) => {
81+ bug ! ( "failed to compute layout for type {:?}" , ty) ;
82+ }
83+ } ;
84+
85+ let is_product = |t : Ty < ' tcx > | matches ! ( t. kind( ) , ty:: Tuple ( _) | ty:: Adt ( _, _) ) ;
86+
87+ if layout. size ( ) <= pointer_size * 2 && is_product ( * ty) {
88+ let n_scalars = count_scalar_fields ( tcx, * ty) ;
89+ for _ in 0 ..n_scalars. saturating_sub ( 1 ) {
90+ new_activities. push ( da[ i] . clone ( ) ) ;
91+ new_positions. push ( i + 1 ) ;
92+ }
93+ }
7394 }
7495 // now add the extra activities coming from slices
7596 // Reverse order to not invalidate the indices
@@ -80,6 +101,20 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
80101 }
81102}
82103
104+ fn count_scalar_fields < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> usize {
105+ match ty. kind ( ) {
106+ ty:: Float ( _) | ty:: Int ( _) | ty:: Uint ( _) => 1 ,
107+ ty:: Adt ( def, substs) if def. is_struct ( ) => def
108+ . non_enum_variant ( )
109+ . fields
110+ . iter ( )
111+ . map ( |f| count_scalar_fields ( tcx, f. ty ( tcx, substs) ) )
112+ . sum ( ) ,
113+ ty:: Tuple ( substs) => substs. iter ( ) . map ( |t| count_scalar_fields ( tcx, t) ) . sum ( ) ,
114+ _ => 0 ,
115+ }
116+ }
117+
83118pub ( crate ) fn find_autodiff_source_functions < ' tcx > (
84119 tcx : TyCtxt < ' tcx > ,
85120 usage_map : & UsageMap < ' tcx > ,
0 commit comments