@@ -90,23 +90,23 @@ pub fn expand(
9090 generics : Generics :: default ( ) ,
9191 body : Some ( d_body) ,
9292 } ) ) ;
93- let mut tmp = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
93+ let mut rustc_ad_attr = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
9494 let mut attr: ast:: Attribute = ast:: Attribute {
95- kind : ast:: AttrKind :: Normal ( tmp . clone ( ) ) ,
95+ kind : ast:: AttrKind :: Normal ( rustc_ad_attr . clone ( ) ) ,
9696 id : ast:: AttrId :: from_u32 ( 0 ) ,
9797 style : ast:: AttrStyle :: Outer ,
9898 span : span,
9999 } ;
100- orig_item. attrs . push ( attr) ;
100+ orig_item. attrs . push ( attr. clone ( ) ) ;
101101
102102 // Now update for d_fn
103- tmp . item . args = rustc_ast:: AttrArgs :: Delimited ( rustc_ast:: DelimArgs {
103+ rustc_ad_attr . item . args = rustc_ast:: AttrArgs :: Delimited ( rustc_ast:: DelimArgs {
104104 dspan : DelimSpan :: dummy ( ) ,
105105 delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
106106 tokens : ts,
107107 } ) ;
108108 let mut attr2: ast:: Attribute = ast:: Attribute {
109- kind : ast:: AttrKind :: Normal ( tmp ) ,
109+ kind : ast:: AttrKind :: Normal ( rustc_ad_attr ) ,
110110 id : ast:: AttrId :: from_u32 ( 0 ) ,
111111 style : ast:: AttrStyle :: Outer ,
112112 span : span,
@@ -165,12 +165,13 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
165165 tokens : None ,
166166 could_be_bare_literal : false ,
167167 } ) ) ;
168+ let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
168169 // create ::core::hint::black_box(array(arr));
169- // let black_box0 = ecx.expr_call(
170- // new_decl_span,
171- // blackbox_call_expr.clone(),
172- // thin_vec![primal_call.clone()],
173- // );
170+ let black_box0 = ecx. expr_call (
171+ new_decl_span,
172+ blackbox_call_expr. clone ( ) ,
173+ thin_vec ! [ primal_call. clone( ) ] ,
174+ ) ;
174175
175176 // create ::core::hint::black_box(grad_arr, tang_y));
176177 let black_box1 = ecx. expr_call (
@@ -188,26 +189,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
188189 thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
189190 ) ;
190191
191- let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
192192
193193 let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
194194 body. stmts . push ( ecx. stmt_semi ( primal_call) ) ;
195- // body.stmts.push(ecx.stmt_expr (black_box0));
196- // body.stmts.push(ecx.stmt_expr (black_box1));
197- //body.stmts.push(ecx.stmt_expr (black_box2));
195+ body. stmts . push ( ecx. stmt_semi ( black_box0) ) ;
196+ body. stmts . push ( ecx. stmt_semi ( black_box1) ) ;
197+ //body.stmts.push(ecx.stmt_semi (black_box2));
198198 body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
199199 body
200200}
201201
202202fn gen_primal_call ( ecx : & ExtCtxt < ' _ > , span : Span , primal : Ident , sig : & ast:: FnSig , idents : Vec < Ident > ) -> P < ast:: Expr > {
203- //pub struct Param {
204- // pub attrs: AttrVec,
205- // pub ty: P<Ty>,
206- // pub pat: P<Pat>,
207- // pub id: NodeId,
208- // pub span: Span,
209- // pub is_placeholder: bool,
210- //}
211203 let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
212204 let args = idents. iter ( ) . map ( |arg| {
213205 ecx. expr_path ( ecx. path_ident ( span, * arg) )
@@ -228,16 +220,14 @@ fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSi
228220// activity.
229221fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , sig : & ast:: FnSig , x : & AutoDiffAttrs , span : Span , _sig_span : Span )
230222 -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
231- let decl: P < ast:: FnDecl > = sig. decl . clone ( ) ;
232- assert ! ( decl. inputs. len( ) == x. input_activity. len( ) ) ;
233- assert ! ( decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
234- let mut d_decl = decl. clone ( ) ;
223+ assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
224+ assert ! ( sig. decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
225+ let mut d_decl = sig. decl . clone ( ) ;
235226 let mut d_inputs = Vec :: new ( ) ;
236227 let mut new_inputs = Vec :: new ( ) ;
237228 let mut old_names = Vec :: new ( ) ;
238229 let mut idents = Vec :: new ( ) ;
239- for ( arg, activity) in decl. inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
240- //dbg!(&arg);
230+ for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
241231 d_inputs. push ( arg. clone ( ) ) ;
242232 match activity {
243233 DiffActivity :: Duplicated | DiffActivity :: Dual => {
@@ -273,7 +263,42 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
273263 //idents.push(ident);
274264 d_inputs. push ( shadow_arg) ;
275265 }
276- _ => { } ,
266+ _ => { dbg ! ( & activity) ; } ,
267+ }
268+ }
269+
270+ // If we return a scalar in the primal and the scalar is active,
271+ // then add it as last arg to the inputs.
272+ if x. mode == DiffMode :: Reverse {
273+ match x. ret_activity {
274+ DiffActivity :: Active => {
275+ let ty = match d_decl. output {
276+ rustc_ast:: FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
277+ rustc_ast:: FnRetTy :: Default ( span) => {
278+ panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
279+ }
280+ } ;
281+ let name = "dret" . to_string ( ) ;
282+ let ident = Ident :: from_str_and_span ( & name, ty. span ) ;
283+ let shadow_arg = ast:: Param {
284+ attrs : ThinVec :: new ( ) ,
285+ ty : ty. clone ( ) ,
286+ pat : P ( ast:: Pat {
287+ id : ast:: DUMMY_NODE_ID ,
288+ kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
289+ ident,
290+ None ,
291+ ) ,
292+ span : ty. span ,
293+ tokens : None ,
294+ } ) ,
295+ id : ast:: DUMMY_NODE_ID ,
296+ span : ty. span ,
297+ is_placeholder : false ,
298+ } ;
299+ d_inputs. push ( shadow_arg) ;
300+ }
301+ _ => { }
277302 }
278303 }
279304 d_decl. inputs = d_inputs. into ( ) ;
0 commit comments