@@ -48,16 +48,15 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
4848 }
4949}
5050
51- // The lowering of one `#[autodiff]` macro happens in multiple steps.
52- // First we transalte generate a new dummy function, who's llvm-ir we now have as outer_fn.
53- // We kept track of the original function to which the `#[autodiff]` macro was applied to, which we
54- // now have as fn_to_diff. In our current implementation, we use the enzyme pass to carry out the
55- // differentiation, following naming and calling conventions documented here: <https://enzyme.mit.edu/getting_started/CallingConvention/>
56- //
57- // Our `outer_fn` had some dummy code inserted at higher levels, so we first remove most of the
58- // existing body. We then insert an `__enzyme_<autodiff/fwddiff>_<unique_id>` call, which the pass
59- // will then pick up. FIXME(ZuseZ4): We will later want to upstream safety checks to the `outer_fn`,
60- // in order to cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
51+ /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
52+ /// function with expected naming and calling conventions[^1] which will be
53+ /// discovered by the enzyme LLVM pass and its body populated with the differentiated
54+ /// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
55+ /// function and handle the differences between the Rust calling convention and
56+ /// Enzyme.
57+ /// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
58+ // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
59+ // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
6160pub ( crate ) fn generate_enzyme_call < ' ll > (
6261 llmod : & ' ll llvm:: Module ,
6362 llcx : & ' ll llvm:: Context ,
@@ -69,7 +68,7 @@ pub(crate) fn generate_enzyme_call<'ll>(
6968 let output = attrs. ret_activity ;
7069
7170 // We have to pick the name depending on whether we want forward or reverse mode autodiff.
72- // FIXME(ZuseZ4): The new pass based approach should not need the * First method anymore, since
71+ // FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse} First method anymore, since
7372 // it will handle higher-order derivatives correctly automatically (in theory). Currently
7473 // higher-order derivatives fail, so we should debug that before adjusting this code.
7574 let mut ad_name: String = match attrs. mode {
@@ -87,16 +86,38 @@ pub(crate) fn generate_enzyme_call<'ll>(
8786 let outer_fn_name = std:: ffi:: CStr :: from_bytes_with_nul ( name) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
8887 ad_name. push_str ( outer_fn_name. to_string ( ) . as_str ( ) ) ;
8988
90- // Assuming that our fn_to_diff is the fnc square, want to generate the following llvm-ir, which
91- // would allow the enzyme pass to generate a function body for `__enzyme_autodiff_square`
89+ // Let us assume the user wrote the following function square:
9290 //
91+ // ```llvm
92+ // define double @square(double %x) {
93+ // entry:
94+ // %0 = fmul double %x, %x
95+ // ret double %0
96+ // }
97+ // ```
98+ //
99+ // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
100+ // Our macro generates the following placeholder code (slightly simplified):
101+ //
102+ // ```llvm
103+ // define double @dsquare(double %x) {
104+ // ; placeholder code
105+ // return 0.0;
106+ // }
107+ // ```
108+ //
109+ // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
110+ // code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
111+ // Again, the arguments to all functions are slightly simplified.
112+ // ```llvm
93113 // declare double @__enzyme_autodiff_square(...)
94114 //
95115 // define double @dsquare(double %x) {
96116 // entry:
97117 // %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
98118 // ret double %0
99119 // }
120+ // ```
100121 unsafe {
101122 // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
102123 // arguments. We do however need to declare them with their correct return type.
0 commit comments