@@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *};
2525pub use generics:: * ;
2626pub use intrinsic:: IntrinsicDef ;
2727use rustc_abi:: { Align , FieldIdx , Integer , IntegerType , ReprFlags , ReprOptions , VariantIdx } ;
28+ use rustc_ast:: expand:: typetree:: { FncTree , Kind , Type , TypeTree } ;
2829use rustc_ast:: node_id:: NodeMap ;
2930pub use rustc_ast_ir:: { Movability , Mutability , try_visit} ;
3031use rustc_data_structures:: fx:: { FxHashMap , FxHashSet , FxIndexMap , FxIndexSet } ;
@@ -2216,3 +2217,82 @@ pub struct DestructuredConst<'tcx> {
22162217 pub variant : Option < VariantIdx > ,
22172218 pub fields : & ' tcx [ ty:: Const < ' tcx > ] ,
22182219}
2220+
2221+ /// Generate TypeTree information for autodiff.
2222+ /// This function creates TypeTree metadata that describes the memory layout
2223+ /// of function parameters and return types for Enzyme autodiff.
2224+ pub fn fnc_typetrees < ' tcx > ( tcx : TyCtxt < ' tcx > , fn_ty : Ty < ' tcx > ) -> FncTree {
2225+ // Check if TypeTrees are disabled via NoTT flag
2226+ if tcx. sess . opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: NoTT ) {
2227+ return FncTree { args : vec ! [ ] , ret : TypeTree :: new ( ) } ;
2228+ }
2229+
2230+ // Check if this is actually a function type
2231+ if !fn_ty. is_fn ( ) {
2232+ return FncTree { args : vec ! [ ] , ret : TypeTree :: new ( ) } ;
2233+ }
2234+
2235+ // Get the function signature
2236+ let fn_sig = fn_ty. fn_sig ( tcx) ;
2237+ let sig = tcx. instantiate_bound_regions_with_erased ( fn_sig) ;
2238+
2239+ // Create TypeTrees for each input parameter
2240+ let mut args = vec ! [ ] ;
2241+ for ty in sig. inputs ( ) . iter ( ) {
2242+ let type_tree = typetree_from_ty ( tcx, * ty) ;
2243+ args. push ( type_tree) ;
2244+ }
2245+
2246+ // Create TypeTree for return type
2247+ let ret = typetree_from_ty ( tcx, sig. output ( ) ) ;
2248+
2249+ FncTree { args, ret }
2250+ }
2251+
2252+ /// Generate TypeTree for a specific type.
2253+ /// This function analyzes a Rust type and creates appropriate TypeTree metadata.
2254+ fn typetree_from_ty < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> TypeTree {
2255+ // Handle basic scalar types
2256+ if ty. is_scalar ( ) {
2257+ let ( kind, size) = if ty. is_integral ( ) || ty. is_char ( ) || ty. is_bool ( ) {
2258+ ( Kind :: Integer , ty. primitive_size ( tcx) . bytes_usize ( ) )
2259+ } else if ty. is_floating_point ( ) {
2260+ match ty {
2261+ x if x == tcx. types . f32 => ( Kind :: Float , 4 ) ,
2262+ x if x == tcx. types . f64 => ( Kind :: Double , 8 ) ,
2263+ _ => return TypeTree :: new ( ) , // Unknown float type
2264+ }
2265+ } else {
2266+ // TODO(KMJ-007): Handle other scalar types if needed
2267+ return TypeTree :: new ( ) ;
2268+ } ;
2269+
2270+ return TypeTree ( vec ! [ Type {
2271+ offset: -1 ,
2272+ size,
2273+ kind,
2274+ child: TypeTree :: new( )
2275+ } ] ) ;
2276+ }
2277+
2278+ // Handle references and pointers
2279+ if ty. is_ref ( ) || ty. is_raw_ptr ( ) || ty. is_box ( ) {
2280+ let inner_ty = if let Some ( inner) = ty. builtin_deref ( true ) {
2281+ inner
2282+ } else {
2283+ // TODO(KMJ-007): Handle complex pointer types
2284+ return TypeTree :: new ( ) ;
2285+ } ;
2286+
2287+ let child = typetree_from_ty ( tcx, inner_ty) ;
2288+ return TypeTree ( vec ! [ Type {
2289+ offset: -1 ,
2290+ size: 8 , // TODO(KMJ-007): Get actual pointer size from target
2291+ kind: Kind :: Pointer ,
2292+ child,
2293+ } ] ) ;
2294+ }
2295+
2296+ // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
2297+ TypeTree :: new ( )
2298+ }
0 commit comments