@@ -72,20 +72,15 @@ pub fn expand(
7272 ts. push ( TokenTree :: Token ( t, Spacing :: Joint ) ) ;
7373 ts. push ( TokenTree :: Token ( comma. clone ( ) , Spacing :: Alone ) ) ;
7474 }
75- dbg ! ( & ts) ;
7675 let ts: TokenStream = TokenStream :: from_iter ( ts) ;
77- dbg ! ( & ts) ;
7876
7977 let x: AutoDiffAttrs = AutoDiffAttrs :: from_ast ( & meta_item_vec, has_ret) ;
8078 dbg ! ( & x) ;
81- //let span = ecx.with_def_site_ctxt(sig_span);
8279 let span = ecx. with_def_site_ctxt ( expand_span) ;
83- //let span = ecx.with_def_site_ctxt(fn_item.span);
8480
85- let ( d_sig, old_names, new_args) = gen_enzyme_decl ( ecx, & sig, & x, span, sig_span) ;
81+ let ( d_sig, old_names, new_args, idents ) = gen_enzyme_decl ( ecx, & sig, & x, span, sig_span) ;
8682 let new_decl_span = d_sig. span ;
87- //let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, span);
88- let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span, new_decl_span) ;
83+ let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span, new_decl_span, & sig, & d_sig, idents) ;
8984 let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
9085
9186 // The first element of it is the name of the function to be generated
@@ -147,14 +142,13 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
147142// The second will just take the shadow arguments.
148143// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
149144// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
150- fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span , new_decl_span : Span ) -> P < ast:: Block > {
145+ fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span , new_decl_span : Span , sig : & ast :: FnSig , d_sig : & ast :: FnSig , idents : Vec < Ident > ) -> P < ast:: Block > {
151146 let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
152147 let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
153148 let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
154149 let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
155150
156151
157- let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
158152 let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
159153 let zeroed_call_expr = ecx. expr_path ( ecx. path ( span, zeroed_path) ) ;
160154
@@ -172,18 +166,11 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
172166 could_be_bare_literal : false ,
173167 } ) ) ;
174168 // create ::core::hint::black_box(array(arr));
175- let primal_call = ecx. expr_call (
176- new_decl_span,
177- primal_call_expr,
178- old_names. iter ( ) . map ( |name| {
179- ecx. expr_path ( ecx. path_ident ( new_decl_span, Ident :: from_str ( name) ) )
180- } ) . collect ( ) ,
181- ) ;
182- let black_box0 = ecx. expr_call (
183- new_decl_span,
184- blackbox_call_expr. clone ( ) ,
185- thin_vec ! [ primal_call. clone( ) ] ,
186- ) ;
169+ //let black_box0 = ecx.expr_call(
170+ // new_decl_span,
171+ // blackbox_call_expr.clone(),
172+ // thin_vec![primal_call.clone()],
173+ //);
187174
188175 // create ::core::hint::black_box(grad_arr, tang_y));
189176 let black_box1 = ecx. expr_call (
@@ -201,30 +188,54 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
201188 thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
202189 ) ;
203190
191+ let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
192+
204193 let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
205- // body.stmts.push(ecx.stmt_expr (primal_call));
194+ body. stmts . push ( ecx. stmt_semi ( primal_call) ) ;
206195 //body.stmts.push(ecx.stmt_expr(black_box0));
207196 //body.stmts.push(ecx.stmt_expr(black_box1));
208- body. stmts . push ( ecx. stmt_expr ( black_box2) ) ;
197+ // body.stmts.push(ecx.stmt_expr(black_box2));
209198 body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
210199 body
211200}
212201
202+ fn gen_primal_call ( ecx : & ExtCtxt < ' _ > , span : Span , primal : Ident , sig : & ast:: FnSig , idents : Vec < Ident > ) -> P < ast:: Expr > {
203+ //pub struct Param {
204+ // pub attrs: AttrVec,
205+ // pub ty: P<Ty>,
206+ // pub pat: P<Pat>,
207+ // pub id: NodeId,
208+ // pub span: Span,
209+ // pub is_placeholder: bool,
210+ //}
211+ let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
212+ let args = idents. iter ( ) . map ( |arg| {
213+ ecx. expr_path ( ecx. path_ident ( span, * arg) )
214+ } ) . collect ( ) ;
215+ let primal_call = ecx. expr_call (
216+ span,
217+ primal_call_expr,
218+ args,
219+ ) ;
220+ primal_call
221+ }
222+
213223// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
214224// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
215225// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
216226// zero-initialized by Enzyme). Active arguments are not handled yet.
217227// Each argument of the primal function (and the return type if existing) must be annotated with an
218228// activity.
219229fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , sig : & ast:: FnSig , x : & AutoDiffAttrs , span : Span , _sig_span : Span )
220- -> ( ast:: FnSig , Vec < String > , Vec < String > ) {
230+ -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
221231 let decl: P < ast:: FnDecl > = sig. decl . clone ( ) ;
222232 assert ! ( decl. inputs. len( ) == x. input_activity. len( ) ) ;
223233 assert ! ( decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
224234 let mut d_decl = decl. clone ( ) ;
225235 let mut d_inputs = Vec :: new ( ) ;
226236 let mut new_inputs = Vec :: new ( ) ;
227237 let mut old_names = Vec :: new ( ) ;
238+ let mut idents = Vec :: new ( ) ;
228239 for ( arg, activity) in decl. inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
229240 //dbg!(&arg);
230241 d_inputs. push ( arg. clone ( ) ) ;
@@ -234,6 +245,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
234245 shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
235246 // adjust name depending on mode
236247 let old_name = if let PatKind :: Ident ( _, ident, _) = shadow_arg. pat . kind {
248+ idents. push ( ident. clone ( ) ) ;
237249 ident. name
238250 } else {
239251 dbg ! ( & shadow_arg. pat) ;
@@ -247,17 +259,18 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
247259 } ;
248260 dbg ! ( & name) ;
249261 new_inputs. push ( name. clone ( ) ) ;
262+ let ident = Ident :: from_str_and_span ( & name, shadow_arg. pat . span ) ;
250263 shadow_arg. pat = P ( ast:: Pat {
251264 // TODO: Check id
252265 id : ast:: DUMMY_NODE_ID ,
253266 kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
254- Ident :: from_str_and_span ( & name , shadow_arg . pat . span ) ,
267+ ident ,
255268 None ,
256269 ) ,
257270 span : shadow_arg. pat . span ,
258271 tokens : shadow_arg. pat . tokens . clone ( ) ,
259272 } ) ;
260-
273+ //idents.push(ident);
261274 d_inputs. push ( shadow_arg) ;
262275 }
263276 _ => { } ,
@@ -269,5 +282,5 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
269282 decl : d_decl,
270283 span,
271284 } ;
272- ( d_sig, old_names, new_inputs)
285+ ( d_sig, old_names, new_inputs, idents )
273286}
0 commit comments