@@ -199,27 +199,46 @@ mod llvm_enzyme {
199199 return vec ! [ item] ;
200200 }
201201 let dcx = ecx. sess . dcx ( ) ;
202- // first get the annotable item:
203- let ( primal, sig, is_impl) : ( Ident , FnSig , bool ) = match & item {
202+
203+ // first get information about the annotable item:
204+ let ( sig, vis, primal) = match & item {
204205 Annotatable :: Item ( iitem) => {
205- let ( ident , sig ) = match & iitem. kind {
206- ItemKind :: Fn ( box ast:: Fn { ident , sig , .. } ) => ( ident , sig ) ,
206+ let ( sig , ident ) = match & iitem. kind {
207+ ItemKind :: Fn ( box ast:: Fn { sig , ident , .. } ) => ( sig , ident ) ,
207208 _ => {
208209 dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
209210 return vec ! [ item] ;
210211 }
211212 } ;
212- ( * ident , sig . clone ( ) , false )
213+ ( sig . clone ( ) , iitem . vis . clone ( ) , ident . clone ( ) )
213214 }
214215 Annotatable :: AssocItem ( assoc_item, Impl { of_trait : false } ) => {
215- let ( ident, sig) = match & assoc_item. kind {
216- ast:: AssocItemKind :: Fn ( box ast:: Fn { ident, sig, .. } ) => ( ident, sig) ,
216+ let ( sig, ident) = match & assoc_item. kind {
217+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => ( sig, ident) ,
218+ _ => {
219+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
220+ return vec ! [ item] ;
221+ }
222+ } ;
223+ ( sig. clone ( ) , assoc_item. vis . clone ( ) , ident. clone ( ) )
224+ }
225+ Annotatable :: Stmt ( stmt) => {
226+ let ( sig, vis, ident) = match & stmt. kind {
227+ ast:: StmtKind :: Item ( iitem) => match & iitem. kind {
228+ ast:: ItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
229+ ( sig. clone ( ) , iitem. vis . clone ( ) , ident. clone ( ) )
230+ }
231+ _ => {
232+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
233+ return vec ! [ item] ;
234+ }
235+ } ,
217236 _ => {
218237 dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
219238 return vec ! [ item] ;
220239 }
221240 } ;
222- ( * ident , sig . clone ( ) , true )
241+ ( sig , vis , ident )
223242 }
224243 _ => {
225244 dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
@@ -238,15 +257,6 @@ mod llvm_enzyme {
238257 let has_ret = has_ret ( & sig. decl . output ) ;
239258 let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
240259
241- let vis = match & item {
242- Annotatable :: Item ( iitem) => iitem. vis . clone ( ) ,
243- Annotatable :: AssocItem ( assoc_item, _) => assoc_item. vis . clone ( ) ,
244- _ => {
245- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
246- return vec ! [ item] ;
247- }
248- } ;
249-
250260 // create TokenStream from vec elemtents:
251261 // meta_item doesn't have a .tokens field
252262 let mut ts: Vec < TokenTree > = vec ! [ ] ;
@@ -379,6 +389,22 @@ mod llvm_enzyme {
379389 }
380390 Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
381391 }
392+ Annotatable :: Stmt ( ref mut stmt) => {
393+ match stmt. kind {
394+ ast:: StmtKind :: Item ( ref mut iitem) => {
395+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
396+ iitem. attrs . push ( attr) ;
397+ }
398+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
399+ {
400+ iitem. attrs . push ( inline_never. clone ( ) ) ;
401+ }
402+ }
403+ _ => unreachable ! ( "stmt kind checked previously" ) ,
404+ } ;
405+
406+ Annotatable :: Stmt ( stmt. clone ( ) )
407+ }
382408 _ => {
383409 unreachable ! ( "annotatable kind checked previously" )
384410 }
@@ -389,22 +415,40 @@ mod llvm_enzyme {
389415 delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
390416 tokens : ts,
391417 } ) ;
418+
392419 let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
393- let d_annotatable = if is_impl {
394- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
395- let d_fn = P ( ast:: AssocItem {
396- attrs : thin_vec ! [ d_attr, inline_never] ,
397- id : ast:: DUMMY_NODE_ID ,
398- span,
399- vis,
400- kind : assoc_item,
401- tokens : None ,
402- } ) ;
403- Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
404- } else {
405- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
406- d_fn. vis = vis;
407- Annotatable :: Item ( d_fn)
420+ let d_annotatable = match & item {
421+ Annotatable :: AssocItem ( _, _) => {
422+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
423+ let d_fn = P ( ast:: AssocItem {
424+ attrs : thin_vec ! [ d_attr, inline_never] ,
425+ id : ast:: DUMMY_NODE_ID ,
426+ span,
427+ vis,
428+ kind : assoc_item,
429+ tokens : None ,
430+ } ) ;
431+ Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
432+ }
433+ Annotatable :: Item ( _) => {
434+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
435+ d_fn. vis = vis;
436+
437+ Annotatable :: Item ( d_fn)
438+ }
439+ Annotatable :: Stmt ( _) => {
440+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
441+ d_fn. vis = vis;
442+
443+ Annotatable :: Stmt ( P ( ast:: Stmt {
444+ id : ast:: DUMMY_NODE_ID ,
445+ kind : ast:: StmtKind :: Item ( d_fn) ,
446+ span,
447+ } ) )
448+ }
449+ _ => {
450+ unreachable ! ( "item kind checked previously" )
451+ }
408452 } ;
409453
410454 return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments