@@ -329,17 +329,22 @@ mod llvm_enzyme {
329329 . filter ( |a| * * a == DiffActivity :: Active || * * a == DiffActivity :: ActiveOnly )
330330 . count ( ) as u32 ;
331331 let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
332- let d_body = gen_enzyme_body (
332+
333+ // TODO(Sa4dUs): Remove this and all the related logic
334+ let _d_body = gen_enzyme_body (
333335 ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
334336 & generics,
335337 ) ;
336338
339+ let d_body =
340+ call_autodiff ( ecx, primal, first_ident ( & meta_item_vec[ 0 ] ) , span, & d_sig) ;
341+
337342 // The first element of it is the name of the function to be generated
338343 let asdf = Box :: new ( ast:: Fn {
339344 defaultness : ast:: Defaultness :: Final ,
340345 sig : d_sig,
341346 ident : first_ident ( & meta_item_vec[ 0 ] ) ,
342- generics,
347+ generics : generics . clone ( ) ,
343348 contract : None ,
344349 body : Some ( d_body) ,
345350 define_opaque : None ,
@@ -428,12 +433,15 @@ mod llvm_enzyme {
428433 tokens : ts,
429434 } ) ;
430435
436+ let vis_clone = vis. clone ( ) ;
437+
438+ let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
431439 let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
432440 let d_annotatable = match & item {
433441 Annotatable :: AssocItem ( _, _) => {
434442 let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
435443 let d_fn = Box :: new ( ast:: AssocItem {
436- attrs : thin_vec ! [ d_attr, inline_never ] ,
444+ attrs : thin_vec ! [ d_attr] ,
437445 id : ast:: DUMMY_NODE_ID ,
438446 span,
439447 vis,
@@ -443,13 +451,13 @@ mod llvm_enzyme {
443451 Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
444452 }
445453 Annotatable :: Item ( _) => {
446- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never ] , ItemKind :: Fn ( asdf) ) ;
454+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf) ) ;
447455 d_fn. vis = vis;
448456
449457 Annotatable :: Item ( d_fn)
450458 }
451459 Annotatable :: Stmt ( _) => {
452- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never ] , ItemKind :: Fn ( asdf) ) ;
460+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf) ) ;
453461 d_fn. vis = vis;
454462
455463 Annotatable :: Stmt ( Box :: new ( ast:: Stmt {
@@ -463,7 +471,9 @@ mod llvm_enzyme {
463471 }
464472 } ;
465473
466- return vec ! [ orig_annotatable, d_annotatable] ;
474+ let dummy_const_annotatable = gen_dummy_const ( ecx, span, primal, sig, generics, vis_clone) ;
475+
476+ return vec ! [ orig_annotatable, dummy_const_annotatable, d_annotatable] ;
467477 }
468478
469479 // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -484,6 +494,123 @@ mod llvm_enzyme {
484494 ty
485495 }
486496
497+ // Generate `autodiff` intrinsic call
498+ // ```
499+ // std::intrinsics::autodiff(source, diff, (args))
500+ // ```
501+ fn call_autodiff (
502+ ecx : & ExtCtxt < ' _ > ,
503+ primal : Ident ,
504+ diff : Ident ,
505+ span : Span ,
506+ d_sig : & FnSig ,
507+ ) -> P < ast:: Block > {
508+ let primal_path_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
509+ let diff_path_expr = ecx. expr_path ( ecx. path_ident ( span, diff) ) ;
510+
511+ let tuple_expr = ecx. expr_tuple (
512+ span,
513+ d_sig
514+ . decl
515+ . inputs
516+ . iter ( )
517+ . map ( |arg| match arg. pat . kind {
518+ PatKind :: Ident ( _, ident, _) => ecx. expr_path ( ecx. path_ident ( span, ident) ) ,
519+ _ => todo ! ( ) ,
520+ } )
521+ . collect :: < ThinVec < _ > > ( )
522+ . into ( ) ,
523+ ) ;
524+
525+ let enzyme_path = ecx. path (
526+ span,
527+ vec ! [
528+ Ident :: from_str( "std" ) ,
529+ Ident :: from_str( "intrinsics" ) ,
530+ Ident :: from_str( "autodiff" ) ,
531+ ] ,
532+ ) ;
533+ let call_expr = ecx. expr_call (
534+ span,
535+ ecx. expr_path ( enzyme_path) ,
536+ vec ! [ primal_path_expr, diff_path_expr, tuple_expr] . into ( ) ,
537+ ) ;
538+
539+ let block = ecx. block_expr ( call_expr) ;
540+
541+ block
542+ }
543+
544+ // Generate dummy const to prevent primal function
545+ // from being optimized away before applying enzyme
546+ // ```
547+ // const _: () =
548+ // {
549+ // #[used]
550+ // pub static DUMMY_PTR: fn_type = primal_fn;
551+ // };
552+ // ```
553+ fn gen_dummy_const (
554+ ecx : & ExtCtxt < ' _ > ,
555+ span : Span ,
556+ primal : Ident ,
557+ sig : FnSig ,
558+ generics : Generics ,
559+ vis : Visibility ,
560+ ) -> Annotatable {
561+ // #[used]
562+ let used_attr = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: used) ) ) ;
563+ let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
564+ let used_attr = outer_normal_attr ( & used_attr, new_id, span) ;
565+
566+ // static DUMMY_PTR: <fn_type> = <primal_ident>
567+ let static_ident = Ident :: from_str_and_span ( "DUMMY_PTR" , span) ;
568+ let fn_ptr_ty = ast:: TyKind :: BareFn ( Box :: new ( ast:: BareFnTy {
569+ safety : sig. header . safety ,
570+ ext : sig. header . ext ,
571+ generic_params : generics. params ,
572+ decl : sig. decl ,
573+ decl_span : sig. span ,
574+ } ) ) ;
575+ let static_ty = ecx. ty ( span, fn_ptr_ty) ;
576+
577+ let static_expr = ecx. expr_path ( ecx. path ( span, vec ! [ primal] ) ) ;
578+ let static_item_kind = ast:: ItemKind :: Static ( Box :: new ( ast:: StaticItem {
579+ ident : static_ident,
580+ ty : static_ty,
581+ safety : ast:: Safety :: Default ,
582+ mutability : ast:: Mutability :: Not ,
583+ expr : Some ( static_expr) ,
584+ define_opaque : None ,
585+ } ) ) ;
586+
587+ let static_item = ast:: Item {
588+ attrs : thin_vec ! [ used_attr] ,
589+ id : ast:: DUMMY_NODE_ID ,
590+ span,
591+ vis,
592+ kind : static_item_kind,
593+ tokens : None ,
594+ } ;
595+
596+ let block_expr = ecx. expr_block ( Box :: new ( ast:: Block {
597+ stmts : thin_vec ! [ ecx. stmt_item( span, P ( static_item) ) ] ,
598+ id : ast:: DUMMY_NODE_ID ,
599+ rules : ast:: BlockCheckMode :: Default ,
600+ span,
601+ tokens : None ,
602+ } ) ) ;
603+
604+ let const_item = ecx. item_const (
605+ span,
606+ Ident :: from_str_and_span ( "_" , span) ,
607+ ecx. ty ( span, ast:: TyKind :: Tup ( thin_vec ! [ ] ) ) ,
608+ block_expr,
609+ ) ;
610+
611+ Annotatable :: Item ( const_item)
612+ }
613+
487614 // Will generate a body of the type:
488615 // ```
489616 // {
0 commit comments