@@ -359,30 +359,27 @@ mod llvm_enzyme {
359359 ty
360360 }
361361
362- /// We only want this function to type-check, since we will replace the body
363- /// later on llvm level. Using `loop {}` does not cover all return types anymore,
364- /// so instead we build something that should pass. We also add a inline_asm
365- /// line, as one more barrier for rustc to prevent inlining of this function.
366- /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
367- /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
368- /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
369- /// this function (which should never happen, since it is only a placeholder).
370- /// Finally, we also add back_box usages of all input arguments, to prevent rustc
371- /// from optimizing any arguments away.
372- fn gen_enzyme_body (
362+ // Will generate a body of the type:
363+ // ```
364+ // {
365+ // unsafe {
366+ // asm!("NOP");
367+ // }
368+ // ::core::hint::black_box(primal(args));
369+ // ::core::hint::black_box((args, ret));
370+ // <This part remains to be done by following function>
371+ // }
372+ // ```
373+ fn init_body_helper (
373374 ecx : & ExtCtxt < ' _ > ,
374- x : & AutoDiffAttrs ,
375- n_active : u32 ,
376- sig : & ast:: FnSig ,
377- d_sig : & ast:: FnSig ,
375+ span : Span ,
378376 primal : Ident ,
379377 new_names : & [ String ] ,
380- span : Span ,
381378 sig_span : Span ,
382379 new_decl_span : Span ,
383- idents : Vec < Ident > ,
380+ idents : & [ Ident ] ,
384381 errored : bool ,
385- ) -> P < ast:: Block > {
382+ ) -> ( P < ast:: Block > , P < ast :: Expr > , P < ast :: Expr > , P < ast :: Expr > ) {
386383 let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
387384 let noop = ast:: InlineAsm {
388385 asm_macro : ast:: AsmMacro :: Asm ,
@@ -431,6 +428,54 @@ mod llvm_enzyme {
431428 }
432429 body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ;
433430
431+ ( body, primal_call, black_box_primal_call, blackbox_call_expr)
432+ }
433+
434+ /// We only want this function to type-check, since we will replace the body
435+ /// later on llvm level. Using `loop {}` does not cover all return types anymore,
436+ /// so instead we build something that should pass. We also add a inline_asm
437+ /// line, as one more barrier for rustc to prevent inlining of this function.
438+ /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
439+ /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
440+ /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
441+ /// this function (which should never happen, since it is only a placeholder).
442+ /// Finally, we also add back_box usages of all input arguments, to prevent rustc
443+ /// from optimizing any arguments away.
444+ fn gen_enzyme_body (
445+ ecx : & ExtCtxt < ' _ > ,
446+ x : & AutoDiffAttrs ,
447+ n_active : u32 ,
448+ sig : & ast:: FnSig ,
449+ d_sig : & ast:: FnSig ,
450+ primal : Ident ,
451+ new_names : & [ String ] ,
452+ span : Span ,
453+ sig_span : Span ,
454+ _new_decl_span : Span ,
455+ idents : Vec < Ident > ,
456+ errored : bool ,
457+ ) -> P < ast:: Block > {
458+ let new_decl_span = d_sig. span ;
459+
460+ // Just adding some default inline-asm and black_box usages to prevent early inlining
461+ // and optimizations which alter the function signature.
462+ //
463+ // The bb_primal_call is the black_box call of the primal function. We keep it around,
464+ // since it has the convenient property of returning the type of the primal function,
465+ // Remember, we only care to match types here.
466+ // No matter which return we pick, we always wrap it into a std::hint::black_box call,
467+ // to prevent rustc from propagating it into the caller.
468+ let ( mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper (
469+ ecx,
470+ span,
471+ primal,
472+ new_names,
473+ sig_span,
474+ new_decl_span,
475+ & idents,
476+ errored,
477+ ) ;
478+
434479 if !has_ret ( & d_sig. decl . output ) {
435480 // there is no return type that we have to match, () works fine.
436481 return body;
@@ -442,7 +487,7 @@ mod llvm_enzyme {
442487
443488 if primal_ret && n_active == 0 && x. mode . is_rev ( ) {
444489 // We only have the primal ret.
445- body. stmts . push ( ecx. stmt_expr ( black_box_primal_call ) ) ;
490+ body. stmts . push ( ecx. stmt_expr ( bb_primal_call ) ) ;
446491 return body;
447492 }
448493
@@ -534,11 +579,11 @@ mod llvm_enzyme {
534579 return body;
535580 }
536581 [ arg] => {
537- ret = ecx. expr_call ( new_decl_span, blackbox_call_expr , thin_vec ! [ arg. clone( ) ] ) ;
582+ ret = ecx. expr_call ( new_decl_span, bb_call_expr , thin_vec ! [ arg. clone( ) ] ) ;
538583 }
539584 args => {
540585 let ret_tuple: P < ast:: Expr > = ecx. expr_tuple ( span, args. into ( ) ) ;
541- ret = ecx. expr_call ( new_decl_span, blackbox_call_expr , thin_vec ! [ ret_tuple] ) ;
586+ ret = ecx. expr_call ( new_decl_span, bb_call_expr , thin_vec ! [ ret_tuple] ) ;
542587 }
543588 }
544589 assert ! ( has_ret( & d_sig. decl. output) ) ;
@@ -551,7 +596,7 @@ mod llvm_enzyme {
551596 ecx : & ExtCtxt < ' _ > ,
552597 span : Span ,
553598 primal : Ident ,
554- idents : Vec < Ident > ,
599+ idents : & [ Ident ] ,
555600 ) -> P < ast:: Expr > {
556601 let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
557602 if has_self {
0 commit comments