@@ -17,7 +17,7 @@ mod llvm_enzyme {
1717 use rustc_ast:: visit:: AssocCtxt :: * ;
1818 use rustc_ast:: {
1919 self as ast, AssocItemKind , BindingMode , ExprKind , FnRetTy , FnSig , Generics , ItemKind ,
20- MetaItemInner , PatKind , QSelf , TyKind ,
20+ MetaItemInner , PatKind , QSelf , TyKind , Visibility ,
2121 } ;
2222 use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
2323 use rustc_span:: { Ident , Span , Symbol , kw, sym} ;
@@ -72,6 +72,16 @@ mod llvm_enzyme {
7272 }
7373 }
7474
75+ // Get information about the function the macro is applied to
76+ fn extract_item_info ( iitem : & P < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident ) > {
77+ match & iitem. kind {
78+ ItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
79+ Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) ) )
80+ }
81+ _ => None ,
82+ }
83+ }
84+
7585 pub ( crate ) fn from_ast (
7686 ecx : & mut ExtCtxt < ' _ > ,
7787 meta_item : & ThinVec < MetaItemInner > ,
@@ -199,32 +209,26 @@ mod llvm_enzyme {
199209 return vec ! [ item] ;
200210 }
201211 let dcx = ecx. sess . dcx ( ) ;
202- // first get the annotable item:
203- let ( primal, sig, is_impl) : ( Ident , FnSig , bool ) = match & item {
204- Annotatable :: Item ( iitem) => {
205- let ( ident, sig) = match & iitem. kind {
206- ItemKind :: Fn ( box ast:: Fn { ident, sig, .. } ) => ( ident, sig) ,
207- _ => {
208- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
209- return vec ! [ item] ;
210- }
211- } ;
212- ( * ident, sig. clone ( ) , false )
213- }
212+
213+ // first get information about the annotable item:
214+ let Some ( ( vis, sig, primal) ) = ( match & item {
215+ Annotatable :: Item ( iitem) => extract_item_info ( iitem) ,
216+ Annotatable :: Stmt ( stmt) => match & stmt. kind {
217+ ast:: StmtKind :: Item ( iitem) => extract_item_info ( iitem) ,
218+ _ => None ,
219+ } ,
214220 Annotatable :: AssocItem ( assoc_item, Impl { of_trait : false } ) => {
215- let ( ident, sig) = match & assoc_item. kind {
216- ast:: AssocItemKind :: Fn ( box ast:: Fn { ident, sig, .. } ) => ( ident, sig) ,
217- _ => {
218- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
219- return vec ! [ item] ;
221+ match & assoc_item. kind {
222+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
223+ Some ( ( assoc_item. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) ) )
220224 }
221- } ;
222- ( * ident, sig. clone ( ) , true )
223- }
224- _ => {
225- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
226- return vec ! [ item] ;
225+ _ => None ,
226+ }
227227 }
228+ _ => None ,
229+ } ) else {
230+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
231+ return vec ! [ item] ;
228232 } ;
229233
230234 let meta_item_vec: ThinVec < MetaItemInner > = match meta_item. kind {
@@ -238,15 +242,6 @@ mod llvm_enzyme {
238242 let has_ret = has_ret ( & sig. decl . output ) ;
239243 let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
240244
241- let vis = match & item {
242- Annotatable :: Item ( iitem) => iitem. vis . clone ( ) ,
243- Annotatable :: AssocItem ( assoc_item, _) => assoc_item. vis . clone ( ) ,
244- _ => {
245- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
246- return vec ! [ item] ;
247- }
248- } ;
249-
250245 // create TokenStream from vec elemtents:
251246 // meta_item doesn't have a .tokens field
252247 let mut ts: Vec < TokenTree > = vec ! [ ] ;
@@ -379,6 +374,22 @@ mod llvm_enzyme {
379374 }
380375 Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
381376 }
377+ Annotatable :: Stmt ( ref mut stmt) => {
378+ match stmt. kind {
379+ ast:: StmtKind :: Item ( ref mut iitem) => {
380+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
381+ iitem. attrs . push ( attr) ;
382+ }
383+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
384+ {
385+ iitem. attrs . push ( inline_never. clone ( ) ) ;
386+ }
387+ }
388+ _ => unreachable ! ( "stmt kind checked previously" ) ,
389+ } ;
390+
391+ Annotatable :: Stmt ( stmt. clone ( ) )
392+ }
382393 _ => {
383394 unreachable ! ( "annotatable kind checked previously" )
384395 }
@@ -389,22 +400,40 @@ mod llvm_enzyme {
389400 delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
390401 tokens : ts,
391402 } ) ;
403+
392404 let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
393- let d_annotatable = if is_impl {
394- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
395- let d_fn = P ( ast:: AssocItem {
396- attrs : thin_vec ! [ d_attr, inline_never] ,
397- id : ast:: DUMMY_NODE_ID ,
398- span,
399- vis,
400- kind : assoc_item,
401- tokens : None ,
402- } ) ;
403- Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
404- } else {
405- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
406- d_fn. vis = vis;
407- Annotatable :: Item ( d_fn)
405+ let d_annotatable = match & item {
406+ Annotatable :: AssocItem ( _, _) => {
407+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
408+ let d_fn = P ( ast:: AssocItem {
409+ attrs : thin_vec ! [ d_attr, inline_never] ,
410+ id : ast:: DUMMY_NODE_ID ,
411+ span,
412+ vis,
413+ kind : assoc_item,
414+ tokens : None ,
415+ } ) ;
416+ Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
417+ }
418+ Annotatable :: Item ( _) => {
419+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
420+ d_fn. vis = vis;
421+
422+ Annotatable :: Item ( d_fn)
423+ }
424+ Annotatable :: Stmt ( _) => {
425+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
426+ d_fn. vis = vis;
427+
428+ Annotatable :: Stmt ( P ( ast:: Stmt {
429+ id : ast:: DUMMY_NODE_ID ,
430+ kind : ast:: StmtKind :: Item ( d_fn) ,
431+ span,
432+ } ) )
433+ }
434+ _ => {
435+ unreachable ! ( "item kind checked previously" )
436+ }
408437 } ;
409438
410439 return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments