@@ -146,7 +146,7 @@ mod llvm_enzyme {
146146 }
147147 let dcx = ecx. sess . dcx ( ) ;
148148 // first get the annotable item:
149- let ( sig, is_impl ) : ( FnSig , bool ) = match & item {
149+ let sig: FnSig = match & item {
150150 Annotatable :: Item ( iitem) => {
151151 let sig = match & iitem. kind {
152152 ItemKind :: Fn ( box ast:: Fn { sig, .. } ) => sig,
@@ -155,7 +155,7 @@ mod llvm_enzyme {
155155 return vec ! [ item] ;
156156 }
157157 } ;
158- ( sig. clone ( ) , false )
158+ sig. clone ( )
159159 }
160160 Annotatable :: AssocItem ( assoc_item, Impl { of_trait : false } ) => {
161161 let sig = match & assoc_item. kind {
@@ -165,7 +165,24 @@ mod llvm_enzyme {
165165 return vec ! [ item] ;
166166 }
167167 } ;
168- ( sig. clone ( ) , true )
168+ sig. clone ( )
169+ }
170+ Annotatable :: Stmt ( stmt) => {
171+ let sig = match & stmt. kind {
172+ ast:: StmtKind :: Item ( iitem) => match & iitem. kind {
173+ ast:: ItemKind :: Fn ( box ast:: Fn { sig, .. } ) => sig,
174+ _ => {
175+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
176+ return vec ! [ item] ;
177+ }
178+ } ,
179+ _ => {
180+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
181+ return vec ! [ item] ;
182+ }
183+ } ;
184+
185+ sig. clone ( )
169186 }
170187 _ => {
171188 dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
@@ -189,6 +206,10 @@ mod llvm_enzyme {
189206 Annotatable :: AssocItem ( assoc_item, _) => {
190207 ( assoc_item. vis . clone ( ) , assoc_item. ident . clone ( ) )
191208 }
209+ Annotatable :: Stmt ( stmt) => match & stmt. kind {
210+ ast:: StmtKind :: Item ( iitem) => ( iitem. vis . clone ( ) , iitem. ident . clone ( ) ) ,
211+ _ => unreachable ! ( "stmt kind checked previously" ) ,
212+ } ,
192213 _ => {
193214 dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
194215 return vec ! [ item] ;
@@ -305,6 +326,22 @@ mod llvm_enzyme {
305326 }
306327 Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
307328 }
329+ Annotatable :: Stmt ( ref mut stmt) => {
330+ match stmt. kind {
331+ ast:: StmtKind :: Item ( ref mut iitem) => {
332+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
333+ iitem. attrs . push ( attr) ;
334+ }
335+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
336+ {
337+ iitem. attrs . push ( inline_never. clone ( ) ) ;
338+ }
339+ }
340+ _ => unreachable ! ( "stmt kind checked previously" ) ,
341+ } ;
342+
343+ Annotatable :: Stmt ( stmt. clone ( ) )
344+ }
308345 _ => {
309346 unreachable ! ( "annotatable kind checked previously" )
310347 }
@@ -315,24 +352,43 @@ mod llvm_enzyme {
315352 delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
316353 tokens : ts,
317354 } ) ;
355+
318356 let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
319- let d_annotatable = if is_impl {
320- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
321- let d_fn = P ( ast:: AssocItem {
322- attrs : thin_vec ! [ d_attr, inline_never] ,
323- id : ast:: DUMMY_NODE_ID ,
324- span,
325- vis,
326- ident : d_ident,
327- kind : assoc_item,
328- tokens : None ,
329- } ) ;
330- Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
331- } else {
332- let mut d_fn =
333- ecx. item ( span, d_ident, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
334- d_fn. vis = vis;
335- Annotatable :: Item ( d_fn)
357+ let d_annotatable = match & item {
358+ Annotatable :: AssocItem ( _, _) => {
359+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
360+ let d_fn = P ( ast:: AssocItem {
361+ attrs : thin_vec ! [ d_attr, inline_never] ,
362+ id : ast:: DUMMY_NODE_ID ,
363+ span,
364+ vis,
365+ ident : d_ident,
366+ kind : assoc_item,
367+ tokens : None ,
368+ } ) ;
369+ Annotatable :: AssocItem ( d_fn, Impl )
370+ }
371+ Annotatable :: Item ( _) => {
372+ let mut d_fn =
373+ ecx. item ( span, d_ident, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
374+ d_fn. vis = vis;
375+
376+ Annotatable :: Item ( d_fn)
377+ }
378+ Annotatable :: Stmt ( _) => {
379+ let mut d_fn =
380+ ecx. item ( span, d_ident, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
381+ d_fn. vis = vis;
382+
383+ Annotatable :: Stmt ( P ( ast:: Stmt {
384+ id : ast:: DUMMY_NODE_ID ,
385+ kind : ast:: StmtKind :: Item ( d_fn) ,
386+ span,
387+ } ) )
388+ }
389+ _ => {
390+ unreachable ! ( "item kind checked previously" )
391+ }
336392 } ;
337393
338394 return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments