@@ -3,6 +3,8 @@ use std::ptr;
33use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
44use rustc_codegen_ssa:: ModuleCodegen ;
55use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6+ use rustc_codegen_ssa:: common:: TypeKind ;
7+ use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
68use rustc_errors:: FatalError ;
79use rustc_middle:: bug;
810use tracing:: { debug, trace} ;
@@ -18,18 +20,18 @@ use crate::value::Value;
1820use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
1921
2022fn get_params ( fnc : & Value ) -> Vec < & Value > {
23+ let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
24+ let mut fnc_args: Vec < & Value > = vec ! [ ] ;
25+ fnc_args. reserve ( param_num) ;
2126 unsafe {
22- let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
23- let mut fnc_args: Vec < & Value > = vec ! [ ] ;
24- fnc_args. reserve ( param_num) ;
2527 llvm:: LLVMGetParams ( fnc, fnc_args. as_mut_ptr ( ) ) ;
2628 fnc_args. set_len ( param_num) ;
27- fnc_args
2829 }
30+ fnc_args
2931}
3032
3133fn has_sret ( fnc : & Value ) -> bool {
32- let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
34+ let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
3335 if num_args == 0 {
3436 false
3537 } else {
@@ -121,23 +123,15 @@ fn match_args_from_caller_to_enzyme<'ll>(
121123 // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
122124 // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
123125 // int2 >= int1, which means the shadow vector is large enough to store the gradient.
124- assert ! ( unsafe {
125- llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
126- } ) ;
126+ assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Integer ) ;
127127
128128 for _ in 0 ..width {
129129 let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
130130 let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131- assert ! (
132- unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) }
133- == llvm:: TypeKind :: Pointer
134- ) ;
131+ assert_eq ! ( cx. type_kind( next_outer_ty2) , TypeKind :: Pointer ) ;
135132 let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
136133 let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
137- assert ! (
138- unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) }
139- == llvm:: TypeKind :: Integer
140- ) ;
134+ assert_eq ! ( cx. type_kind( next_outer_ty3) , TypeKind :: Integer ) ;
141135 args. push ( next_outer_arg2) ;
142136 }
143137 args. push ( cx. get_metadata_value ( enzyme_const) ) ;
@@ -150,10 +144,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
150144 // (..., metadata! enzyme_dup, ptr, ptr, ...).
151145 if matches ! ( diff_activity, DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly )
152146 {
153- assert ! (
154- unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty) }
155- == llvm:: TypeKind :: Pointer
156- ) ;
147+ assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Pointer ) ;
157148 }
158149 // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
159150 args. push ( next_outer_arg) ;
@@ -213,8 +204,8 @@ fn compute_enzyme_fn_ty<'ll>(
213204 todo ! ( "Handle sret for scalar ad" ) ;
214205 } else {
215206 // First we check if we also have to deal with the primal return.
216- if attrs. mode . is_fwd ( ) {
217- match attrs. ret_activity {
207+ match attrs. mode {
208+ DiffMode :: Forward => match attrs. ret_activity {
218209 DiffActivity :: Dual => {
219210 let arr_ty =
220211 unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
@@ -231,11 +222,13 @@ fn compute_enzyme_fn_ty<'ll>(
231222 _ => {
232223 bug ! ( "unreachable" ) ;
233224 }
225+ } ,
226+ DiffMode :: Reverse => {
227+ todo ! ( "Handle sret for reverse mode" ) ;
228+ }
229+ _ => {
230+ bug ! ( "unreachable" ) ;
234231 }
235- } else if attrs. mode . is_rev ( ) {
236- todo ! ( "Handle sret for reverse mode" ) ;
237- } else {
238- bug ! ( "unreachable" ) ;
239232 }
240233 }
241234 }
@@ -395,7 +388,7 @@ fn generate_enzyme_call<'ll>(
395388 // now store the result of the enzyme call into the sret pointer.
396389 let sret_ptr = outer_args[ 0 ] ;
397390 let call_ty = cx. val_ty ( call) ;
398- assert ! ( llvm :: LLVMRustIsArrayTy ( call_ty) ) ;
391+ assert_eq ! ( cx . type_kind ( call_ty) , TypeKind :: Array ) ;
399392 llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
400393 }
401394 builder. ret_void ( ) ;
0 commit comments