@@ -305,6 +305,7 @@ mod llvm_enzyme {
305305 let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
306306 let d_body = gen_enzyme_body (
307307 ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
308+ & generics,
308309 ) ;
309310
310311 // The first element of it is the name of the function to be generated
@@ -477,6 +478,7 @@ mod llvm_enzyme {
477478 new_decl_span : Span ,
478479 idents : & [ Ident ] ,
479480 errored : bool ,
481+ generics : & Generics ,
480482 ) -> ( P < ast:: Block > , P < ast:: Expr > , P < ast:: Expr > , P < ast:: Expr > ) {
481483 let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
482484 let noop = ast:: InlineAsm {
@@ -499,7 +501,7 @@ mod llvm_enzyme {
499501 } ;
500502 let unsf_expr = ecx. expr_block ( P ( unsf_block) ) ;
501503 let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
502- let primal_call = gen_primal_call ( ecx, span, primal, idents) ;
504+ let primal_call = gen_primal_call ( ecx, span, primal, idents, generics ) ;
503505 let black_box_primal_call = ecx. expr_call (
504506 new_decl_span,
505507 blackbox_call_expr. clone ( ) ,
@@ -548,6 +550,7 @@ mod llvm_enzyme {
548550 sig_span : Span ,
549551 idents : Vec < Ident > ,
550552 errored : bool ,
553+ generics : & Generics ,
551554 ) -> P < ast:: Block > {
552555 let new_decl_span = d_sig. span ;
553556
@@ -568,6 +571,7 @@ mod llvm_enzyme {
568571 new_decl_span,
569572 & idents,
570573 errored,
574+ generics,
571575 ) ;
572576
573577 if !has_ret ( & d_sig. decl . output ) {
@@ -610,7 +614,6 @@ mod llvm_enzyme {
610614 panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
611615 }
612616 } ;
613-
614617 if x. mode . is_fwd ( ) {
615618 // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
616619 // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -670,8 +673,10 @@ mod llvm_enzyme {
670673 span : Span ,
671674 primal : Ident ,
672675 idents : & [ Ident ] ,
676+ generics : & Generics ,
673677 ) -> P < ast:: Expr > {
674678 let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
679+
675680 if has_self {
676681 let args: ThinVec < _ > =
677682 idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
@@ -680,7 +685,51 @@ mod llvm_enzyme {
680685 } else {
681686 let args: ThinVec < _ > =
682687 idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
683- let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
688+ let mut primal_path = ecx. path_ident ( span, primal) ;
689+
690+ let is_generic = !generics. params . is_empty ( ) ;
691+
692+ match ( is_generic, primal_path. segments . last_mut ( ) ) {
693+ ( true , Some ( function_path) ) => {
694+ let primal_generic_types = generics
695+ . params
696+ . iter ( )
697+ . filter ( |param| matches ! ( param. kind, ast:: GenericParamKind :: Type { .. } ) ) ;
698+
699+ let generated_generic_types = primal_generic_types
700+ . map ( |type_param| {
701+ let generic_param = TyKind :: Path (
702+ None ,
703+ ast:: Path {
704+ span,
705+ segments : thin_vec ! [ ast:: PathSegment {
706+ ident: type_param. ident,
707+ args: None ,
708+ id: ast:: DUMMY_NODE_ID ,
709+ } ] ,
710+ tokens : None ,
711+ } ,
712+ ) ;
713+
714+ ast:: AngleBracketedArg :: Arg ( ast:: GenericArg :: Type ( P ( ast:: Ty {
715+ id : type_param. id ,
716+ span,
717+ kind : generic_param,
718+ tokens : None ,
719+ } ) ) )
720+ } )
721+ . collect ( ) ;
722+
723+ function_path. args =
724+ Some ( P ( ast:: GenericArgs :: AngleBracketed ( ast:: AngleBracketedArgs {
725+ span,
726+ args : generated_generic_types,
727+ } ) ) ) ;
728+ }
729+ _ => { }
730+ }
731+
732+ let primal_call_expr = ecx. expr_path ( primal_path) ;
684733 ecx. expr_call ( span, primal_call_expr, args)
685734 }
686735 }
0 commit comments