44//use crate::util::check_builtin_macro_attribute;
55//use crate::util::check_autodiff;
66
7- use std:: string:: String ;
87use crate :: errors;
8+ use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
99use rustc_ast:: ptr:: P ;
10- use rustc_ast:: { BindingAnnotation , ByRef } ;
11- use rustc_ast:: { self as ast, FnHeader , FnSig , Generics , StmtKind , NestedMetaItem , MetaItemKind } ;
12- use rustc_ast:: { Fn , ItemKind , Stmt , TyKind , Unsafe , PatKind } ;
10+ use rustc_ast:: token:: { Token , TokenKind } ;
1311use rustc_ast:: tokenstream:: * ;
12+ use rustc_ast:: { self as ast, FnHeader , FnSig , Generics , MetaItemKind , NestedMetaItem , StmtKind } ;
13+ use rustc_ast:: { BindingAnnotation , ByRef } ;
14+ use rustc_ast:: { Fn , ItemKind , PatKind , Stmt , TyKind , Unsafe } ;
1415use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
1516use rustc_span:: symbol:: { kw, sym, Ident } ;
1617use rustc_span:: Span ;
17- use thin_vec:: { thin_vec, ThinVec } ;
1818use rustc_span:: Symbol ;
19- use rustc_ast :: expand :: autodiff_attrs :: { AutoDiffAttrs , DiffActivity , DiffMode } ;
20- use rustc_ast :: token :: { Token , TokenKind } ;
19+ use std :: string :: String ;
20+ use thin_vec :: { thin_vec , ThinVec } ;
2121
2222fn first_ident ( x : & NestedMetaItem ) -> rustc_span:: symbol:: Ident {
2323 let segments = & x. meta_item ( ) . unwrap ( ) . path . segments ;
@@ -36,9 +36,7 @@ pub fn expand(
3636 let meta_item_vec: ThinVec < NestedMetaItem > = match meta_item. kind {
3737 ast:: MetaItemKind :: List ( ref vec) => vec. clone ( ) ,
3838 _ => {
39- ecx. sess
40- . dcx ( )
41- . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
39+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
4240 return vec ! [ item] ;
4341 }
4442 } ;
@@ -52,18 +50,19 @@ pub fn expand(
5250 {
5351 ( item, sig. decl . output . has_ret ( ) , sig, ecx. with_call_site_ctxt ( sig. span ) )
5452 } else {
55- ecx. sess
56- . dcx ( )
57- . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
53+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
5854 return vec ! [ item] ;
5955 } ;
6056 // create TokenStream from vec elemtents:
6157 // meta_item doesn't have a .tokens field
62- let ts: Vec < Token > = meta_item_vec. clone ( ) [ 1 ..] . iter ( ) . map ( |x| {
63- let val = first_ident ( x) ;
64- let t = Token :: from_ast_ident ( val) ;
65- t
66- } ) . collect ( ) ;
58+ let ts: Vec < Token > = meta_item_vec. clone ( ) [ 1 ..]
59+ . iter ( )
60+ . map ( |x| {
61+ let val = first_ident ( x) ;
62+ let t = Token :: from_ast_ident ( val) ;
63+ t
64+ } )
65+ . collect ( ) ;
6766 let comma: Token = Token :: new ( TokenKind :: Comma , Span :: default ( ) ) ;
6867 let mut ts: Vec < TokenTree > = vec ! [ ] ;
6968 for t in meta_item_vec. clone ( ) [ 1 ..] . iter ( ) {
@@ -80,7 +79,18 @@ pub fn expand(
8079
8180 let ( d_sig, old_names, new_args, idents) = gen_enzyme_decl ( ecx, & sig, & x, span, sig_span) ;
8281 let new_decl_span = d_sig. span ;
83- let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span, new_decl_span, & sig, & d_sig, idents) ;
82+ let d_body = gen_enzyme_body (
83+ ecx,
84+ primal,
85+ & old_names,
86+ & new_args,
87+ span,
88+ sig_span,
89+ new_decl_span,
90+ & sig,
91+ & d_sig,
92+ idents,
93+ ) ;
8494 let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
8595
8696 // The first element of it is the name of the function to be generated
@@ -90,7 +100,8 @@ pub fn expand(
90100 generics : Generics :: default ( ) ,
91101 body : Some ( d_body) ,
92102 } ) ) ;
93- let mut rustc_ad_attr = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
103+ let mut rustc_ad_attr =
104+ P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
94105 let mut attr: ast:: Attribute = ast:: Attribute {
95106 kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
96107 id : ast:: AttrId :: from_u32 ( 0 ) ,
@@ -136,27 +147,33 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
136147 ty
137148}
138149
139-
140150// The body of our generated functions will consist of three black_Box calls.
141151// The first will call the primal function with the original arguments.
142152// The second will just take the shadow arguments.
143153// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
144154// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
145- fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span , new_decl_span : Span , sig : & ast:: FnSig , d_sig : & ast:: FnSig , idents : Vec < Ident > ) -> P < ast:: Block > {
155+ fn gen_enzyme_body (
156+ ecx : & ExtCtxt < ' _ > ,
157+ primal : Ident ,
158+ old_names : & [ String ] ,
159+ new_names : & [ String ] ,
160+ span : Span ,
161+ sig_span : Span ,
162+ new_decl_span : Span ,
163+ sig : & ast:: FnSig ,
164+ d_sig : & ast:: FnSig ,
165+ idents : Vec < Ident > ,
166+ ) -> P < ast:: Block > {
146167 let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
147168 let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
148169 let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
149170 let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
150171
151-
152172 let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
153173 let zeroed_call_expr = ecx. expr_path ( ecx. path ( span, zeroed_path) ) ;
154174
155- let mem_zeroed_call: Stmt = ecx. stmt_expr ( ecx. expr_call (
156- span,
157- zeroed_call_expr. clone ( ) ,
158- thin_vec ! [ ] ,
159- ) ) ;
175+ let mem_zeroed_call: Stmt =
176+ ecx. stmt_expr ( ecx. expr_call ( span, zeroed_call_expr. clone ( ) , thin_vec ! [ ] ) ) ;
160177 let unsafe_block_with_zeroed_call: P < ast:: Expr > = ecx. expr_block ( P ( ast:: Block {
161178 stmts : thin_vec ! [ mem_zeroed_call] ,
162179 id : ast:: DUMMY_NODE_ID ,
@@ -167,19 +184,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
167184 } ) ) ;
168185 let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
169186 // create ::core::hint::black_box(array(arr));
170- let black_box0 = ecx. expr_call (
171- new_decl_span,
172- blackbox_call_expr. clone ( ) ,
173- thin_vec ! [ primal_call. clone( ) ] ,
174- ) ;
187+ let black_box0 =
188+ ecx. expr_call ( new_decl_span, blackbox_call_expr. clone ( ) , thin_vec ! [ primal_call. clone( ) ] ) ;
175189
176190 // create ::core::hint::black_box(grad_arr, tang_y));
177191 let black_box1 = ecx. expr_call (
178192 sig_span,
179193 blackbox_call_expr. clone ( ) ,
180- new_names. iter ( ) . map ( |arg| {
181- ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) )
182- } ) . collect ( ) ,
194+ new_names
195+ . iter ( )
196+ . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) ) )
197+ . collect ( ) ,
183198 ) ;
184199
185200 // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
@@ -189,7 +204,6 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
189204 thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
190205 ) ;
191206
192-
193207 let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
194208 body. stmts . push ( ecx. stmt_semi ( primal_call) ) ;
195209 body. stmts . push ( ecx. stmt_semi ( black_box0) ) ;
@@ -199,16 +213,16 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
199213 body
200214}
201215
202- fn gen_primal_call ( ecx : & ExtCtxt < ' _ > , span : Span , primal : Ident , sig : & ast:: FnSig , idents : Vec < Ident > ) -> P < ast:: Expr > {
216+ fn gen_primal_call (
217+ ecx : & ExtCtxt < ' _ > ,
218+ span : Span ,
219+ primal : Ident ,
220+ sig : & ast:: FnSig ,
221+ idents : Vec < Ident > ,
222+ ) -> P < ast:: Expr > {
203223 let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
204- let args = idents. iter ( ) . map ( |arg| {
205- ecx. expr_path ( ecx. path_ident ( span, * arg) )
206- } ) . collect ( ) ;
207- let primal_call = ecx. expr_call (
208- span,
209- primal_call_expr,
210- args,
211- ) ;
224+ let args = idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
225+ let primal_call = ecx. expr_call ( span, primal_call_expr, args) ;
212226 primal_call
213227}
214228
@@ -218,8 +232,13 @@ fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSi
218232// zero-initialized by Enzyme). Active arguments are not handled yet.
219233// Each argument of the primal function (and the return type if existing) must be annotated with an
220234// activity.
221- fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , sig : & ast:: FnSig , x : & AutoDiffAttrs , span : Span , _sig_span : Span )
222- -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
235+ fn gen_enzyme_decl (
236+ _ecx : & ExtCtxt < ' _ > ,
237+ sig : & ast:: FnSig ,
238+ x : & AutoDiffAttrs ,
239+ span : Span ,
240+ _sig_span : Span ,
241+ ) -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
223242 assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
224243 assert ! ( sig. decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
225244 let mut d_decl = sig. decl . clone ( ) ;
@@ -253,17 +272,16 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
253272 shadow_arg. pat = P ( ast:: Pat {
254273 // TODO: Check id
255274 id : ast:: DUMMY_NODE_ID ,
256- kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
257- ident,
258- None ,
259- ) ,
275+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
260276 span : shadow_arg. pat . span ,
261277 tokens : shadow_arg. pat . tokens . clone ( ) ,
262278 } ) ;
263279 //idents.push(ident);
264280 d_inputs. push ( shadow_arg) ;
265281 }
266- _ => { dbg ! ( & activity) ; } ,
282+ _ => {
283+ dbg ! ( & activity) ;
284+ }
267285 }
268286 }
269287
@@ -285,10 +303,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
285303 ty : ty. clone ( ) ,
286304 pat : P ( ast:: Pat {
287305 id : ast:: DUMMY_NODE_ID ,
288- kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
289- ident,
290- None ,
291- ) ,
306+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
292307 span : ty. span ,
293308 tokens : None ,
294309 } ) ,
@@ -302,10 +317,6 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
302317 }
303318 }
304319 d_decl. inputs = d_inputs. into ( ) ;
305- let d_sig = FnSig {
306- header : sig. header . clone ( ) ,
307- decl : d_decl,
308- span,
309- } ;
320+ let d_sig = FnSig { header : sig. header . clone ( ) , decl : d_decl, span } ;
310321 ( d_sig, old_names, new_inputs, idents)
311322}
0 commit comments