@@ -145,27 +145,46 @@ mod llvm_enzyme {
145145 return vec ! [ item] ;
146146 }
147147 let dcx = ecx. sess . dcx ( ) ;
148- // first get the annotable item:
149- let ( primal, sig, is_impl) : ( Ident , FnSig , bool ) = match & item {
148+
149+ // first get information about the annotable item:
150+ let ( sig, vis, primal) = match & item {
150151 Annotatable :: Item ( iitem) => {
151- let ( ident , sig ) = match & iitem. kind {
152- ItemKind :: Fn ( box ast:: Fn { ident , sig , .. } ) => ( ident , sig ) ,
152+ let ( sig , ident ) = match & iitem. kind {
153+ ItemKind :: Fn ( box ast:: Fn { sig , ident , .. } ) => ( sig , ident ) ,
153154 _ => {
154155 dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
155156 return vec ! [ item] ;
156157 }
157158 } ;
158- ( * ident , sig . clone ( ) , false )
159+ ( sig . clone ( ) , iitem . vis . clone ( ) , ident . clone ( ) )
159160 }
160161 Annotatable :: AssocItem ( assoc_item, Impl { of_trait : false } ) => {
161- let ( ident, sig) = match & assoc_item. kind {
162- ast:: AssocItemKind :: Fn ( box ast:: Fn { ident, sig, .. } ) => ( ident, sig) ,
162+ let ( sig, ident) = match & assoc_item. kind {
163+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => ( sig, ident) ,
164+ _ => {
165+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
166+ return vec ! [ item] ;
167+ }
168+ } ;
169+ ( sig. clone ( ) , assoc_item. vis . clone ( ) , ident. clone ( ) )
170+ }
171+ Annotatable :: Stmt ( stmt) => {
172+ let ( sig, vis, ident) = match & stmt. kind {
173+ ast:: StmtKind :: Item ( iitem) => match & iitem. kind {
174+ ast:: ItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
175+ ( sig. clone ( ) , iitem. vis . clone ( ) , ident. clone ( ) )
176+ }
177+ _ => {
178+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
179+ return vec ! [ item] ;
180+ }
181+ } ,
163182 _ => {
164183 dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
165184 return vec ! [ item] ;
166185 }
167186 } ;
168- ( * ident , sig . clone ( ) , true )
187+ ( sig , vis , ident )
169188 }
170189 _ => {
171190 dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
@@ -184,15 +203,6 @@ mod llvm_enzyme {
184203 let has_ret = has_ret ( & sig. decl . output ) ;
185204 let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
186205
187- let vis = match & item {
188- Annotatable :: Item ( iitem) => iitem. vis . clone ( ) ,
189- Annotatable :: AssocItem ( assoc_item, _) => assoc_item. vis . clone ( ) ,
190- _ => {
191- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
192- return vec ! [ item] ;
193- }
194- } ;
195-
196206 // create TokenStream from vec elemtents:
197207 // meta_item doesn't have a .tokens field
198208 let comma: Token = Token :: new ( TokenKind :: Comma , Span :: default ( ) ) ;
@@ -303,6 +313,22 @@ mod llvm_enzyme {
303313 }
304314 Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
305315 }
316+ Annotatable :: Stmt ( ref mut stmt) => {
317+ match stmt. kind {
318+ ast:: StmtKind :: Item ( ref mut iitem) => {
319+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
320+ iitem. attrs . push ( attr) ;
321+ }
322+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
323+ {
324+ iitem. attrs . push ( inline_never. clone ( ) ) ;
325+ }
326+ }
327+ _ => unreachable ! ( "stmt kind checked previously" ) ,
328+ } ;
329+
330+ Annotatable :: Stmt ( stmt. clone ( ) )
331+ }
306332 _ => {
307333 unreachable ! ( "annotatable kind checked previously" )
308334 }
@@ -313,22 +339,40 @@ mod llvm_enzyme {
313339 delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
314340 tokens : ts,
315341 } ) ;
342+
316343 let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
317- let d_annotatable = if is_impl {
318- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
319- let d_fn = P ( ast:: AssocItem {
320- attrs : thin_vec ! [ d_attr, inline_never] ,
321- id : ast:: DUMMY_NODE_ID ,
322- span,
323- vis,
324- kind : assoc_item,
325- tokens : None ,
326- } ) ;
327- Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
328- } else {
329- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
330- d_fn. vis = vis;
331- Annotatable :: Item ( d_fn)
344+ let d_annotatable = match & item {
345+ Annotatable :: AssocItem ( _, _) => {
346+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
347+ let d_fn = P ( ast:: AssocItem {
348+ attrs : thin_vec ! [ d_attr, inline_never] ,
349+ id : ast:: DUMMY_NODE_ID ,
350+ span,
351+ vis,
352+ kind : assoc_item,
353+ tokens : None ,
354+ } ) ;
355+ Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
356+ }
357+ Annotatable :: Item ( _) => {
358+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
359+ d_fn. vis = vis;
360+
361+ Annotatable :: Item ( d_fn)
362+ }
363+ Annotatable :: Stmt ( _) => {
364+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
365+ d_fn. vis = vis;
366+
367+ Annotatable :: Stmt ( P ( ast:: Stmt {
368+ id : ast:: DUMMY_NODE_ID ,
369+ kind : ast:: StmtKind :: Item ( d_fn) ,
370+ span,
371+ } ) )
372+ }
373+ _ => {
374+ unreachable ! ( "item kind checked previously" )
375+ }
332376 } ;
333377
334378 return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments