@@ -16,8 +16,9 @@ mod llvm_enzyme {
1616 use rustc_ast:: tokenstream:: * ;
1717 use rustc_ast:: visit:: AssocCtxt :: * ;
1818 use rustc_ast:: {
19- self as ast, AssocItemKind , BindingMode , ExprKind , FnRetTy , FnSig , Generics , ItemKind ,
20- MetaItemInner , PatKind , QSelf , TyKind , Visibility ,
19+ self as ast, AngleBracketedArg , AngleBracketedArgs , AssocItemKind , BindingMode , ExprKind ,
20+ FnRetTy , FnSig , GenericArg , GenericArgs , Generics , ItemKind , MetaItemInner , PatKind , Path ,
21+ PathSegment , QSelf , TyKind , Visibility ,
2122 } ;
2223 use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
2324 use rustc_span:: { Ident , Span , Symbol , kw, sym} ;
@@ -337,8 +338,14 @@ mod llvm_enzyme {
337338 & generics,
338339 ) ;
339340
340- let d_body =
341- call_enzyme_autodiff ( ecx, primal, first_ident ( & meta_item_vec[ 0 ] ) , span, & d_sig) ;
341+ let d_body = call_enzyme_autodiff (
342+ ecx,
343+ primal,
344+ first_ident ( & meta_item_vec[ 0 ] ) ,
345+ span,
346+ & d_sig,
347+ & generics,
348+ ) ;
342349
343350 // The first element of it is the name of the function to be generated
344351 let asdf = Box :: new ( ast:: Fn {
@@ -505,9 +512,10 @@ mod llvm_enzyme {
505512 diff : Ident ,
506513 span : Span ,
507514 d_sig : & FnSig ,
515+ generics : & Generics ,
508516 ) -> P < ast:: Block > {
509- let primal_path_expr = ecx . expr_path ( ecx. path_ident ( span , primal) ) ;
510- let diff_path_expr = ecx . expr_path ( ecx. path_ident ( span , diff) ) ;
517+ let primal_path_expr = gen_turbofish_expr ( ecx, primal, generics , span ) ;
518+ let diff_path_expr = gen_turbofish_expr ( ecx, diff, generics , span ) ;
511519
512520 let tuple_expr = ecx. expr_tuple (
513521 span,
@@ -542,6 +550,37 @@ mod llvm_enzyme {
542550 block
543551 }
544552
553+ // Generate turbofish expression from fn name and generics
554+ // Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
555+ fn gen_turbofish_expr (
556+ ecx : & ExtCtxt < ' _ > ,
557+ ident : Ident ,
558+ generics : & Generics ,
559+ span : Span ,
560+ ) -> P < ast:: Expr > {
561+ let generic_args = generics
562+ . params
563+ . iter ( )
564+ . map ( |p| {
565+ let path = ast:: Path :: from_ident ( p. ident ) ;
566+ let ty = ecx. ty_path ( path) ;
567+ AngleBracketedArg :: Arg ( GenericArg :: Type ( ty) )
568+ } )
569+ . collect :: < ThinVec < _ > > ( ) ;
570+
571+ let args = AngleBracketedArgs { span, args : generic_args } ;
572+
573+ let segment = PathSegment {
574+ ident,
575+ id : ast:: DUMMY_NODE_ID ,
576+ args : Some ( P ( GenericArgs :: AngleBracketed ( args) ) ) ,
577+ } ;
578+
579+ let path = Path { span, segments : thin_vec ! [ segment] , tokens : None } ;
580+
581+ ecx. expr_path ( path)
582+ }
583+
545584 // Generate dummy const to prevent primal function
546585 // from being optimized away before applying enzyme
547586 // ```
0 commit comments