11#![ allow( unused_imports) ]
2- #![ allow( unused_variables) ]
3- #![ allow( unused_mut) ]
42//use crate::util::check_builtin_macro_attribute;
53//use crate::util::check_autodiff;
64
@@ -20,12 +18,25 @@ use rustc_span::Symbol;
2018use std:: string:: String ;
2119use thin_vec:: { thin_vec, ThinVec } ;
2220
21+ #[ cfg( llvm_enzyme) ]
2322fn first_ident ( x : & NestedMetaItem ) -> rustc_span:: symbol:: Ident {
2423 let segments = & x. meta_item ( ) . unwrap ( ) . path . segments ;
2524 assert ! ( segments. len( ) == 1 ) ;
2625 segments[ 0 ] . ident
2726}
2827
28+ #[ cfg( not( llvm_enzyme) ) ]
29+ pub fn expand (
30+ ecx : & mut ExtCtxt < ' _ > ,
31+ _expand_span : Span ,
32+ meta_item : & ast:: MetaItem ,
33+ item : Annotatable ,
34+ ) -> Vec < Annotatable > {
35+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffSupportNotBuild { span : meta_item. span } ) ;
36+ return vec ! [ item] ;
37+ }
38+
39+ #[ cfg( llvm_enzyme) ]
2940pub fn expand (
3041 ecx : & mut ExtCtxt < ' _ > ,
3142 expand_span : Span ,
@@ -45,24 +56,16 @@ pub fn expand(
4556 let primal = orig_item. ident . clone ( ) ;
4657
4758 // Allow using `#[autodiff(...)]` only on a Fn
48- let ( fn_item , has_ret, sig, sig_span) = if let Annotatable :: Item ( item) = & item
59+ let ( has_ret, sig, sig_span) = if let Annotatable :: Item ( item) = & item
4960 && let ItemKind :: Fn ( box ast:: Fn { sig, .. } ) = & item. kind
5061 {
51- ( item , sig. decl . output . has_ret ( ) , sig, ecx. with_call_site_ctxt ( sig. span ) )
62+ ( sig. decl . output . has_ret ( ) , sig, ecx. with_call_site_ctxt ( sig. span ) )
5263 } else {
5364 ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
5465 return vec ! [ item] ;
5566 } ;
5667 // create TokenStream from vec elemtents:
5768 // meta_item doesn't have a .tokens field
58- let ts: Vec < Token > = meta_item_vec. clone ( ) [ 1 ..]
59- . iter ( )
60- . map ( |x| {
61- let val = first_ident ( x) ;
62- let t = Token :: from_ast_ident ( val) ;
63- t
64- } )
65- . collect ( ) ;
6669 let comma: Token = Token :: new ( TokenKind :: Comma , Span :: default ( ) ) ;
6770 let mut ts: Vec < TokenTree > = vec ! [ ] ;
6871 for t in meta_item_vec. clone ( ) [ 1 ..] . iter ( ) {
@@ -77,18 +80,15 @@ pub fn expand(
7780 dbg ! ( & x) ;
7881 let span = ecx. with_def_site_ctxt ( expand_span) ;
7982
80- let ( d_sig, old_names , new_args, idents) = gen_enzyme_decl ( & sig, & x, span) ;
83+ let ( d_sig, new_args, idents) = gen_enzyme_decl ( & sig, & x, span) ;
8184 let new_decl_span = d_sig. span ;
8285 let d_body = gen_enzyme_body (
8386 ecx,
8487 primal,
85- & old_names,
8688 & new_args,
8789 span,
8890 sig_span,
8991 new_decl_span,
90- & sig,
91- & d_sig,
9292 idents,
9393 ) ;
9494 let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
@@ -102,7 +102,7 @@ pub fn expand(
102102 } ) ) ;
103103 let mut rustc_ad_attr =
104104 P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
105- let mut attr: ast:: Attribute = ast:: Attribute {
105+ let attr: ast:: Attribute = ast:: Attribute {
106106 kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
107107 id : ast:: AttrId :: from_u32 ( 0 ) ,
108108 style : ast:: AttrStyle :: Outer ,
@@ -116,7 +116,7 @@ pub fn expand(
116116 delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
117117 tokens : ts,
118118 } ) ;
119- let mut attr2: ast:: Attribute = ast:: Attribute {
119+ let attr2: ast:: Attribute = ast:: Attribute {
120120 kind : ast:: AttrKind :: Normal ( rustc_ad_attr) ,
121121 id : ast:: AttrId :: from_u32 ( 0 ) ,
122122 style : ast:: AttrStyle :: Outer ,
@@ -131,6 +131,7 @@ pub fn expand(
131131}
132132
133133// shadow arguments must be mutable references or ptrs, because Enzyme will write into them.
134+ #[ cfg( llvm_enzyme) ]
134135fn assure_mut_ref ( ty : & ast:: Ty ) -> ast:: Ty {
135136 let mut ty = ty. clone ( ) ;
136137 match ty. kind {
@@ -152,37 +153,21 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
152153// The second will just take a tuple containing the new arguments.
153154// This way we surpress rustc from optimizing any argument away.
154155// The last line will 'loop {}', to match the return type of the new function
156+ #[ cfg( llvm_enzyme) ]
155157fn gen_enzyme_body (
156158 ecx : & ExtCtxt < ' _ > ,
157159 primal : Ident ,
158- old_names : & [ String ] ,
159160 new_names : & [ String ] ,
160161 span : Span ,
161162 sig_span : Span ,
162163 new_decl_span : Span ,
163- sig : & ast:: FnSig ,
164- d_sig : & ast:: FnSig ,
165164 idents : Vec < Ident > ,
166165) -> P < ast:: Block > {
167166 let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
168- let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
169167 let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
170168 let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
171-
172169 let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
173- let zeroed_call_expr = ecx. expr_path ( ecx. path ( span, zeroed_path) ) ;
174-
175- let mem_zeroed_call: Stmt =
176- ecx. stmt_expr ( ecx. expr_call ( span, zeroed_call_expr. clone ( ) , thin_vec ! [ ] ) ) ;
177- let unsafe_block_with_zeroed_call: P < ast:: Expr > = ecx. expr_block ( P ( ast:: Block {
178- stmts : thin_vec ! [ mem_zeroed_call] ,
179- id : ast:: DUMMY_NODE_ID ,
180- rules : ast:: BlockCheckMode :: Unsafe ( ast:: UserProvided ) ,
181- span : sig_span,
182- tokens : None ,
183- could_be_bare_literal : false ,
184- } ) ) ;
185- let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
170+ let primal_call = gen_primal_call ( ecx, span, primal, idents) ;
186171 // create ::core::hint::black_box(array(arr));
187172 let black_box_primal_call =
188173 ecx. expr_call ( new_decl_span, blackbox_call_expr. clone ( ) , thin_vec ! [ primal_call. clone( ) ] ) ;
@@ -207,11 +192,11 @@ fn gen_enzyme_body(
207192 body
208193}
209194
195+ #[ cfg( llvm_enzyme) ]
210196fn gen_primal_call (
211197 ecx : & ExtCtxt < ' _ > ,
212198 span : Span ,
213199 primal : Ident ,
214- sig : & ast:: FnSig ,
215200 idents : Vec < Ident > ,
216201) -> P < ast:: Expr > {
217202 let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
@@ -226,17 +211,18 @@ fn gen_primal_call(
226211// zero-initialized by Enzyme). Active arguments are not handled yet.
227212// Each argument of the primal function (and the return type if existing) must be annotated with an
228213// activity.
214+ #[ cfg( llvm_enzyme) ]
229215fn gen_enzyme_decl (
230216 sig : & ast:: FnSig ,
231217 x : & AutoDiffAttrs ,
232218 span : Span ,
233- ) -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
219+ ) -> ( ast:: FnSig , Vec < String > , Vec < Ident > ) {
234220 assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
235221 assert ! ( sig. decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
236222 let mut d_decl = sig. decl . clone ( ) ;
237223 let mut d_inputs = Vec :: new ( ) ;
238224 let mut new_inputs = Vec :: new ( ) ;
239- let mut old_names = Vec :: new ( ) ;
225+ // let mut old_names = Vec::new();
240226 let mut idents = Vec :: new ( ) ;
241227 let mut act_ret = ThinVec :: new ( ) ;
242228 for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
@@ -256,7 +242,7 @@ fn gen_enzyme_decl(
256242 dbg ! ( & shadow_arg. pat) ;
257243 panic ! ( "not an ident?" ) ;
258244 } ;
259- old_names. push ( old_name. to_string ( ) ) ;
245+ // old_names.push(old_name.to_string());
260246 let name: String = match x. mode {
261247 DiffMode :: Reverse => format ! ( "d{}" , old_name) ,
262248 DiffMode :: Forward => format ! ( "b{}" , old_name) ,
@@ -320,7 +306,7 @@ fn gen_enzyme_decl(
320306 // return type. This might require changing the return type to a
321307 // tuple.
322308 if act_ret. len ( ) > 0 {
323- let mut ret_ty = match d_decl. output {
309+ let ret_ty = match d_decl. output {
324310 FnRetTy :: Ty ( ref ty) => {
325311 act_ret. insert ( 0 , ty. clone ( ) ) ;
326312 let kind = TyKind :: Tup ( act_ret) ;
@@ -339,5 +325,5 @@ fn gen_enzyme_decl(
339325 }
340326
341327 let d_sig = FnSig { header : sig. header . clone ( ) , decl : d_decl, span } ;
342- ( d_sig, old_names , new_inputs, idents)
328+ ( d_sig, new_inputs, idents)
343329}
0 commit comments