@@ -3,8 +3,8 @@ use std::ptr;
33use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
44use rustc_codegen_ssa:: ModuleCodegen ;
55use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods as _;
76use rustc_errors:: FatalError ;
7+ use rustc_middle:: bug;
88use tracing:: { debug, trace} ;
99
1010use crate :: back:: write:: llvm_err;
@@ -28,11 +28,32 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2828 }
2929}
3030
31+ fn has_sret ( fnc : & Value ) -> bool {
32+ let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
33+ if num_args == 0 {
34+ false
35+ } else {
36+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
37+ }
38+ }
39+
40+ // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
41+ // original inputs, as well as metadata and the additional shadow arguments.
42+ // This function matches the arguments from the outer function to the inner enzyme call.
43+ //
44+ // This function also considers that Rust level arguments not always match the llvm-ir level
45+ // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
46+ // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
47+ // need to match those.
48+ // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
49+ // using iterators and peek()?
3150fn match_args_from_caller_to_enzyme < ' ll > (
3251 cx : & SimpleCx < ' ll > ,
52+ width : u32 ,
3353 args : & mut Vec < & ' ll llvm:: Value > ,
3454 inputs : & [ DiffActivity ] ,
3555 outer_args : & [ & ' ll llvm:: Value ] ,
56+ has_sret : bool ,
3657) {
3758 debug ! ( "matching autodiff arguments" ) ;
3859 // We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -44,6 +65,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
4465 let mut outer_pos: usize = 0 ;
4566 let mut activity_pos = 0 ;
4667
68+ if has_sret {
69+ // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
70+ // inner function will still return something. We increase our outer_pos by one,
71+ // and once we're done with all other args we will take the return of the inner call and
72+ // update the sret pointer with it
73+ outer_pos = 1 ;
74+ }
75+
4776 let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
4877 let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
4978 let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
@@ -95,20 +124,25 @@ fn match_args_from_caller_to_enzyme<'ll>(
95124 assert ! ( unsafe {
96125 llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
97126 } ) ;
98- let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
99- let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
100- assert ! ( unsafe {
101- llvm:: LLVMRustGetTypeKind ( next_outer_ty2) == llvm:: TypeKind :: Pointer
102- } ) ;
103- let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
104- let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
105- assert ! ( unsafe {
106- llvm:: LLVMRustGetTypeKind ( next_outer_ty3) == llvm:: TypeKind :: Integer
107- } ) ;
108- args. push ( next_outer_arg2) ;
127+
128+ for _ in 0 ..width {
129+ let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
130+ let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131+ assert ! (
132+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) }
133+ == llvm:: TypeKind :: Pointer
134+ ) ;
135+ let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
136+ let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
137+ assert ! (
138+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) }
139+ == llvm:: TypeKind :: Integer
140+ ) ;
141+ args. push ( next_outer_arg2) ;
142+ }
109143 args. push ( cx. get_metadata_value ( enzyme_const) ) ;
110144 args. push ( next_outer_arg) ;
111- outer_pos += 4 ;
145+ outer_pos += 2 + 2 * width as usize ;
112146 activity_pos += 2 ;
113147 } else {
114148 // A duplicated pointer will have the following two outer_fn arguments:
@@ -125,6 +159,13 @@ fn match_args_from_caller_to_enzyme<'ll>(
125159 args. push ( next_outer_arg) ;
126160 outer_pos += 2 ;
127161 activity_pos += 1 ;
162+
163+ // Now, if width > 1, we need to account for that
164+ for _ in 1 ..width {
165+ let next_outer_arg = outer_args[ outer_pos] ;
166+ args. push ( next_outer_arg) ;
167+ outer_pos += 1 ;
168+ }
128169 }
129170 } else {
130171 // We do not differentiate with resprect to this argument.
@@ -135,6 +176,74 @@ fn match_args_from_caller_to_enzyme<'ll>(
135176 }
136177}
137178
179+ // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
180+ // arguments. We do however need to declare them with their correct return type.
181+ // We already figured the correct return type out in our frontend, when generating the outer_fn,
182+ // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
183+ // Beyond sret, this article describes our challenges nicely:
184+ // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
185+ // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
186+ fn compute_enzyme_fn_ty < ' ll > (
187+ cx : & SimpleCx < ' ll > ,
188+ attrs : & AutoDiffAttrs ,
189+ fn_to_diff : & ' ll Value ,
190+ outer_fn : & ' ll Value ,
191+ ) -> & ' ll llvm:: Type {
192+ let fn_ty = cx. get_type_of_global ( outer_fn) ;
193+ let mut ret_ty = cx. get_return_type ( fn_ty) ;
194+
195+ let has_sret = has_sret ( outer_fn) ;
196+
197+ if has_sret {
198+ // Now we don't just forward the return type, so we have to figure it out based on the
199+ // primal return type, in combination with the autodiff settings.
200+ let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
201+ let inner_ret_ty = cx. get_return_type ( fn_ty) ;
202+
203+ let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
204+ if inner_ret_ty == void_ty {
205+ // This indicates that even the inner function has an sret.
206+ // Right now I only look for an sret in the outer function.
207+ // This *probably* needs some extra handling, but I never ran
208+ // into such a case. So I'll wait for user reports to have a test case.
209+ bug ! ( "sret in inner function" ) ;
210+ }
211+
212+ if attrs. width == 1 {
213+ todo ! ( "Handle sret for scalar ad" ) ;
214+ } else {
215+ // First we check if we also have to deal with the primal return.
216+ if attrs. mode . is_fwd ( ) {
217+ match attrs. ret_activity {
218+ DiffActivity :: Dual => {
219+ let arr_ty =
220+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
221+ ret_ty = arr_ty;
222+ }
223+ DiffActivity :: DualOnly => {
224+ let arr_ty =
225+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
226+ ret_ty = arr_ty;
227+ }
228+ DiffActivity :: Const => {
229+ todo ! ( "Not sure, do we need to do something here?" ) ;
230+ }
231+ _ => {
232+ bug ! ( "unreachable" ) ;
233+ }
234+ }
235+ } else if attrs. mode . is_rev ( ) {
236+ todo ! ( "Handle sret for reverse mode" ) ;
237+ } else {
238+ bug ! ( "unreachable" ) ;
239+ }
240+ }
241+ }
242+
243+ // LLVM can figure out the input types on it's own, so we take a shortcut here.
244+ unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
245+ }
246+
138247/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139248/// function with expected naming and calling conventions[^1] which will be
140249/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -197,17 +306,9 @@ fn generate_enzyme_call<'ll>(
197306 // }
198307 // ```
199308 unsafe {
200- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201- // arguments. We do however need to declare them with their correct return type.
202- // We already figured the correct return type out in our frontend, when generating the outer_fn,
203- // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204- let fn_ty = llvm:: LLVMGlobalGetValueType ( outer_fn) ;
205- let ret_ty = llvm:: LLVMGetReturnType ( fn_ty) ;
309+ let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
206310
207- // LLVM can figure out the input types on it's own, so we take a shortcut here.
208- let enzyme_ty = llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) ;
209-
210- //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
311+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211312 // think a bit more about what should go here.
212313 let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
213314 let ad_fn = declare_simple_fn (
@@ -240,14 +341,27 @@ fn generate_enzyme_call<'ll>(
240341 if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
241342 args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
242343 }
344+ if attrs. width > 1 {
345+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
346+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
347+ args. push ( cx. get_const_i64 ( attrs. width as u64 ) ) ;
348+ }
243349
350+ let has_sret = has_sret ( outer_fn) ;
244351 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
245- match_args_from_caller_to_enzyme ( & cx, & mut args, & attrs. input_activity , & outer_args) ;
352+ match_args_from_caller_to_enzyme (
353+ & cx,
354+ attrs. width ,
355+ & mut args,
356+ & attrs. input_activity ,
357+ & outer_args,
358+ has_sret,
359+ ) ;
246360
247361 let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
248362
249363 // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
250- // metadata attachted to it, but we just created this code oota. Given that the
364+ // metadata attached to it, but we just created this code oota. Given that the
251365 // differentiated function already has partly confusing metadata, and given that this
252366 // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
253367 // dummy code which we inserted at a higher level.
@@ -268,7 +382,22 @@ fn generate_enzyme_call<'ll>(
268382 // Now that we copied the metadata, get rid of dummy code.
269383 llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
270384
271- if cx. val_ty ( call) == cx. type_void ( ) {
385+ if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
386+ if has_sret {
387+ // This is what we already have in our outer_fn (shortened):
388+ // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
389+ // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
390+ // <Here we are, we want to add the following two lines>
391+ // store [4 x double] %7, ptr %0, align 8
392+ // ret void
393+ // }
394+
395+ // now store the result of the enzyme call into the sret pointer.
396+ let sret_ptr = outer_args[ 0 ] ;
397+ let call_ty = cx. val_ty ( call) ;
398+ assert ! ( llvm:: LLVMRustIsArrayTy ( call_ty) ) ;
399+ llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
400+ }
272401 builder. ret_void ( ) ;
273402 } else {
274403 builder. ret ( call) ;
@@ -300,8 +429,7 @@ pub(crate) fn differentiate<'ll>(
300429 if !diff_items. is_empty ( )
301430 && !cgcx. opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: Enable )
302431 {
303- let dcx = cgcx. create_dcx ( ) ;
304- return Err ( dcx. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
432+ return Err ( diag_handler. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
305433 }
306434
307435 // Before dumping the module, we want all the TypeTrees to become part of the module.
0 commit comments