@@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
201201 }
202202
203203 if attrs. width == 1 {
204- todo ! ( "Handle sret for scalar ad" ) ;
204+ // Enzyme returns a struct of style:
205+ // `{ original_ret(if requested), float, float, ... }`
206+ let mut struct_elements = vec ! [ ] ;
207+ if attrs. has_primal_ret ( ) {
208+ struct_elements. push ( inner_ret_ty) ;
209+ }
210+ // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
211+ // and therefore part of the return struct.
212+ let param_tys = cx. func_params_types ( fn_ty) ;
213+ for ( act, param_ty) in attrs. input_activity . iter ( ) . zip ( param_tys) {
214+ if matches ! ( act, DiffActivity :: Active ) {
215+ // Now find the float type at position i based on the fn_ty,
216+ // to know what (f16/f32/f64/...) to add to the struct.
217+ struct_elements. push ( param_ty) ;
218+ }
219+ }
220+ ret_ty = cx. type_struct ( & struct_elements, false ) ;
205221 } else {
206222 // First we check if we also have to deal with the primal return.
207223 match attrs. mode {
@@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
388404 // now store the result of the enzyme call into the sret pointer.
389405 let sret_ptr = outer_args[ 0 ] ;
390406 let call_ty = cx. val_ty ( call) ;
391- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
407+ if attrs. width == 1 {
408+ assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Struct ) ;
409+ } else {
410+ assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
411+ }
392412 llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
393413 }
394414 builder. ret_void ( ) ;
0 commit comments