33//use crate::util::check_autodiff;
44
55use crate :: errors;
6- use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
6+ use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode , valid_input_activity } ;
77use rustc_ast:: ptr:: P ;
88use rustc_ast:: token:: { Token , TokenKind } ;
99use rustc_ast:: tokenstream:: * ;
@@ -80,7 +80,7 @@ pub fn expand(
8080 dbg ! ( & x) ;
8181 let span = ecx. with_def_site_ctxt ( expand_span) ;
8282
83- let ( d_sig, new_args, idents) = gen_enzyme_decl ( & sig, & x, span) ;
83+ let ( d_sig, new_args, idents) = gen_enzyme_decl ( ecx , & sig, & x, span) ;
8484 let new_decl_span = d_sig. span ;
8585 let d_body = gen_enzyme_body (
8686 ecx,
@@ -175,6 +175,26 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
175175 ty
176176}
177177
178+ // TODO We should make this more robust to also
179+ // accept aliases of f32 and f64
180+ #[ cfg( llvm_enzyme) ]
181+ fn is_float ( ty : & ast:: Ty ) -> bool {
182+ match ty. kind {
183+ TyKind :: Path ( _, ref path) => {
184+ let last = path. segments . last ( ) . unwrap ( ) ;
185+ last. ident . name == sym:: f32 || last. ident . name == sym:: f64
186+ }
187+ _ => false ,
188+ }
189+ }
190+ #[ cfg( llvm_enzyme) ]
191+ fn is_ptr_or_ref ( ty : & ast:: Ty ) -> bool {
192+ match ty. kind {
193+ TyKind :: Ptr ( _) | TyKind :: Ref ( _, _) => true ,
194+ _ => false ,
195+ }
196+ }
197+
178198// The body of our generated functions will consist of two black_Box calls.
179199// The first will call the primal function with the original arguments.
180200// The second will just take a tuple containing the new arguments.
@@ -259,6 +279,7 @@ fn gen_primal_call(
259279// activity.
260280#[ cfg( llvm_enzyme) ]
261281fn gen_enzyme_decl (
282+ ecx : & ExtCtxt < ' _ > ,
262283 sig : & ast:: FnSig ,
263284 x : & AutoDiffAttrs ,
264285 span : Span ,
@@ -273,31 +294,50 @@ fn gen_enzyme_decl(
273294 let mut act_ret = ThinVec :: new ( ) ;
274295 for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
275296 d_inputs. push ( arg. clone ( ) ) ;
297+ if !valid_input_activity ( x. mode , * activity) {
298+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplicationModeAct {
299+ span,
300+ mode : x. mode . to_string ( ) ,
301+ act : activity. to_string ( )
302+ } ) ;
303+ }
276304 match activity {
277305 DiffActivity :: Active => {
278- assert ! ( x . mode == DiffMode :: Reverse ) ;
306+ assert ! ( is_float ( & arg . ty ) ) ;
279307 act_ret. push ( arg. ty . clone ( ) ) ;
280308 }
281- DiffActivity :: Duplicated | DiffActivity :: Dual => {
309+ DiffActivity :: Duplicated => {
310+ assert ! ( is_ptr_or_ref( & arg. ty) ) ;
282311 let mut shadow_arg = arg. clone ( ) ;
283312 // We += into the shadow in reverse mode.
284- // Otherwise copy mutability of the original argument.
285- if activity == & DiffActivity :: Duplicated {
286- shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
287- }
288- // adjust name depending on mode
313+ shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
289314 let old_name = if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
290315 ident. name
291316 } else {
292317 dbg ! ( & shadow_arg. pat) ;
293318 panic ! ( "not an ident?" ) ;
294319 } ;
295- let name: String = match x. mode {
296- DiffMode :: Reverse => format ! ( "d{}" , old_name) ,
297- DiffMode :: Forward => format ! ( "b{}" , old_name) ,
298- _ => panic ! ( "unsupported mode: {}" , old_name) ,
320+ let name: String = format ! ( "d{}" , old_name) ;
321+ new_inputs. push ( name. clone ( ) ) ;
322+ let ident = Ident :: from_str_and_span ( & name, shadow_arg. pat . span ) ;
323+ shadow_arg. pat = P ( ast:: Pat {
324+ // TODO: Check id
325+ id : ast:: DUMMY_NODE_ID ,
326+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
327+ span : shadow_arg. pat . span ,
328+ tokens : shadow_arg. pat . tokens . clone ( ) ,
329+ } ) ;
330+ d_inputs. push ( shadow_arg) ;
331+ }
332+ DiffActivity :: Dual => {
333+ let mut shadow_arg = arg. clone ( ) ;
334+ let old_name = if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
335+ ident. name
336+ } else {
337+ dbg ! ( & shadow_arg. pat) ;
338+ panic ! ( "not an ident?" ) ;
299339 } ;
300- dbg ! ( & name ) ;
340+ let name : String = format ! ( "b{}" , old_name ) ;
301341 new_inputs. push ( name. clone ( ) ) ;
302342 let ident = Ident :: from_str_and_span ( & name, shadow_arg. pat . span ) ;
303343 shadow_arg. pat = P ( ast:: Pat {
@@ -311,6 +351,7 @@ fn gen_enzyme_decl(
311351 }
312352 _ => {
313353 dbg ! ( & activity) ;
354+ panic ! ( "Not implemented" ) ;
314355 }
315356 }
316357 if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
0 commit comments