@@ -4,13 +4,13 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
44use rustc_codegen_ssa:: ModuleCodegen ;
55use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
66use rustc_codegen_ssa:: common:: TypeKind ;
7- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
7+ use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
88use rustc_errors:: FatalError ;
99use rustc_middle:: bug;
1010use tracing:: { debug, trace} ;
1111
1212use crate :: back:: write:: llvm_err;
13- use crate :: builder:: { SBuilder , UNNAMED } ;
13+ use crate :: builder:: { Builder , OperandRef , PlaceRef , UNNAMED } ;
1414use crate :: context:: SimpleCx ;
1515use crate :: declare:: declare_simple_fn;
1616use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
@@ -19,7 +19,7 @@ use crate::llvm::{Metadata, True};
1919use crate :: value:: Value ;
2020use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
2121
22- fn get_params ( fnc : & Value ) -> Vec < & Value > {
22+ fn _get_params ( fnc : & Value ) -> Vec < & Value > {
2323 let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
2424 let mut fnc_args: Vec < & Value > = vec ! [ ] ;
2525 fnc_args. reserve ( param_num) ;
@@ -49,9 +49,9 @@ fn has_sret(fnc: &Value) -> bool {
4949// need to match those.
5050// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
5151// using iterators and peek()?
52- fn match_args_from_caller_to_enzyme < ' ll > (
52+ fn match_args_from_caller_to_enzyme < ' ll , ' tcx > (
5353 cx : & SimpleCx < ' ll > ,
54- builder : & SBuilder < ' ll , ' ll > ,
54+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
5555 width : u32 ,
5656 args : & mut Vec < & ' ll llvm:: Value > ,
5757 inputs : & [ DiffActivity ] ,
@@ -289,11 +289,14 @@ fn compute_enzyme_fn_ty<'ll>(
289289/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
290290// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
291291// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
292- fn generate_enzyme_call < ' ll > (
292+ pub ( crate ) fn generate_enzyme_call < ' ll , ' tcx > (
293+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
293294 cx : & SimpleCx < ' ll > ,
294295 fn_to_diff : & ' ll Value ,
295296 outer_fn : & ' ll Value ,
297+ fn_args : & [ OperandRef < ' tcx , & ' ll Value > ] ,
296298 attrs : AutoDiffAttrs ,
299+ dest : PlaceRef < ' tcx , & ' ll Value > ,
297300) {
298301 // We have to pick the name depending on whether we want forward or reverse mode autodiff.
299302 let mut ad_name: String = match attrs. mode {
@@ -366,14 +369,6 @@ fn generate_enzyme_call<'ll>(
366369 let enzyme_marker_attr = llvm:: CreateAttrString ( cx. llcx , "enzyme_marker" ) ;
367370 attributes:: apply_to_llfn ( outer_fn, Function , & [ enzyme_marker_attr] ) ;
368371
369- // first, remove all calls from fnc
370- let entry = llvm:: LLVMGetFirstBasicBlock ( outer_fn) ;
371- let br = llvm:: LLVMRustGetTerminator ( entry) ;
372- llvm:: LLVMRustEraseInstFromParent ( br) ;
373-
374- let last_inst = llvm:: LLVMRustGetLastInstruction ( entry) . unwrap ( ) ;
375- let mut builder = SBuilder :: build ( cx, entry) ;
376-
377372 let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
378373 let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
379374 args. push ( fn_to_diff) ;
@@ -389,40 +384,20 @@ fn generate_enzyme_call<'ll>(
389384 }
390385
391386 let has_sret = has_sret ( outer_fn) ;
392- let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn ) ;
387+ let outer_args: Vec < & llvm:: Value > = fn_args . iter ( ) . map ( |op| op . immediate ( ) ) . collect ( ) ;
393388 match_args_from_caller_to_enzyme (
394389 & cx,
395- & builder,
390+ builder,
396391 attrs. width ,
397392 & mut args,
398393 & attrs. input_activity ,
399394 & outer_args,
400395 has_sret,
401396 ) ;
402397
403- let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
404-
405- // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
406- // metadata attached to it, but we just created this code oota. Given that the
407- // differentiated function already has partly confusing metadata, and given that this
408- // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
409- // dummy code which we inserted at a higher level.
410- // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
411- // and how to best improve it for enzyme core and rust-enzyme.
412- let md_ty = cx. get_md_kind_id ( "dbg" ) ;
413- if llvm:: LLVMRustHasMetadata ( last_inst, md_ty) {
414- let md = llvm:: LLVMRustDIGetInstMetadata ( last_inst)
415- . expect ( "failed to get instruction metadata" ) ;
416- let md_todiff = cx. get_metadata_value ( md) ;
417- llvm:: LLVMSetMetadata ( call, md_ty, md_todiff) ;
418- } else {
419- // We don't panic, since depending on whether we are in debug or release mode, we might
420- // have no debug info to copy, which would then be ok.
421- trace ! ( "no dbg info" ) ;
422- }
398+ let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
423399
424- // Now that we copied the metadata, get rid of dummy code.
425- llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
400+ builder. store_to_place ( call, dest. val ) ;
426401
427402 if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
428403 if has_sret {
@@ -445,10 +420,10 @@ fn generate_enzyme_call<'ll>(
445420 llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
446421 }
447422 builder. ret_void ( ) ;
448- } else {
449- builder. ret ( call) ;
450423 }
451424
425+ builder. store_to_place ( call, dest. val ) ;
426+
452427 // Let's crash in case that we messed something up above and generated invalid IR.
453428 llvm:: LLVMRustVerifyFunction (
454429 outer_fn,
@@ -463,6 +438,7 @@ pub(crate) fn differentiate<'ll>(
463438 diff_items : Vec < AutoDiffItem > ,
464439 _config : & ModuleConfig ,
465440) -> Result < ( ) , FatalError > {
441+ // TODO(Sa4dUs): delete all this logic
466442 for item in & diff_items {
467443 trace ! ( "{}" , item) ;
468444 }
@@ -482,7 +458,7 @@ pub(crate) fn differentiate<'ll>(
482458 for item in diff_items. iter ( ) {
483459 let name = item. source . clone ( ) ;
484460 let fn_def: Option < & llvm:: Value > = cx. get_function ( & name) ;
485- let Some ( fn_def ) = fn_def else {
461+ let Some ( _fn_def ) = fn_def else {
486462 return Err ( llvm_err (
487463 diag_handler. handle ( ) ,
488464 LlvmError :: PrepareAutoDiff {
@@ -494,7 +470,7 @@ pub(crate) fn differentiate<'ll>(
494470 } ;
495471 debug ! ( ?item. target) ;
496472 let fn_target: Option < & llvm:: Value > = cx. get_function ( & item. target ) ;
497- let Some ( fn_target ) = fn_target else {
473+ let Some ( _fn_target ) = fn_target else {
498474 return Err ( llvm_err (
499475 diag_handler. handle ( ) ,
500476 LlvmError :: PrepareAutoDiff {
@@ -505,7 +481,7 @@ pub(crate) fn differentiate<'ll>(
505481 ) ) ;
506482 } ;
507483
508- generate_enzyme_call ( & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
484+ // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
509485 }
510486
511487 // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
0 commit comments