@@ -20,7 +20,7 @@ mod llvm_enzyme {
2020 MetaItemInner , PatKind , Path , PathSegment , TyKind , Visibility ,
2121 } ;
2222 use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
23- use rustc_span:: { Ident , Span , Symbol , kw , sym} ;
23+ use rustc_span:: { Ident , Span , Symbol , sym} ;
2424 use thin_vec:: { ThinVec , thin_vec} ;
2525 use tracing:: { debug, trace} ;
2626
@@ -73,9 +73,7 @@ mod llvm_enzyme {
7373 }
7474
7575 // Get information about the function the macro is applied to
76- fn extract_item_info (
77- iitem : & Box < ast:: Item > ,
78- ) -> Option < ( Visibility , FnSig , Ident , Generics , bool ) > {
76+ fn extract_item_info ( iitem : & Box < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident , Generics ) > {
7977 match & iitem. kind {
8078 ItemKind :: Fn ( box ast:: Fn { sig, ident, generics, .. } ) => {
8179 Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) , generics. clone ( ) , false ) )
@@ -182,11 +180,8 @@ mod llvm_enzyme {
182180 }
183181
184182 /// We expand the autodiff macro to generate a new placeholder function which passes
185- /// type-checking and can be called by users. The function body of the placeholder function will
186- /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
187- /// should just prevent early inlining and optimizations which alter the function signature.
188- /// The exact signature of the generated function depends on the configuration provided by the
189- /// user, but here is an example:
183+ /// type-checking and can be called by users. The exact signature of the generated function
184+ /// depends on the configuration provided by the user, but here is an example:
190185 ///
191186 /// ```
192187 /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@@ -202,14 +197,8 @@ mod llvm_enzyme {
202197 /// f32::sin(**x)
203198 /// }
204199 /// #[rustc_autodiff(Reverse, Duplicated, Active)]
205- /// #[inline(never)]
206200 /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
207- /// unsafe {
208- /// asm!("NOP");
209- /// };
210- /// ::core::hint::black_box(sin(x));
211- /// ::core::hint::black_box((dx, dret));
212- /// ::core::hint::black_box(sin(x))
201+ /// std::intrinsics::enzyme_autodiff(sin::<>, cos_box::<>, (x, dx, dret))
213202 /// }
214203 /// ```
215204 /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -329,22 +318,20 @@ mod llvm_enzyme {
329318 }
330319 let span = ecx. with_def_site_ctxt ( expand_span) ;
331320
332- let ( d_sig, idents , errored ) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
321+ let d_sig = gen_enzyme_decl ( ecx, & sig, & x, span) ;
333322
334323 let d_body = gen_enzyme_body (
335324 ecx,
336325 & d_sig,
337326 primal,
338327 span,
339- idents,
340- errored,
341328 first_ident ( & meta_item_vec[ 0 ] ) ,
342329 & generics,
343330 impl_of_trait,
344331 ) ;
345332
346333 // The first element of it is the name of the function to be generated
347- let asdf = Box :: new ( ast:: Fn {
334+ let d_fn = Box :: new ( ast:: Fn {
348335 defaultness : ast:: Defaultness :: Final ,
349336 sig : d_sig,
350337 ident : first_ident ( & meta_item_vec[ 0 ] ) ,
@@ -453,13 +440,13 @@ mod llvm_enzyme {
453440 Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
454441 }
455442 Annotatable :: Item ( _) => {
456- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf ) ) ;
443+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( d_fn ) ) ;
457444 d_fn. vis = vis;
458445
459446 Annotatable :: Item ( d_fn)
460447 }
461448 Annotatable :: Stmt ( _) => {
462- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf ) ) ;
449+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( d_fn ) ) ;
463450 d_fn. vis = vis;
464451
465452 Annotatable :: Stmt ( Box :: new ( ast:: Stmt {
@@ -524,14 +511,8 @@ mod llvm_enzyme {
524511 . into ( ) ,
525512 ) ;
526513
527- let enzyme_path = ecx. path (
528- span,
529- vec ! [
530- Ident :: from_str( "std" ) ,
531- Ident :: from_str( "intrinsics" ) ,
532- Ident :: with_dummy_span( sym:: enzyme_autodiff) ,
533- ] ,
534- ) ;
514+ let enzyme_path_idents = ecx. std_path ( & [ sym:: intrinsics, sym:: enzyme_autodiff] ) ;
515+ let enzyme_path = ecx. path ( span, enzyme_path_idents) ;
535516 let call_expr = ecx. expr_call (
536517 span,
537518 ecx. expr_path ( enzyme_path) ,
@@ -549,7 +530,7 @@ mod llvm_enzyme {
549530 generics : & Generics ,
550531 span : Span ,
551532 is_impl : bool ,
552- ) -> P < ast:: Expr > {
533+ ) -> Box < ast:: Expr > {
553534 let generic_args = generics
554535 . params
555536 . iter ( )
@@ -573,7 +554,7 @@ mod llvm_enzyme {
573554 let segment = PathSegment {
574555 ident,
575556 id : ast:: DUMMY_NODE_ID ,
576- args : Some ( P ( GenericArgs :: AngleBracketed ( args) ) ) ,
557+ args : Some ( Box :: new ( GenericArgs :: AngleBracketed ( args) ) ) ,
577558 } ;
578559
579560 let segments = if is_impl {
@@ -590,25 +571,6 @@ mod llvm_enzyme {
590571 ecx. expr_path ( path)
591572 }
592573
593- // Will generate a body of the type:
594- // ```
595- // primal(args);
596- // std::intrinsics::enzyme_autodiff(primal, diff, (args))
597- // }
598- // ```
599- fn init_body_helper (
600- ecx : & ExtCtxt < ' _ > ,
601- span : Span ,
602- primal : Ident ,
603- idents : & [ Ident ] ,
604- _errored : bool ,
605- generics : & Generics ,
606- ) -> Box < ast:: Block > {
607- let _primal_call = gen_primal_call ( ecx, span, primal, idents, generics) ;
608- let body = ecx. block ( span, ThinVec :: new ( ) ) ;
609- body
610- }
611-
612574 /// We only want this function to type-check, since we will replace the body
613575 /// later on llvm level. Using `loop {}` does not cover all return types anymore,
614576 /// so instead we manually build something that should pass the type checker.
@@ -622,8 +584,6 @@ mod llvm_enzyme {
622584 d_sig : & ast:: FnSig ,
623585 primal : Ident ,
624586 span : Span ,
625- idents : Vec < Ident > ,
626- errored : bool ,
627587 diff_ident : Ident ,
628588 generics : & Generics ,
629589 is_impl : bool ,
@@ -632,87 +592,22 @@ mod llvm_enzyme {
632592
633593 // Add a call to the primal function to prevent it from being inlined
634594 // and call `enzyme_autodiff` intrinsic (this also covers the return type)
635- let mut body = init_body_helper ( ecx, span, primal, & idents, errored, generics) ;
636-
637- body. stmts . push ( call_enzyme_autodiff (
638- ecx,
639- primal,
640- diff_ident,
641- new_decl_span,
642- d_sig,
643- generics,
644- is_impl,
645- ) ) ;
595+ let body = ecx. block (
596+ span,
597+ thin_vec ! [ call_enzyme_autodiff(
598+ ecx,
599+ primal,
600+ diff_ident,
601+ new_decl_span,
602+ d_sig,
603+ generics,
604+ is_impl,
605+ ) ] ,
606+ ) ;
646607
647608 body
648609 }
649610
650- fn gen_primal_call (
651- ecx : & ExtCtxt < ' _ > ,
652- span : Span ,
653- primal : Ident ,
654- idents : & [ Ident ] ,
655- generics : & Generics ,
656- ) -> Box < ast:: Expr > {
657- let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
658-
659- if has_self {
660- let args: ThinVec < _ > =
661- idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
662- let self_expr = ecx. expr_self ( span) ;
663- ecx. expr_method_call ( span, self_expr, primal, args)
664- } else {
665- let args: ThinVec < _ > =
666- idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
667- let mut primal_path = ecx. path_ident ( span, primal) ;
668-
669- let is_generic = !generics. params . is_empty ( ) ;
670-
671- match ( is_generic, primal_path. segments . last_mut ( ) ) {
672- ( true , Some ( function_path) ) => {
673- let primal_generic_types = generics
674- . params
675- . iter ( )
676- . filter ( |param| matches ! ( param. kind, ast:: GenericParamKind :: Type { .. } ) ) ;
677-
678- let generated_generic_types = primal_generic_types
679- . map ( |type_param| {
680- let generic_param = TyKind :: Path (
681- None ,
682- ast:: Path {
683- span,
684- segments : thin_vec ! [ ast:: PathSegment {
685- ident: type_param. ident,
686- args: None ,
687- id: ast:: DUMMY_NODE_ID ,
688- } ] ,
689- tokens : None ,
690- } ,
691- ) ;
692-
693- ast:: AngleBracketedArg :: Arg ( ast:: GenericArg :: Type ( Box :: new ( ast:: Ty {
694- id : type_param. id ,
695- span,
696- kind : generic_param,
697- tokens : None ,
698- } ) ) )
699- } )
700- . collect ( ) ;
701-
702- function_path. args =
703- Some ( Box :: new ( ast:: GenericArgs :: AngleBracketed ( ast:: AngleBracketedArgs {
704- span,
705- args : generated_generic_types,
706- } ) ) ) ;
707- }
708- _ => { }
709- }
710-
711- let primal_call_expr = ecx. expr_path ( primal_path) ;
712- ecx. expr_call ( span, primal_call_expr, args)
713- }
714- }
715-
716611 // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
717612 // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
718613 // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
@@ -729,7 +624,7 @@ mod llvm_enzyme {
729624 sig : & ast:: FnSig ,
730625 x : & AutoDiffAttrs ,
731626 span : Span ,
732- ) -> ( ast:: FnSig , Vec < Ident > , bool ) {
627+ ) -> ast:: FnSig {
733628 let dcx = ecx. sess . dcx ( ) ;
734629 let has_ret = has_ret ( & sig. decl . output ) ;
735630 let sig_args = sig. decl . inputs . len ( ) + if has_ret { 1 } else { 0 } ;
@@ -741,7 +636,7 @@ mod llvm_enzyme {
741636 found : num_activities,
742637 } ) ;
743638 // This is not the right signature, but we can continue parsing.
744- return ( sig. clone ( ) , vec ! [ ] , true ) ;
639+ return sig. clone ( ) ;
745640 }
746641 assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
747642 assert ! ( has_ret == x. has_ret_activity( ) ) ;
@@ -784,7 +679,7 @@ mod llvm_enzyme {
784679
785680 if errors {
786681 // This is not the right signature, but we can continue parsing.
787- return ( sig. clone ( ) , idents , true ) ;
682+ return sig. clone ( ) ;
788683 }
789684
790685 let unsafe_activities = x
@@ -998,7 +893,7 @@ mod llvm_enzyme {
998893 }
999894 let d_sig = FnSig { header : d_header, decl : d_decl, span } ;
1000895 trace ! ( "Generated signature: {:?}" , d_sig) ;
1001- ( d_sig, idents , false )
896+ d_sig
1002897 }
1003898}
1004899
0 commit comments