@@ -3,7 +3,6 @@ use std::cmp::Ordering;
33
44use rustc_abi:: { Align , BackendRepr , ExternAbi , Float , HasDataLayout , Primitive , Size } ;
55use rustc_codegen_ssa:: base:: { compare_simd_types, wants_msvc_seh, wants_wasm_eh} ;
6- use rustc_codegen_ssa:: codegen_attrs:: autodiff_attrs;
76use rustc_codegen_ssa:: common:: { IntPredicate , TypeKind } ;
87use rustc_codegen_ssa:: errors:: { ExpectedPointerMutability , InvalidMonomorphization } ;
98use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
@@ -1167,7 +1166,13 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11671166 Instance :: try_resolve ( tcx, bx. cx . typing_env ( ) , * diff_id, diff_args) . unwrap ( ) . unwrap ( ) ;
11681167 let diff_symbol = symbol_name_for_instance_in_crate ( tcx, fn_diff. clone ( ) , LOCAL_CRATE ) ;
11691168
1170- let diff_attrs = autodiff_attrs ( tcx, fn_diff. def_id ( ) ) ;
1169+ // TODO(Sa4dUs): Store autodiff items in a single pass and just get them here
1170+ // in a O(1) step
1171+ let diff_attrs = tcx
1172+ . collect_and_partition_mono_items ( ( ) )
1173+ . autodiff_items
1174+ . iter ( )
1175+ . find ( |item| item. target == diff_symbol) ;
11711176 let Some ( diff_attrs) = diff_attrs else { bug ! ( "could not find autodiff attrs" ) } ;
11721177
11731178 // Build body
@@ -1178,7 +1183,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11781183 & diff_symbol,
11791184 llret_ty,
11801185 & val_arr,
1181- diff_attrs. clone ( ) ,
1186+ diff_attrs. attrs . clone ( ) ,
11821187 result,
11831188 ) ;
11841189}
@@ -1195,11 +1200,22 @@ fn get_args_from_tuple<'ll, 'tcx>(
11951200 for i in 0 ..tuple_place. layout . layout . 0 . fields . count ( ) {
11961201 let field_place = tuple_place. project_field ( bx, i) ;
11971202 let field_layout = tuple_place. layout . field ( bx, i) ;
1203+ let field_ty = field_layout. ty ;
11981204 let llvm_ty = field_layout. llvm_type ( bx. cx ) ;
11991205
12001206 let field_val = bx. load ( llvm_ty, field_place. val . llval , field_place. val . align ) ;
12011207
1202- ret_arr. push ( field_val)
1208+ match field_ty. kind ( ) {
1209+ ty:: Ref ( _, inner_ty, _) if matches ! ( inner_ty. kind( ) , ty:: Slice ( _) ) => {
1210+ let ptr = bx. extract_value ( field_val, 0 ) ;
1211+ let len = bx. extract_value ( field_val, 1 ) ;
1212+ ret_arr. push ( ptr) ;
1213+ ret_arr. push ( len) ;
1214+ }
1215+ _ => {
1216+ ret_arr. push ( field_val) ;
1217+ }
1218+ }
12031219 }
12041220
12051221 ret_arr
0 commit comments