55//use crate::util::check_autodiff;
66
77use crate :: errors;
8+ use rustc_ast:: FnRetTy ;
89use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
910use rustc_ast:: ptr:: P ;
1011use rustc_ast:: token:: { Token , TokenKind } ;
@@ -41,7 +42,6 @@ pub fn expand(
4142 }
4243 } ;
4344 let mut orig_item: P < ast:: Item > = item. clone ( ) . expect_item ( ) ;
44- //dbg!(&orig_item.tokens);
4545 let primal = orig_item. ident . clone ( ) ;
4646
4747 // Allow using `#[autodiff(...)]` only on a Fn
@@ -77,7 +77,7 @@ pub fn expand(
7777 dbg ! ( & x) ;
7878 let span = ecx. with_def_site_ctxt ( expand_span) ;
7979
80- let ( d_sig, old_names, new_args, idents) = gen_enzyme_decl ( ecx , & sig, & x, span, sig_span ) ;
80+ let ( d_sig, old_names, new_args, idents) = gen_enzyme_decl ( & sig, & x, span) ;
8181 let new_decl_span = d_sig. span ;
8282 let d_body = gen_enzyme_body (
8383 ecx,
@@ -147,11 +147,11 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
147147 ty
148148}
149149
150- // The body of our generated functions will consist of three black_Box calls.
150+ // The body of our generated functions will consist of two black_Box calls.
151151// The first will call the primal function with the original arguments.
152- // The second will just take the shadow arguments.
153- // The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
154- // (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
152+ // The second will just take a tuple containing the new arguments.
153+ // This way we surpress rustc from optimizing any argument away.
154+ // The last line will 'loop {}', to match the return type of the new function
155155fn gen_enzyme_body (
156156 ecx : & ExtCtxt < ' _ > ,
157157 primal : Ident ,
@@ -184,31 +184,25 @@ fn gen_enzyme_body(
184184 } ) ) ;
185185 let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
186186 // create ::core::hint::black_box(array(arr));
187- let black_box0 =
187+ let black_box_primal_call =
188188 ecx. expr_call ( new_decl_span, blackbox_call_expr. clone ( ) , thin_vec ! [ primal_call. clone( ) ] ) ;
189189
190- // create ::core::hint::black_box(grad_arr, tang_y));
191- let black_box1 = ecx. expr_call (
192- sig_span,
193- blackbox_call_expr. clone ( ) ,
194- new_names
195- . iter ( )
196- . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) ) )
197- . collect ( ) ,
198- ) ;
190+ // create ::core::hint::black_box((grad_arr, tang_y));
191+ let tup_args = new_names
192+ . iter ( )
193+ . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) ) )
194+ . collect ( ) ;
199195
200- // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
201- let black_box2 = ecx. expr_call (
196+ let black_box_remaining_args = ecx. expr_call (
202197 sig_span,
203198 blackbox_call_expr. clone ( ) ,
204- thin_vec ! [ unsafe_block_with_zeroed_call . clone ( ) ] ,
199+ thin_vec ! [ ecx . expr_tuple ( sig_span , tup_args ) ] ,
205200 ) ;
206201
207202 let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
208203 body. stmts . push ( ecx. stmt_semi ( primal_call) ) ;
209- body. stmts . push ( ecx. stmt_semi ( black_box0) ) ;
210- body. stmts . push ( ecx. stmt_semi ( black_box1) ) ;
211- //body.stmts.push(ecx.stmt_semi(black_box2));
204+ body. stmts . push ( ecx. stmt_semi ( black_box_primal_call) ) ;
205+ body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ;
212206 body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
213207 body
214208}
@@ -233,11 +227,9 @@ fn gen_primal_call(
233227// Each argument of the primal function (and the return type if existing) must be annotated with an
234228// activity.
235229fn gen_enzyme_decl (
236- _ecx : & ExtCtxt < ' _ > ,
237230 sig : & ast:: FnSig ,
238231 x : & AutoDiffAttrs ,
239232 span : Span ,
240- _sig_span : Span ,
241233) -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
242234 assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
243235 assert ! ( sig. decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
@@ -246,15 +238,19 @@ fn gen_enzyme_decl(
246238 let mut new_inputs = Vec :: new ( ) ;
247239 let mut old_names = Vec :: new ( ) ;
248240 let mut idents = Vec :: new ( ) ;
241+ let mut act_ret = ThinVec :: new ( ) ;
249242 for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
250243 d_inputs. push ( arg. clone ( ) ) ;
251244 match activity {
245+ DiffActivity :: Active => {
246+ assert ! ( x. mode == DiffMode :: Reverse ) ;
247+ act_ret. push ( arg. ty . clone ( ) ) ;
248+ }
252249 DiffActivity :: Duplicated | DiffActivity :: Dual => {
253250 let mut shadow_arg = arg. clone ( ) ;
254251 shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
255252 // adjust name depending on mode
256- let old_name = if let PatKind :: Ident ( _, ident, _) = shadow_arg. pat . kind {
257- idents. push ( ident. clone ( ) ) ;
253+ let old_name = if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
258254 ident. name
259255 } else {
260256 dbg ! ( & shadow_arg. pat) ;
@@ -276,47 +272,72 @@ fn gen_enzyme_decl(
276272 span : shadow_arg. pat . span ,
277273 tokens : shadow_arg. pat . tokens . clone ( ) ,
278274 } ) ;
279- //idents.push(ident);
280275 d_inputs. push ( shadow_arg) ;
281276 }
282277 _ => {
283278 dbg ! ( & activity) ;
284279 }
285280 }
281+ if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
282+ idents. push ( ident. clone ( ) ) ;
283+ } else {
284+ panic ! ( "not an ident?" ) ;
285+ }
286286 }
287287
288288 // If we return a scalar in the primal and the scalar is active,
289289 // then add it as last arg to the inputs.
290- if x. mode == DiffMode :: Reverse {
291- match x. ret_activity {
292- DiffActivity :: Active => {
293- let ty = match d_decl. output {
294- rustc_ast:: FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
295- rustc_ast:: FnRetTy :: Default ( span) => {
296- panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
297- }
298- } ;
299- let name = "dret" . to_string ( ) ;
300- let ident = Ident :: from_str_and_span ( & name, ty. span ) ;
301- let shadow_arg = ast:: Param {
302- attrs : ThinVec :: new ( ) ,
303- ty : ty. clone ( ) ,
304- pat : P ( ast:: Pat {
305- id : ast:: DUMMY_NODE_ID ,
306- kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
307- span : ty. span ,
308- tokens : None ,
309- } ) ,
290+ if let DiffMode :: Reverse = x. mode {
291+ if let DiffActivity :: Active = x. ret_activity {
292+ let ty = match d_decl. output {
293+ FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
294+ FnRetTy :: Default ( span) => {
295+ panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
296+ }
297+ } ;
298+ let name = "dret" . to_string ( ) ;
299+ let ident = Ident :: from_str_and_span ( & name, ty. span ) ;
300+ let shadow_arg = ast:: Param {
301+ attrs : ThinVec :: new ( ) ,
302+ ty : ty. clone ( ) ,
303+ pat : P ( ast:: Pat {
310304 id : ast:: DUMMY_NODE_ID ,
305+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
311306 span : ty. span ,
312- is_placeholder : false ,
313- } ;
314- d_inputs. push ( shadow_arg) ;
315- }
316- _ => { }
307+ tokens : None ,
308+ } ) ,
309+ id : ast:: DUMMY_NODE_ID ,
310+ span : ty. span ,
311+ is_placeholder : false ,
312+ } ;
313+ d_inputs. push ( shadow_arg) ;
314+ new_inputs. push ( name) ;
317315 }
318316 }
319317 d_decl. inputs = d_inputs. into ( ) ;
318+
319+ // If we have an active input scalar, add it's gradient to the
320+ // return type. This might require changing the return type to a
321+ // tuple.
322+ if act_ret. len ( ) > 0 {
323+ let mut ret_ty = match d_decl. output {
324+ FnRetTy :: Ty ( ref ty) => {
325+ act_ret. insert ( 0 , ty. clone ( ) ) ;
326+ let kind = TyKind :: Tup ( act_ret) ;
327+ P ( rustc_ast:: Ty { kind, id : ty. id , span : ty. span , tokens : None } )
328+ }
329+ FnRetTy :: Default ( span) => {
330+ if act_ret. len ( ) == 1 {
331+ act_ret[ 0 ] . clone ( )
332+ } else {
333+ let kind = TyKind :: Tup ( act_ret. iter ( ) . map ( |arg| arg. clone ( ) ) . collect ( ) ) ;
334+ P ( rustc_ast:: Ty { kind, id : ast:: DUMMY_NODE_ID , span, tokens : None } )
335+ }
336+ }
337+ } ;
338+ d_decl. output = FnRetTy :: Ty ( ret_ty) ;
339+ }
340+
320341 let d_sig = FnSig { header : sig. header . clone ( ) , decl : d_decl, span } ;
321342 ( d_sig, old_names, new_inputs, idents)
322343}
0 commit comments