@@ -49,9 +49,11 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
4949// using iterators and peek()?
5050fn match_args_from_caller_to_enzyme < ' ll > (
5151 cx : & SimpleCx < ' ll > ,
52+ width : u32 ,
5253 args : & mut Vec < & ' ll llvm:: Value > ,
5354 inputs : & [ DiffActivity ] ,
5455 outer_args : & [ & ' ll llvm:: Value ] ,
56+ has_sret : bool ,
5557) {
5658 debug ! ( "matching autodiff arguments" ) ;
5759 // We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -63,6 +65,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
6365 let mut outer_pos: usize = 0 ;
6466 let mut activity_pos = 0 ;
6567
68+ if has_sret {
69+ // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
70+ // inner function will still return something. We increase our outer_pos by one,
71+ // and once we're done with all other args we will take the return of the inner call and
72+ // update the sret pointer with it
73+ outer_pos = 1 ;
74+ }
75+
6676 let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
6777 let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
6878 let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
@@ -114,21 +124,21 @@ fn match_args_from_caller_to_enzyme<'ll>(
114124 assert ! ( unsafe {
115125 llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
116126 } ) ;
117- let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
118- let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
119- assert ! ( unsafe {
120- llvm:: LLVMRustGetTypeKind ( next_outer_ty2) == llvm:: TypeKind :: Pointer
121- } ) ;
122- let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
123- let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
124- assert ! ( unsafe {
125- llvm:: LLVMRustGetTypeKind ( next_outer_ty3) == llvm:: TypeKind :: Integer
126- } ) ;
127- args. push ( next_outer_arg2) ;
127+
128+ for _ in 0 ..width {
129+ let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
130+ let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131+ assert ! ( unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) } == llvm:: TypeKind :: Pointer ) ;
132+ let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
133+ let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
134+ assert ! ( unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) } == llvm:: TypeKind :: Integer ) ;
135+ args. push ( next_outer_arg2) ;
136+ }
128137 args. push ( cx. get_metadata_value ( enzyme_const) ) ;
129138 args. push ( next_outer_arg) ;
130- outer_pos += 4 ;
139+ outer_pos += 2 + 2 * width as usize ;
131140 activity_pos += 2 ;
141+
132142 } else {
133143 // A duplicated pointer will have the following two outer_fn arguments:
134144 // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
@@ -144,6 +154,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
144154 args. push ( next_outer_arg) ;
145155 outer_pos += 2 ;
146156 activity_pos += 1 ;
157+
158+ // Now, if width > 1, we need to account for that
159+ for _ in 1 ..width {
160+ let next_outer_arg = outer_args[ outer_pos] ;
161+ args. push ( next_outer_arg) ;
162+ outer_pos += 1 ;
163+ }
164+
147165 }
148166 } else {
149167 // We do not differentiate with resprect to this argument.
@@ -324,14 +342,20 @@ fn generate_enzyme_call<'ll>(
324342 if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
325343 args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
326344 }
345+ if attrs. width > 1 {
346+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
347+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
348+ args. push ( cx. get_const_i64 ( attrs. width as u64 ) ) ;
349+ }
327350
351+ let has_sret = has_sret ( outer_fn) ;
328352 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
329- match_args_from_caller_to_enzyme ( & cx, & mut args, & attrs. input_activity , & outer_args) ;
353+ match_args_from_caller_to_enzyme ( & cx, attrs . width , & mut args, & attrs. input_activity , & outer_args, has_sret ) ;
330354
331355 let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
332356
333357 // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
334- // metadata attachted to it, but we just created this code oota. Given that the
358+ // metadata attached to it, but we just created this code oota. Given that the
335359 // differentiated function already has partly confusing metadata, and given that this
336360 // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
337361 // dummy code which we inserted at a higher level.
@@ -352,8 +376,6 @@ fn generate_enzyme_call<'ll>(
352376 // Now that we copied the metadata, get rid of dummy code.
353377 llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
354378
355- let has_sret = has_sret ( outer_fn) ;
356-
357379 if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
358380 if has_sret {
359381 // This is what we already have in our outer_fn (shortened):
0 commit comments