@@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>(
164164 let mut activity_pos = 0 ;
165165 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
166166 while activity_pos < inputs. len ( ) {
167- let activity = inputs[ activity_pos as usize ] ;
167+ let diff_activity = inputs[ activity_pos as usize ] ;
168168 // Duplicated arguments received a shadow argument, into which enzyme will write the
169169 // gradient.
170- let ( activity, duplicated) : ( & Metadata , bool ) = match activity {
170+ let ( activity, duplicated) : ( & Metadata , bool ) = match diff_activity {
171171 DiffActivity :: None => panic ! ( "not a valid input activity" ) ,
172172 DiffActivity :: Const => ( enzyme_const, false ) ,
173173 DiffActivity :: Active => ( enzyme_out, false ) ,
@@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>(
222222 // A duplicated pointer will have the following two outer_fn arguments:
223223 // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
224224 // (..., metadata! enzyme_dup, ptr, ptr, ...).
225- assert ! ( llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Pointer ) ;
225+ if matches ! ( diff_activity, DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly ) {
226+ assert ! (
227+ llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Pointer
228+ ) ;
229+ }
230+ // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
226231 args. push ( next_outer_arg) ;
227232 outer_pos += 2 ;
228233 activity_pos += 1 ;
0 commit comments