1+ /// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2+ /// we create an `AutoDiffItem` which contains the source and target function names. The source
3+ /// is the function to which the autodiff attribute is applied, and the target is the function
4+ /// getting generated by us (with a name given by the user as the first autodiff arg).
15use std:: fmt:: { self , Display , Formatter } ;
26use std:: str:: FromStr ;
37
@@ -6,27 +10,91 @@ use crate::expand::{Decodable, Encodable, HashStable_Generic};
610use crate :: ptr:: P ;
711use crate :: { Ty , TyKind } ;
812
9- #[ allow( dead_code) ]
1013#[ derive( Clone , Copy , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
1114pub enum DiffMode {
15+ /// No autodiff is applied (usually used during error handling).
1216 Inactive ,
17+ /// The primal function which we will differentiate.
1318 Source ,
19+ /// The target function, to be created using forward mode AD.
1420 Forward ,
21+ /// The target function, to be created using reverse mode AD.
1522 Reverse ,
23+ /// The target function, to be created using forward mode AD.
24+ /// This target function will also be used as a source for higher order derivatives,
25+ /// so compute it before all Forward/Reverse targets and optimize it through llvm.
1626 ForwardFirst ,
27+ /// The target function, to be created using reverse mode AD.
28+ /// This target function will also be used as a source for higher order derivatives,
29+ /// so compute it before all Forward/Reverse targets and optimize it through llvm.
1730 ReverseFirst ,
1831}
1932
20- pub fn is_rev ( mode : DiffMode ) -> bool {
21- match mode {
22- DiffMode :: Reverse | DiffMode :: ReverseFirst => true ,
23- _ => false ,
24- }
33+ /// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
34+ /// However, under forward mode we overwrite the previous shadow value, while for reverse mode
35+ /// we add to the previous shadow value. To not surprise users, we picked different names.
36+ /// Dual numbers is also a quite well known name for forward mode AD types.
37+ #[ derive( Clone , Copy , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
38+ pub enum DiffActivity {
39+ /// Implicit or Explicit () return type, so a special case of Const.
40+ None ,
41+ /// Don't compute derivatives with respect to this input/output.
42+ Const ,
43+ /// Reverse Mode, Compute derivatives for this scalar input/output.
44+ Active ,
45+ /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
46+ /// the original return value.
47+ ActiveOnly ,
48+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
49+ /// with it.
50+ Dual ,
51+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
52+ /// with it. Drop the code which updates the original input/output for maximum performance.
53+ DualOnly ,
54+ /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
55+ Duplicated ,
56+ /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
57+ /// Drop the code which updates the original input for maximum performance.
58+ DuplicatedOnly ,
59+ /// All Integers must be Const, but these are used to mark the integer which represents the
60+ /// length of a slice/vec. This is used for safety checks on slices.
61+ FakeActivitySize ,
2562}
26- pub fn is_fwd ( mode : DiffMode ) -> bool {
27- match mode {
28- DiffMode :: Forward | DiffMode :: ForwardFirst => true ,
29- _ => false ,
63+ /// We generate one of these structs for each `#[autodiff(...)]` attribute.
64+ #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
65+ pub struct AutoDiffItem {
66+ /// The name of the function getting differentiated
67+ pub source : String ,
68+ /// The name of the function being generated
69+ pub target : String ,
70+ pub attrs : AutoDiffAttrs ,
71+ /// Despribe the memory layout of input types
72+ pub inputs : Vec < TypeTree > ,
73+ /// Despribe the memory layout of the output type
74+ pub output : TypeTree ,
75+ }
76+ #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
77+ pub struct AutoDiffAttrs {
78+ /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
79+ /// e.g. in the [JAX
80+ /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
81+ pub mode : DiffMode ,
82+ pub ret_activity : DiffActivity ,
83+ pub input_activity : Vec < DiffActivity > ,
84+ }
85+
86+ impl DiffMode {
87+ pub fn is_rev ( & self ) -> bool {
88+ match self {
89+ DiffMode :: Reverse | DiffMode :: ReverseFirst => true ,
90+ _ => false ,
91+ }
92+ }
93+ pub fn is_fwd ( & self ) -> bool {
94+ match self {
95+ DiffMode :: Forward | DiffMode :: ForwardFirst => true ,
96+ _ => false ,
97+ }
3098 }
3199}
32100
@@ -63,30 +131,20 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
63131 }
64132 }
65133}
66- fn is_ptr_or_ref ( ty : & Ty ) -> bool {
67- match ty. kind {
68- TyKind :: Ptr ( _) | TyKind :: Ref ( _, _) => true ,
69- _ => false ,
70- }
71- }
72- // TODO We should make this more robust to also
134+
135+ // FIXME(ZuseZ4) We should make this more robust to also
73136// accept aliases of f32 and f64
74- //fn is_float(ty: &Ty) -> bool {
75- // false
76- //}
77137pub fn valid_ty_for_activity ( ty : & P < Ty > , activity : DiffActivity ) -> bool {
78- if is_ptr_or_ref ( ty) {
79- return activity == DiffActivity :: Dual
80- || activity == DiffActivity :: DualOnly
81- || activity == DiffActivity :: Duplicated
82- || activity == DiffActivity :: DuplicatedOnly
83- || activity == DiffActivity :: Const ;
138+ match ty. kind {
139+ TyKind :: Ptr ( _) | TyKind :: Ref ( ..) => {
140+ return activity == DiffActivity :: Dual
141+ || activity == DiffActivity :: DualOnly
142+ || activity == DiffActivity :: Duplicated
143+ || activity == DiffActivity :: DuplicatedOnly
144+ || activity == DiffActivity :: Const ;
145+ }
146+ _ => false ,
84147 }
85- true
86- //if is_scalar_ty(&ty) {
87- // return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
88- // activity == DiffActivity::Const;
89- //}
90148}
91149pub fn valid_input_activity ( mode : DiffMode , activity : DiffActivity ) -> bool {
92150 return match mode {
@@ -117,20 +175,6 @@ pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -
117175 None
118176}
119177
120- #[ allow( dead_code) ]
121- #[ derive( Clone , Copy , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
122- pub enum DiffActivity {
123- None ,
124- Const ,
125- Active ,
126- ActiveOnly ,
127- Dual ,
128- DualOnly ,
129- Duplicated ,
130- DuplicatedOnly ,
131- FakeActivitySize ,
132- }
133-
134178impl Display for DiffActivity {
135179 fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
136180 match self {
@@ -180,30 +224,14 @@ impl FromStr for DiffActivity {
180224 }
181225}
182226
183- #[ allow( dead_code) ]
184- #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
185- pub struct AutoDiffAttrs {
186- pub mode : DiffMode ,
187- pub ret_activity : DiffActivity ,
188- pub input_activity : Vec < DiffActivity > ,
189- }
190-
191227impl AutoDiffAttrs {
192228 pub fn has_ret_activity ( & self ) -> bool {
193- match self . ret_activity {
194- DiffActivity :: None => false ,
195- _ => true ,
196- }
229+ self . ret_activity != DiffActivity :: None
197230 }
198231 pub fn has_active_only_ret ( & self ) -> bool {
199- match self . ret_activity {
200- DiffActivity :: ActiveOnly => true ,
201- _ => false ,
202- }
232+ self . ret_activity == DiffActivity :: ActiveOnly
203233 }
204- }
205234
206- impl AutoDiffAttrs {
207235 pub fn inactive ( ) -> Self {
208236 AutoDiffAttrs {
209237 mode : DiffMode :: Inactive ,
@@ -251,15 +279,6 @@ impl AutoDiffAttrs {
251279 }
252280}
253281
254- #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
255- pub struct AutoDiffItem {
256- pub source : String ,
257- pub target : String ,
258- pub attrs : AutoDiffAttrs ,
259- pub inputs : Vec < TypeTree > ,
260- pub output : TypeTree ,
261- }
262-
263282impl fmt:: Display for AutoDiffItem {
264283 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
265284 write ! ( f, "Differentiating {} -> {}" , self . source, self . target) ?;
0 commit comments