@@ -262,7 +262,6 @@ mod llvm_enzyme {
262262 } ;
263263
264264 let has_ret = has_ret ( & sig. decl . output ) ;
265- let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
266265
267266 // create TokenStream from vec elemtents:
268267 // meta_item doesn't have a .tokens field
@@ -331,24 +330,13 @@ mod llvm_enzyme {
331330 }
332331 let span = ecx. with_def_site_ctxt ( expand_span) ;
333332
334- let n_active: u32 = x
335- . input_activity
336- . iter ( )
337- . filter ( |a| * * a == DiffActivity :: Active || * * a == DiffActivity :: ActiveOnly )
338- . count ( ) as u32 ;
339- let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
333+ let ( d_sig, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
340334
341- // TODO(Sa4dUs): Remove this and all the related logic
342335 let d_body = gen_enzyme_body (
343336 ecx,
344- & x,
345- n_active,
346- & sig,
347337 & d_sig,
348338 primal,
349- & new_args,
350339 span,
351- sig_span,
352340 idents,
353341 errored,
354342 first_ident ( & meta_item_vec[ 0 ] ) ,
@@ -361,7 +349,7 @@ mod llvm_enzyme {
361349 defaultness : ast:: Defaultness :: Final ,
362350 sig : d_sig,
363351 ident : first_ident ( & meta_item_vec[ 0 ] ) ,
364- generics : generics . clone ( ) ,
352+ generics,
365353 contract : None ,
366354 body : Some ( d_body) ,
367355 define_opaque : None ,
@@ -542,7 +530,7 @@ mod llvm_enzyme {
542530 vec ! [
543531 Ident :: from_str( "std" ) ,
544532 Ident :: from_str( "intrinsics" ) ,
545- Ident :: from_str ( " enzyme_autodiff" ) ,
533+ Ident :: with_dummy_span ( sym :: enzyme_autodiff) ,
546534 ] ,
547535 ) ;
548536 let call_expr = ecx. expr_call (
@@ -555,7 +543,7 @@ mod llvm_enzyme {
555543 }
556544
557545 // Generate turbofish expression from fn name and generics
558- // Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
546+ // Given `foo` and `<A, B, C>` params , gen `foo::<A, B, C>`
559547 fn gen_turbofish_expr (
560548 ecx : & ExtCtxt < ' _ > ,
561549 ident : Ident ,
@@ -597,43 +585,27 @@ mod llvm_enzyme {
597585
598586 // Will generate a body of the type:
599587 // ```
600- // {
601- // unsafe {
602- // asm!("NOP");
603- // }
604- // ::core::hint::black_box(primal(args));
605- // ::core::hint::black_box((args, ret));
606- // <This part remains to be done by following function>
588+ // primal(args);
589+ // std::intrinsics::enzyme_autodiff(primal, diff, (args))
607590 // }
608591 // ```
609592 fn init_body_helper (
610593 ecx : & ExtCtxt < ' _ > ,
611594 span : Span ,
612595 primal : Ident ,
613- _new_names : & [ String ] ,
614- _sig_span : Span ,
615- new_decl_span : Span ,
616596 idents : & [ Ident ] ,
617597 errored : bool ,
618598 generics : & Generics ,
619- ) -> ( P < ast:: Block > , P < ast:: Expr > , P < ast:: Expr > , P < ast:: Expr > ) {
620- let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
621- let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
599+ ) -> P < ast:: Block > {
622600 let primal_call = gen_primal_call ( ecx, span, primal, idents, generics) ;
623- let black_box_primal_call = ecx. expr_call (
624- new_decl_span,
625- blackbox_call_expr. clone ( ) ,
626- thin_vec ! [ primal_call. clone( ) ] ,
627- ) ;
628-
629601 let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
630602
631603 // This uses primal args which won't be available if we errored before
632604 if !errored {
633605 body. stmts . push ( ecx. stmt_semi ( primal_call. clone ( ) ) ) ;
634606 }
635607
636- ( body, primal_call , black_box_primal_call , blackbox_call_expr )
608+ body
637609 }
638610
639611 /// We only want this function to type-check, since we will replace the body
@@ -646,14 +618,9 @@ mod llvm_enzyme {
646618 /// from optimizing any arguments away.
647619 fn gen_enzyme_body (
648620 ecx : & ExtCtxt < ' _ > ,
649- _x : & AutoDiffAttrs ,
650- _n_active : u32 ,
651- _sig : & ast:: FnSig ,
652621 d_sig : & ast:: FnSig ,
653622 primal : Ident ,
654- new_names : & [ String ] ,
655623 span : Span ,
656- sig_span : Span ,
657624 idents : Vec < Ident > ,
658625 errored : bool ,
659626 diff_ident : Ident ,
@@ -664,17 +631,7 @@ mod llvm_enzyme {
664631
665632 // Add a call to the primal function to prevent it from being inlined
666633 // and call `enzyme_autodiff` intrinsic (this also covers the return type)
667- let ( mut body, _primal_call, _bb_primal_call, _bb_call_expr) = init_body_helper (
668- ecx,
669- span,
670- primal,
671- new_names,
672- sig_span,
673- new_decl_span,
674- & idents,
675- errored,
676- generics,
677- ) ;
634+ let mut body = init_body_helper ( ecx, span, primal, & idents, errored, generics) ;
678635
679636 body. stmts . push ( call_enzyme_autodiff (
680637 ecx,
@@ -771,7 +728,7 @@ mod llvm_enzyme {
771728 sig : & ast:: FnSig ,
772729 x : & AutoDiffAttrs ,
773730 span : Span ,
774- ) -> ( ast:: FnSig , Vec < String > , Vec < Ident > , bool ) {
731+ ) -> ( ast:: FnSig , Vec < Ident > , bool ) {
775732 let dcx = ecx. sess . dcx ( ) ;
776733 let has_ret = has_ret ( & sig. decl . output ) ;
777734 let sig_args = sig. decl . inputs . len ( ) + if has_ret { 1 } else { 0 } ;
@@ -783,7 +740,7 @@ mod llvm_enzyme {
783740 found : num_activities,
784741 } ) ;
785742 // This is not the right signature, but we can continue parsing.
786- return ( sig. clone ( ) , vec ! [ ] , vec ! [ ] , true ) ;
743+ return ( sig. clone ( ) , vec ! [ ] , true ) ;
787744 }
788745 assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
789746 assert ! ( has_ret == x. has_ret_activity( ) ) ;
@@ -826,7 +783,7 @@ mod llvm_enzyme {
826783
827784 if errors {
828785 // This is not the right signature, but we can continue parsing.
829- return ( sig. clone ( ) , new_inputs , idents, true ) ;
786+ return ( sig. clone ( ) , idents, true ) ;
830787 }
831788
832789 let unsafe_activities = x
@@ -1034,7 +991,7 @@ mod llvm_enzyme {
1034991 }
1035992 let d_sig = FnSig { header : d_header, decl : d_decl, span } ;
1036993 trace ! ( "Generated signature: {:?}" , d_sig) ;
1037- ( d_sig, new_inputs , idents, false )
994+ ( d_sig, idents, false )
1038995 }
1039996}
1040997
0 commit comments