@@ -3,7 +3,8 @@ use std::ptr;
33use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
44use rustc_codegen_ssa:: common:: TypeKind ;
55use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
6- use rustc_middle:: bug;
6+ use rustc_middle:: { bug, ty} ;
7+ use rustc_middle:: ty:: { PseudoCanonicalInput , Ty , TyCtxt , TypingEnv } ;
78use tracing:: debug;
89
910use crate :: builder:: { Builder , PlaceRef , UNNAMED } ;
@@ -14,6 +15,82 @@ use crate::llvm::{Metadata, True, Type};
1415use crate :: value:: Value ;
1516use crate :: { attributes, llvm} ;
1617
18+ pub ( crate ) fn adjust_activity_to_abi < ' tcx > (
19+ tcx : TyCtxt < ' tcx > ,
20+ fn_ty : Ty < ' tcx > ,
21+ da : & mut Vec < DiffActivity > ,
22+ ) {
23+ if !matches ! ( fn_ty. kind( ) , ty:: FnDef ( ..) ) {
24+ bug ! ( "expected fn def for autodiff, got {:?}" , fn_ty) ;
25+ }
26+
27+ // We don't actually pass the types back into the type system.
28+ // All we do is decide how to handle the arguments.
29+ let sig = fn_ty. fn_sig ( tcx) . skip_binder ( ) ;
30+
31+ let mut new_activities = vec ! [ ] ;
32+ let mut new_positions = vec ! [ ] ;
33+ for ( i, ty) in sig. inputs ( ) . iter ( ) . enumerate ( ) {
34+ if let Some ( inner_ty) = ty. builtin_deref ( true ) {
35+ if inner_ty. is_slice ( ) {
36+ // Now we need to figure out the size of each slice element in memory to allow
37+ // safety checks and usability improvements in the backend.
38+ let sty = match inner_ty. builtin_index ( ) {
39+ Some ( sty) => sty,
40+ None => {
41+ panic ! ( "slice element type unknown" ) ;
42+ }
43+ } ;
44+ let pci = PseudoCanonicalInput {
45+ typing_env : TypingEnv :: fully_monomorphized ( ) ,
46+ value : sty,
47+ } ;
48+
49+ let layout = tcx. layout_of ( pci) ;
50+ let elem_size = match layout {
51+ Ok ( layout) => layout. size ,
52+ Err ( _) => {
53+ bug ! ( "autodiff failed to compute slice element size" ) ;
54+ }
55+ } ;
56+ let elem_size: u32 = elem_size. bytes ( ) as u32 ;
57+
58+ // We know that the length will be passed as extra arg.
59+ if !da. is_empty ( ) {
60+ // We are looking at a slice. The length of that slice will become an
61+ // extra integer on llvm level. Integers are always const.
62+ // However, if the slice get's duplicated, we want to know to later check the
63+ // size. So we mark the new size argument as FakeActivitySize.
64+ // There is one FakeActivitySize per slice, so for convenience we store the
65+ // slice element size in bytes in it. We will use the size in the backend.
66+ let activity = match da[ i] {
67+ DiffActivity :: DualOnly
68+ | DiffActivity :: Dual
69+ | DiffActivity :: Dualv
70+ | DiffActivity :: DuplicatedOnly
71+ | DiffActivity :: Duplicated => {
72+ DiffActivity :: FakeActivitySize ( Some ( elem_size) )
73+ }
74+ DiffActivity :: Const => DiffActivity :: Const ,
75+ _ => bug ! ( "unexpected activity for ptr/ref" ) ,
76+ } ;
77+ new_activities. push ( activity) ;
78+ new_positions. push ( i + 1 ) ;
79+ }
80+
81+ continue ;
82+ }
83+ }
84+ }
85+ // now add the extra activities coming from slices
86+ // Reverse order to not invalidate the indices
87+ for _ in 0 ..new_activities. len ( ) {
88+ let pos = new_positions. pop ( ) . unwrap ( ) ;
89+ let activity = new_activities. pop ( ) . unwrap ( ) ;
90+ da. insert ( pos, activity) ;
91+ }
92+ }
93+
1794// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
1895// original inputs, as well as metadata and the additional shadow arguments.
1996// This function matches the arguments from the outer function to the inner enzyme call.
0 commit comments