@@ -91,7 +91,7 @@ pub fn expand(
9191 new_decl_span,
9292 idents,
9393 ) ;
94- let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
94+ let d_ident = first_ident ( & meta_item_vec[ 0 ] ) ;
9595
9696 // The first element of it is the name of the function to be generated
9797 let asdf = ItemKind :: Fn ( Box :: new ( ast:: Fn {
@@ -102,11 +102,12 @@ pub fn expand(
102102 } ) ) ;
103103 let mut rustc_ad_attr =
104104 P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
105- let attr: ast:: Attribute = ast:: Attribute {
105+ let mut attr: ast:: Attribute = ast:: Attribute {
106106 kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
107- id : ast:: AttrId :: from_u32 ( 0 ) ,
107+ //id: ast::DUMMY_TR_ID,
108+ id : ast:: AttrId :: from_u32 ( 12341 ) , // TODO: fix
108109 style : ast:: AttrStyle :: Outer ,
109- span : span ,
110+ span,
110111 } ;
111112 orig_item. attrs . push ( attr. clone ( ) ) ;
112113
@@ -116,21 +117,15 @@ pub fn expand(
116117 delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
117118 tokens : ts,
118119 } ) ;
119- let attr2: ast:: Attribute = ast:: Attribute {
120- kind : ast:: AttrKind :: Normal ( rustc_ad_attr) ,
121- id : ast:: AttrId :: from_u32 ( 0 ) ,
122- style : ast:: AttrStyle :: Outer ,
123- span : span,
124- } ;
125- let attr_vec: rustc_ast:: AttrVec = thin_vec ! [ attr2] ;
126- let d_fn = ecx. item ( span, d_ident, attr_vec, asdf) ;
120+ attr. kind = ast:: AttrKind :: Normal ( rustc_ad_attr) ;
121+ let d_fn = ecx. item ( span, d_ident, thin_vec ! [ attr] , asdf) ;
127122
128- let orig_annotatable = Annotatable :: Item ( orig_item. clone ( ) ) ;
123+ let orig_annotatable = Annotatable :: Item ( orig_item) ;
129124 let d_annotatable = Annotatable :: Item ( d_fn) ;
130125 return vec ! [ orig_annotatable, d_annotatable] ;
131126}
132127
133- // shadow arguments must be mutable references or ptrs, because Enzyme will write into them.
128+ // shadow arguments in reverse mode must be mutable references or ptrs, because Enzyme will write into them.
134129#[ cfg( llvm_enzyme) ]
135130fn assure_mut_ref ( ty : & ast:: Ty ) -> ast:: Ty {
136131 let mut ty = ty. clone ( ) ;
@@ -165,6 +160,25 @@ fn gen_enzyme_body(
165160) -> P < ast:: Block > {
166161 let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
167162 let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
163+ let noop = ast:: InlineAsm {
164+ template : vec ! [ ast:: InlineAsmTemplatePiece :: String ( "NOP" . to_string( ) ) ] ,
165+ template_strs : Box :: new ( [ ] ) ,
166+ operands : vec ! [ ] ,
167+ clobber_abis : vec ! [ ] ,
168+ options : ast:: InlineAsmOptions :: PURE & ast:: InlineAsmOptions :: NOMEM ,
169+ line_spans : vec ! [ ] ,
170+ } ;
171+ let noop_expr = ecx. expr_asm ( span, P ( noop) ) ;
172+ let unsf = ast:: BlockCheckMode :: Unsafe ( ast:: UnsafeSource :: CompilerGenerated ) ;
173+ let unsf_block = ast:: Block {
174+ stmts : thin_vec ! [ ecx. stmt_semi( noop_expr) ] ,
175+ id : ast:: DUMMY_NODE_ID ,
176+ tokens : None ,
177+ rules : unsf,
178+ span,
179+ could_be_bare_literal : false ,
180+ } ;
181+ let unsf_expr = ecx. expr_block ( P ( unsf_block) ) ;
168182 let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
169183 let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
170184 let primal_call = gen_primal_call ( ecx, span, primal, idents) ;
@@ -185,7 +199,7 @@ fn gen_enzyme_body(
185199 ) ;
186200
187201 let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
188- body. stmts . push ( ecx. stmt_semi ( primal_call ) ) ;
202+ body. stmts . push ( ecx. stmt_semi ( unsf_expr ) ) ;
189203 body. stmts . push ( ecx. stmt_semi ( black_box_primal_call) ) ;
190204 body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ;
191205 body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
@@ -234,15 +248,18 @@ fn gen_enzyme_decl(
234248 }
235249 DiffActivity :: Duplicated | DiffActivity :: Dual => {
236250 let mut shadow_arg = arg. clone ( ) ;
237- shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
251+ // We += into the shadow in reverse mode.
252+ // Otherwise copy mutability of the original argument.
253+ if activity == & DiffActivity :: Duplicated {
254+ shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
255+ }
238256 // adjust name depending on mode
239257 let old_name = if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
240258 ident. name
241259 } else {
242260 dbg ! ( & shadow_arg. pat) ;
243261 panic ! ( "not an ident?" ) ;
244262 } ;
245- //old_names.push(old_name.to_string());
246263 let name: String = match x. mode {
247264 DiffMode :: Reverse => format ! ( "d{}" , old_name) ,
248265 DiffMode :: Forward => format ! ( "b{}" , old_name) ,
0 commit comments