@@ -3,7 +3,6 @@ use std::cmp::Ordering;
33
44use rustc_abi:: { Align , BackendRepr , ExternAbi , Float , HasDataLayout , Primitive , Size } ;
55use rustc_codegen_ssa:: base:: { compare_simd_types, wants_msvc_seh, wants_wasm_eh} ;
6- use rustc_codegen_ssa:: codegen_attrs:: autodiff_attrs;
76use rustc_codegen_ssa:: common:: { IntPredicate , TypeKind } ;
87use rustc_codegen_ssa:: errors:: { ExpectedPointerMutability , InvalidMonomorphization } ;
98use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
@@ -1165,7 +1164,13 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11651164 Instance :: try_resolve ( tcx, bx. cx . typing_env ( ) , * diff_id, diff_args) . unwrap ( ) . unwrap ( ) ;
11661165 let diff_symbol = symbol_name_for_instance_in_crate ( tcx, fn_diff. clone ( ) , LOCAL_CRATE ) ;
11671166
1168- let diff_attrs = autodiff_attrs ( tcx, fn_diff. def_id ( ) ) ;
1167+ // TODO(Sa4dUs): Store autodiff items in a single pass and just get them here
1168+ // in a O(1) step
1169+ let diff_attrs = tcx
1170+ . collect_and_partition_mono_items ( ( ) )
1171+ . autodiff_items
1172+ . iter ( )
1173+ . find ( |item| item. target == diff_symbol) ;
11691174 let Some ( diff_attrs) = diff_attrs else { bug ! ( "could not find autodiff attrs" ) } ;
11701175
11711176 // Build body
@@ -1176,7 +1181,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11761181 & diff_symbol,
11771182 llret_ty,
11781183 & val_arr,
1179- diff_attrs. clone ( ) ,
1184+ diff_attrs. attrs . clone ( ) ,
11801185 result,
11811186 ) ;
11821187}
@@ -1193,11 +1198,22 @@ fn get_args_from_tuple<'ll, 'tcx>(
11931198 for i in 0 ..tuple_place. layout . layout . 0 . fields . count ( ) {
11941199 let field_place = tuple_place. project_field ( bx, i) ;
11951200 let field_layout = tuple_place. layout . field ( bx, i) ;
1201+ let field_ty = field_layout. ty ;
11961202 let llvm_ty = field_layout. llvm_type ( bx. cx ) ;
11971203
11981204 let field_val = bx. load ( llvm_ty, field_place. val . llval , field_place. val . align ) ;
11991205
1200- ret_arr. push ( field_val)
1206+ match field_ty. kind ( ) {
1207+ ty:: Ref ( _, inner_ty, _) if matches ! ( inner_ty. kind( ) , ty:: Slice ( _) ) => {
1208+ let ptr = bx. extract_value ( field_val, 0 ) ;
1209+ let len = bx. extract_value ( field_val, 1 ) ;
1210+ ret_arr. push ( ptr) ;
1211+ ret_arr. push ( len) ;
1212+ }
1213+ _ => {
1214+ ret_arr. push ( field_val) ;
1215+ }
1216+ }
12011217 }
12021218
12031219 ret_arr
0 commit comments