11use std:: ptr;
2+
23use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
34use rustc_codegen_ssa:: ModuleCodegen ;
45use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
@@ -14,7 +15,6 @@ use crate::errors::{AutoDiffWithoutEnable, LlvmError};
1415use crate :: llvm:: AttributePlace :: Function ;
1516use crate :: llvm:: { Metadata , True } ;
1617use crate :: value:: Value ;
17-
1818use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
1919
2020fn get_params ( fnc : & Value ) -> Vec < & Value > {
@@ -28,14 +28,14 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2828 }
2929}
3030
31- fn has_sret ( fnc : & Value ) -> bool {
32- let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
33- if num_args == 0 {
34- false
35- } else {
36- unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
37- }
31+ fn has_sret ( fnc : & Value ) -> bool {
32+ let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
33+ if num_args == 0 {
34+ false
35+ } else {
36+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
3837 }
38+ }
3939
4040// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
4141// original inputs, as well as metadata and the additional shadow arguments.
@@ -128,17 +128,22 @@ fn match_args_from_caller_to_enzyme<'ll>(
128128 for _ in 0 ..width {
129129 let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
130130 let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131- assert ! ( unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) } == llvm:: TypeKind :: Pointer ) ;
131+ assert ! (
132+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) }
133+ == llvm:: TypeKind :: Pointer
134+ ) ;
132135 let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
133136 let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
134- assert ! ( unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) } == llvm:: TypeKind :: Integer ) ;
137+ assert ! (
138+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) }
139+ == llvm:: TypeKind :: Integer
140+ ) ;
135141 args. push ( next_outer_arg2) ;
136142 }
137143 args. push ( cx. get_metadata_value ( enzyme_const) ) ;
138144 args. push ( next_outer_arg) ;
139145 outer_pos += 2 + 2 * width as usize ;
140146 activity_pos += 2 ;
141-
142147 } else {
143148 // A duplicated pointer will have the following two outer_fn arguments:
144149 // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
@@ -161,7 +166,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
161166 args. push ( next_outer_arg) ;
162167 outer_pos += 1 ;
163168 }
164-
165169 }
166170 } else {
167171 // We do not differentiate with resprect to this argument.
@@ -172,7 +176,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
172176 }
173177}
174178
175-
176179// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
177180// arguments. We do however need to declare them with their correct return type.
178181// We already figured the correct return type out in our frontend, when generating the outer_fn,
@@ -350,7 +353,14 @@ fn generate_enzyme_call<'ll>(
350353
351354 let has_sret = has_sret ( outer_fn) ;
352355 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
353- match_args_from_caller_to_enzyme ( & cx, attrs. width , & mut args, & attrs. input_activity , & outer_args, has_sret) ;
356+ match_args_from_caller_to_enzyme (
357+ & cx,
358+ attrs. width ,
359+ & mut args,
360+ & attrs. input_activity ,
361+ & outer_args,
362+ has_sret,
363+ ) ;
354364
355365 let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
356366
0 commit comments