@@ -331,20 +331,23 @@ mod llvm_enzyme {
331331 . count ( ) as u32 ;
332332 let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
333333
334- // UNUSED
334+ // TODO(Sa4dUs): Remove this and all the related logic
335335 let _d_body = gen_enzyme_body (
336336 ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
337337 & generics,
338338 ) ;
339339
340+ let d_body =
341+ call_enzyme_autodiff ( ecx, primal, first_ident ( & meta_item_vec[ 0 ] ) , span, & d_sig) ;
342+
340343 // The first element of it is the name of the function to be generated
341344 let asdf = Box :: new ( ast:: Fn {
342345 defaultness : ast:: Defaultness :: Final ,
343346 sig : d_sig,
344347 ident : first_ident ( & meta_item_vec[ 0 ] ) ,
345- generics,
348+ generics : generics . clone ( ) ,
346349 contract : None ,
347- body : None , // This leads to an error when the ad function is inside a traits
350+ body : Some ( d_body ) ,
348351 define_opaque : None ,
349352 } ) ;
350353 let mut rustc_ad_attr =
@@ -431,18 +434,15 @@ mod llvm_enzyme {
431434 tokens : ts,
432435 } ) ;
433436
434- let rustc_intrinsic_attr =
435- P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_intrinsic) ) ) ;
436- let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
437- let intrinsic_attr = outer_normal_attr ( & rustc_intrinsic_attr, new_id, span) ;
437+ let vis_clone = vis. clone ( ) ;
438438
439439 let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
440440 let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
441441 let d_annotatable = match & item {
442442 Annotatable :: AssocItem ( _, _) => {
443443 let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
444444 let d_fn = P ( ast:: AssocItem {
445- attrs : thin_vec ! [ d_attr, intrinsic_attr ] ,
445+ attrs : thin_vec ! [ d_attr] ,
446446 id : ast:: DUMMY_NODE_ID ,
447447 span,
448448 vis,
@@ -452,15 +452,13 @@ mod llvm_enzyme {
452452 Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
453453 }
454454 Annotatable :: Item ( _) => {
455- let mut d_fn =
456- ecx. item ( span, thin_vec ! [ d_attr, intrinsic_attr] , ItemKind :: Fn ( asdf) ) ;
455+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf) ) ;
457456 d_fn. vis = vis;
458457
459458 Annotatable :: Item ( d_fn)
460459 }
461460 Annotatable :: Stmt ( _) => {
462- let mut d_fn =
463- ecx. item ( span, thin_vec ! [ d_attr, intrinsic_attr] , ItemKind :: Fn ( asdf) ) ;
461+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf) ) ;
464462 d_fn. vis = vis;
465463
466464 Annotatable :: Stmt ( P ( ast:: Stmt {
@@ -474,7 +472,9 @@ mod llvm_enzyme {
474472 }
475473 } ;
476474
477- return vec ! [ orig_annotatable, d_annotatable] ;
475+ let dummy_const_annotatable = gen_dummy_const ( ecx, span, primal, sig, generics, vis_clone) ;
476+
477+ return vec ! [ orig_annotatable, dummy_const_annotatable, d_annotatable] ;
478478 }
479479
480480 // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -495,6 +495,123 @@ mod llvm_enzyme {
495495 ty
496496 }
497497
498+ // Generate `enzyme_autodiff` intrinsic call
499+ // ```
500+ // std::intrinsics::enzyme_autodiff(source, diff, (args))
501+ // ```
502+ fn call_enzyme_autodiff (
503+ ecx : & ExtCtxt < ' _ > ,
504+ primal : Ident ,
505+ diff : Ident ,
506+ span : Span ,
507+ d_sig : & FnSig ,
508+ ) -> P < ast:: Block > {
509+ let primal_path_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
510+ let diff_path_expr = ecx. expr_path ( ecx. path_ident ( span, diff) ) ;
511+
512+ let tuple_expr = ecx. expr_tuple (
513+ span,
514+ d_sig
515+ . decl
516+ . inputs
517+ . iter ( )
518+ . map ( |arg| match arg. pat . kind {
519+ PatKind :: Ident ( _, ident, _) => ecx. expr_path ( ecx. path_ident ( span, ident) ) ,
520+ _ => todo ! ( ) ,
521+ } )
522+ . collect :: < ThinVec < _ > > ( )
523+ . into ( ) ,
524+ ) ;
525+
526+ let enzyme_path = ecx. path (
527+ span,
528+ vec ! [
529+ Ident :: from_str( "std" ) ,
530+ Ident :: from_str( "intrinsics" ) ,
531+ Ident :: from_str( "enzyme_autodiff" ) ,
532+ ] ,
533+ ) ;
534+ let call_expr = ecx. expr_call (
535+ span,
536+ ecx. expr_path ( enzyme_path) ,
537+ vec ! [ primal_path_expr, diff_path_expr, tuple_expr] . into ( ) ,
538+ ) ;
539+
540+ let block = ecx. block_expr ( call_expr) ;
541+
542+ block
543+ }
544+
545+ // Generate dummy const to prevent primal function
546+ // from being optimized away before applying enzyme
547+ // ```
548+ // const _: () =
549+ // {
550+ // #[used]
551+ // pub static DUMMY_PTR: fn_type = primal_fn;
552+ // };
553+ // ```
554+ fn gen_dummy_const (
555+ ecx : & ExtCtxt < ' _ > ,
556+ span : Span ,
557+ primal : Ident ,
558+ sig : FnSig ,
559+ generics : Generics ,
560+ vis : Visibility ,
561+ ) -> Annotatable {
562+ // #[used]
563+ let used_attr = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: used) ) ) ;
564+ let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
565+ let used_attr = outer_normal_attr ( & used_attr, new_id, span) ;
566+
567+ // static DUMMY_PTR: <fn_type> = <primal_ident>
568+ let static_ident = Ident :: from_str_and_span ( "DUMMY_PTR" , span) ;
569+ let fn_ptr_ty = ast:: TyKind :: BareFn ( Box :: new ( ast:: BareFnTy {
570+ safety : sig. header . safety ,
571+ ext : sig. header . ext ,
572+ generic_params : generics. params ,
573+ decl : sig. decl ,
574+ decl_span : sig. span ,
575+ } ) ) ;
576+ let static_ty = ecx. ty ( span, fn_ptr_ty) ;
577+
578+ let static_expr = ecx. expr_path ( ecx. path ( span, vec ! [ primal] ) ) ;
579+ let static_item_kind = ast:: ItemKind :: Static ( Box :: new ( ast:: StaticItem {
580+ ident : static_ident,
581+ ty : static_ty,
582+ safety : ast:: Safety :: Default ,
583+ mutability : ast:: Mutability :: Not ,
584+ expr : Some ( static_expr) ,
585+ define_opaque : None ,
586+ } ) ) ;
587+
588+ let static_item = ast:: Item {
589+ attrs : thin_vec ! [ used_attr] ,
590+ id : ast:: DUMMY_NODE_ID ,
591+ span,
592+ vis,
593+ kind : static_item_kind,
594+ tokens : None ,
595+ } ;
596+
597+ let block_expr = ecx. expr_block ( Box :: new ( ast:: Block {
598+ stmts : thin_vec ! [ ecx. stmt_item( span, P ( static_item) ) ] ,
599+ id : ast:: DUMMY_NODE_ID ,
600+ rules : ast:: BlockCheckMode :: Default ,
601+ span,
602+ tokens : None ,
603+ } ) ) ;
604+
605+ let const_item = ecx. item_const (
606+ span,
607+ Ident :: from_str_and_span ( "_" , span) ,
608+ ecx. ty ( span, ast:: TyKind :: Tup ( thin_vec ! [ ] ) ) ,
609+ block_expr,
610+ ) ;
611+
612+ Annotatable :: Item ( const_item)
613+ }
614+
498615 // Will generate a body of the type:
499616 // ```
500617 // {
0 commit comments