11use std:: ptr;
2-
32use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
43use rustc_codegen_ssa:: ModuleCodegen ;
54use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6- use rustc_codegen_ssa :: traits :: BaseTypeCodegenMethods as _ ;
7- use rustc_errors :: FatalError ;
5+ use rustc_errors :: { DiagCtxt , FatalError } ;
6+ use rustc_middle :: bug ;
87use tracing:: { debug, trace} ;
98
109use crate :: back:: write:: llvm_err;
1110use crate :: builder:: SBuilder ;
1211use crate :: context:: SimpleCx ;
1312use crate :: declare:: declare_simple_fn;
14- use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
13+ use crate :: errors:: { AutoDiffUnusedArgs , AutoDiffWithoutEnable , LlvmError } ;
1514use crate :: llvm:: AttributePlace :: Function ;
1615use crate :: llvm:: { Metadata , True } ;
1716use crate :: value:: Value ;
17+
1818use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
1919
2020fn get_params ( fnc : & Value ) -> Vec < & Value > {
@@ -28,6 +28,25 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2828 }
2929}
3030
31+ fn has_sret ( fnc : & Value ) -> bool {
32+ let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
33+ if num_args == 0 {
34+ false
35+ } else {
36+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
37+ }
38+ }
39+
40+ // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
41+ // original inputs, as well as metadata and the additional shadow arguments.
42+ // This function matches the arguments from the outer function to the inner enzyme call.
43+ //
44+ // This function also considers that Rust level arguments not always match the llvm-ir level
45+ // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
46+ // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
47+ // need to match those.
48+ // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
49+ // using iterators and peek()?
3150fn match_args_from_caller_to_enzyme < ' ll > (
3251 cx : & SimpleCx < ' ll > ,
3352 args : & mut Vec < & ' ll llvm:: Value > ,
@@ -135,6 +154,78 @@ fn match_args_from_caller_to_enzyme<'ll>(
135154 }
136155}
137156
157+
158+ // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
159+ // arguments. We do however need to declare them with their correct return type.
160+ // We already figured the correct return type out in our frontend, when generating the outer_fn,
161+ // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
162+ // Beyond sret, this article describes our challenges nicely:
163+ // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
164+ // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
165+ fn compute_enzyme_fn_ty < ' ll > (
166+ cx : & SimpleCx < ' ll > ,
167+ attrs : & AutoDiffAttrs ,
168+ fn_to_diff : & ' ll Value ,
169+ outer_fn : & ' ll Value ,
170+ ) -> & ' ll llvm:: Type {
171+ let fn_ty = cx. get_type_of_global ( outer_fn) ;
172+ let mut ret_ty = cx. get_return_type ( fn_ty) ;
173+
174+ let has_sret = has_sret ( outer_fn) ;
175+
176+ if has_sret {
177+ // Now we don't just forward the return type, so we have to figure it out based on the
178+ // primal return type, in combination with the autodiff settings.
179+ let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
180+ let inner_ret_ty = cx. get_return_type ( fn_ty) ;
181+
182+ let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
183+ if inner_ret_ty == void_ty {
184+ dbg ! ( & fn_to_diff) ;
185+ // This indicates that even the inner function has an sret.
186+ // Right now I only look for an sret in the outer function.
187+ // This *probably* needs some extra handling, but I never ran
188+ // into such a case. So I'll wait for user reports to have a test case.
189+ bug ! ( "sret in inner function" ) ;
190+ }
191+
192+ if attrs. width == 1 {
193+ todo ! ( "Handle sret for scalar ad" ) ;
194+ } else {
195+ // First we check if we also have to deal with the primal return.
196+ if attrs. mode . is_fwd ( ) {
197+ match attrs. ret_activity {
198+ DiffActivity :: Dual => {
199+ let arr_ty =
200+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
201+ ret_ty = arr_ty;
202+ }
203+ DiffActivity :: DualOnly => {
204+ let arr_ty =
205+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
206+ ret_ty = arr_ty;
207+ }
208+ DiffActivity :: Const => {
209+ todo ! ( "Not sure, do we need to do something here?" ) ;
210+ }
211+ _ => {
212+ bug ! ( "unreachable" ) ;
213+ }
214+ }
215+ } else if attrs. mode . is_rev ( ) {
216+ todo ! ( "Handle sret for reverse mode" ) ;
217+ } else {
218+ bug ! ( "unreachable" ) ;
219+ }
220+ }
221+ }
222+
223+ dbg ! ( & outer_fn) ;
224+
225+ // LLVM can figure out the input types on it's own, so we take a shortcut here.
226+ unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
227+ }
228+
138229/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139230/// function with expected naming and calling conventions[^1] which will be
140231/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -145,6 +236,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
145236// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
146237// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
147238fn generate_enzyme_call < ' ll > (
239+ _dcx : & DiagCtxt ,
148240 cx : & SimpleCx < ' ll > ,
149241 fn_to_diff : & ' ll Value ,
150242 outer_fn : & ' ll Value ,
@@ -197,17 +289,9 @@ fn generate_enzyme_call<'ll>(
197289 // }
198290 // ```
199291 unsafe {
200- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201- // arguments. We do however need to declare them with their correct return type.
202- // We already figured the correct return type out in our frontend, when generating the outer_fn,
203- // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204- let fn_ty = llvm:: LLVMGlobalGetValueType ( outer_fn) ;
205- let ret_ty = llvm:: LLVMGetReturnType ( fn_ty) ;
292+ let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
206293
207- // LLVM can figure out the input types on it's own, so we take a shortcut here.
208- let enzyme_ty = llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) ;
209-
210- //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
294+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211295 // think a bit more about what should go here.
212296 let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
213297 let ad_fn = declare_simple_fn (
@@ -268,12 +352,31 @@ fn generate_enzyme_call<'ll>(
268352 // Now that we copied the metadata, get rid of dummy code.
269353 llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
270354
271- if cx. val_ty ( call) == cx. type_void ( ) {
355+ let has_sret = has_sret ( outer_fn) ;
356+
357+ if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
358+ if has_sret {
359+ // This is what we already have in our outer_fn (shortened):
360+ // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
361+ // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
362+ // <Here we are, we want to add the following two lines>
363+ // store [4 x double] %7, ptr %0, align 8
364+ // ret void
365+ // }
366+
367+ // now store the result of the enzyme call into the sret pointer.
368+ let sret_ptr = outer_args[ 0 ] ;
369+ let call_ty = cx. val_ty ( call) ;
370+ assert ! ( llvm:: LLVMRustIsArrayTy ( call_ty) ) ;
371+ llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
372+ }
272373 builder. ret_void ( ) ;
273374 } else {
274375 builder. ret ( call) ;
275376 }
276377
378+ dbg ! ( & outer_fn) ;
379+
277380 // Let's crash in case that we messed something up above and generated invalid IR.
278381 llvm:: LLVMRustVerifyFunction (
279382 outer_fn,
@@ -300,8 +403,7 @@ pub(crate) fn differentiate<'ll>(
300403 if !diff_items. is_empty ( )
301404 && !cgcx. opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: Enable )
302405 {
303- let dcx = cgcx. create_dcx ( ) ;
304- return Err ( dcx. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
406+ return Err ( diag_handler. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
305407 }
306408
307409 // Before dumping the module, we want all the TypeTrees to become part of the module.
@@ -331,7 +433,7 @@ pub(crate) fn differentiate<'ll>(
331433 ) ) ;
332434 } ;
333435
334- generate_enzyme_call ( & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
436+ generate_enzyme_call ( & diag_handler , & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
335437 }
336438
337439 // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
0 commit comments