@@ -11,6 +11,7 @@ mod llvm_enzyme {
1111 AutoDiffAttrs , DiffActivity , DiffMode , valid_input_activity, valid_ret_activity,
1212 valid_ty_for_activity,
1313 } ;
14+ use rustc_ast:: expand:: typetree:: { TypeTree , Type , Kind } ;
1415 use rustc_ast:: ptr:: P ;
1516 use rustc_ast:: token:: { Lit , LitKind , Token , TokenKind } ;
1617 use rustc_ast:: tokenstream:: * ;
@@ -25,7 +26,6 @@ mod llvm_enzyme {
2526 use tracing:: { debug, trace} ;
2627
2728 use crate :: errors;
28- use crate :: expand:: typetree:: TypeTree ;
2929
3030 pub ( crate ) fn outer_normal_attr (
3131 kind : & P < rustc_ast:: NormalAttr > ,
@@ -325,10 +325,9 @@ mod llvm_enzyme {
325325 }
326326 let span = ecx. with_def_site_ctxt ( expand_span) ;
327327
328- // Prepare placeholder type trees for now
329- let num_args = sig. decl . inputs . len ( ) ;
330- let inputs = vec ! [ TypeTree :: new( ) ; num_args] ;
331- let output = TypeTree :: new ( ) ;
328+ // Construct real type trees from function signature
329+ let ( inputs, output) = construct_typetree_from_fnsig ( & sig) ;
330+
332331 // Use the new into_item method to construct the AutoDiffItem
333332 let autodiff_item = x. clone ( ) . into_item (
334333 primal. to_string ( ) ,
@@ -1059,4 +1058,105 @@ mod llvm_enzyme {
10591058 }
10601059}
10611060
1061+ #[ cfg( llvm_enzyme) ]
1062+ fn construct_typetree_from_ty ( ty : & ast:: Ty ) -> TypeTree {
1063+ match & ty. kind {
1064+ TyKind :: Path ( ..) => {
1065+ // Handle basic types like f32, f64, i32, etc.
1066+ // For now, we'll use a simple heuristic based on the path
1067+ // In a full implementation, this would need to be more sophisticated
1068+ TypeTree ( vec ! [ Type {
1069+ offset: 0 ,
1070+ size: 8 , // Default size, should be computed properly
1071+ kind: Kind :: Float , // Default to float, should be determined from type
1072+ child: TypeTree :: new( ) ,
1073+ } ] )
1074+ }
1075+ TyKind :: Ptr ( ptr_ty) => {
1076+ TypeTree ( vec ! [ Type {
1077+ offset: 0 ,
1078+ size: 8 , // Pointer size
1079+ kind: Kind :: Pointer ,
1080+ child: construct_typetree_from_ty( & ptr_ty. ty) ,
1081+ } ] )
1082+ }
1083+ TyKind :: Ref ( _, ref_ty) => {
1084+ TypeTree ( vec ! [ Type {
1085+ offset: 0 ,
1086+ size: 8 , // Reference size
1087+ kind: Kind :: Pointer ,
1088+ child: construct_typetree_from_ty( & ref_ty. ty) ,
1089+ } ] )
1090+ }
1091+ TyKind :: Slice ( slice_ty) => {
1092+ TypeTree ( vec ! [ Type {
1093+ offset: 0 ,
1094+ size: 16 , // Slice is (ptr, len)
1095+ kind: Kind :: Pointer ,
1096+ child: construct_typetree_from_ty( & slice_ty. ty) ,
1097+ } ] )
1098+ }
1099+ TyKind :: Array ( array_ty) => {
1100+ // For arrays, we need to handle the size properly
1101+ let elem_ty = construct_typetree_from_ty ( & array_ty. ty ) ;
1102+ TypeTree ( vec ! [ Type {
1103+ offset: 0 ,
1104+ size: 8 , // Array size depends on element type and count
1105+ kind: Kind :: Pointer ,
1106+ child: elem_ty,
1107+ } ] )
1108+ }
1109+ TyKind :: Tup ( tuple_types) => {
1110+ let mut types = Vec :: new ( ) ;
1111+ let mut offset = 0 ;
1112+ for ( i, tuple_ty) in tuple_types. iter ( ) . enumerate ( ) {
1113+ let elem_ty = construct_typetree_from_ty ( tuple_ty) ;
1114+ // For tuples, we need to handle alignment and padding
1115+ // This is a simplified version
1116+ types. push ( Type {
1117+ offset : offset as isize ,
1118+ size : 8 , // Should be computed based on actual type
1119+ kind : Kind :: Float , // Default
1120+ child : elem_ty,
1121+ } ) ;
1122+ offset += 8 ; // Simplified alignment
1123+ }
1124+ TypeTree ( types)
1125+ }
1126+ _ => {
1127+ // Default case for unknown types
1128+ TypeTree ( vec ! [ Type {
1129+ offset: 0 ,
1130+ size: 8 ,
1131+ kind: Kind :: Float ,
1132+ child: TypeTree :: new( ) ,
1133+ } ] )
1134+ }
1135+ }
1136+ }
1137+
1138+ #[ cfg( llvm_enzyme) ]
1139+ fn construct_typetree_from_fnsig ( sig : & ast:: FnSig ) -> ( Vec < TypeTree > , TypeTree ) {
1140+ // Construct type trees for input arguments
1141+ let inputs: Vec < TypeTree > = sig. decl . inputs . iter ( )
1142+ . map ( |param| construct_typetree_from_ty ( & param. ty ) )
1143+ . collect ( ) ;
1144+
1145+ // Construct type tree for return type
1146+ let output = match & sig. decl . output {
1147+ FnRetTy :: Default ( span) => {
1148+ // Unit type ()
1149+ TypeTree ( vec ! [ Type {
1150+ offset: 0 ,
1151+ size: 0 ,
1152+ kind: Kind :: Integer ,
1153+ child: TypeTree :: new( ) ,
1154+ } ] )
1155+ }
1156+ FnRetTy :: Ty ( ty) => construct_typetree_from_ty ( ty) ,
1157+ } ;
1158+
1159+ ( inputs, output)
1160+ }
1161+
10621162pub ( crate ) use llvm_enzyme:: { expand_forward, expand_reverse} ;
0 commit comments