|
12 | 12 | #![allow(rustc::usage_of_ty_tykind)] |
13 | 13 | #![allow(unused_imports)] |
14 | 14 |
|
15 | | -use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree}; |
16 | 15 | use rustc_target::abi::FieldsShape; |
17 | 16 |
|
18 | 17 | pub use self::fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable}; |
@@ -75,6 +74,7 @@ pub use rustc_type_ir::ConstKind::{ |
75 | 74 | }; |
76 | 75 | pub use rustc_type_ir::*; |
77 | 76 |
|
| 77 | +pub use self::typetree::*; |
78 | 78 | pub use self::binding::BindingMode; |
79 | 79 | pub use self::binding::BindingMode::*; |
80 | 80 | pub use self::closure::{ |
@@ -127,6 +127,7 @@ pub mod util; |
127 | 127 | pub mod visit; |
128 | 128 | pub mod vtable; |
129 | 129 | pub mod walk; |
| 130 | +pub mod typetree; |
130 | 131 |
|
131 | 132 | mod adt; |
132 | 133 | mod assoc; |
@@ -2721,306 +2722,3 @@ mod size_asserts { |
2721 | 2722 | // tidy-alphabetical-end |
2722 | 2723 | } |
2723 | 2724 |
|
2724 | | -pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { |
2725 | | - let mut visited = vec![]; |
2726 | | - let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None); |
2727 | | - let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty }; |
2728 | | - return TypeTree(vec![tt]); |
2729 | | -} |
2730 | | - |
2731 | | -use rustc_ast::expand::autodiff_attrs::DiffActivity; |
2732 | | - |
2733 | | -// This function combines three tasks. To avoid traversing each type 3x, we combine them. |
2734 | | -// 1. Create a TypeTree from a Ty. This is the main task. |
2735 | | -// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM |
2736 | | -// lowering. E.g. fat ptr are going to introduce an extra int. |
2737 | | -// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an |
2738 | | -// autodiff macro on top). Here we want to make sure that shadows are mutable internally. |
2739 | | -// We know the outermost ref/ptr indirection is mutability - we generate it like that. |
2740 | | -// We now have to make sure that inner ptr/ref are mutable too, or issue a warning. |
2741 | | -// Not an error, becaues it only causes issues if they are actually read, which we don't check |
2742 | | -// yet. We should add such analysis to relibably either issue an error or accept without warning. |
2743 | | -// If there only were some reasearch to do that... |
2744 | | -pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>, span: Option<Span>) -> FncTree { |
2745 | | - if !fn_ty.is_fn() { |
2746 | | - return FncTree { args: vec![], ret: TypeTree::new() }; |
2747 | | - } |
2748 | | - let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); |
2749 | | - |
2750 | | - // If rustc compiles the unmodified primal, we know that this copy of the function |
2751 | | - // also has correct lifetimes. We know that Enzyme won't free the shadow too early |
2752 | | - // (or actually at all), so let's strip lifetimes when computing the layout. |
2753 | | - // Recommended by compiler-errors: |
2754 | | - // https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751 |
2755 | | - let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); |
2756 | | - |
2757 | | - let mut new_activities = vec![]; |
2758 | | - let mut new_positions = vec![]; |
2759 | | - let mut visited = vec![]; |
2760 | | - let mut args = vec![]; |
2761 | | - for (i, ty) in x.inputs().iter().enumerate() { |
2762 | | - // We care about safety checks, if an argument get's duplicated and we write into the |
2763 | | - // shadow. That's equivalent to Duplicated or DuplicatedOnly. |
2764 | | - let safety = if !da.is_empty() { |
2765 | | - assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len()); |
2766 | | - // If we have Activities, we also have spans |
2767 | | - assert!(span.is_some()); |
2768 | | - match da[i] { |
2769 | | - DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => true, |
2770 | | - _ => false, |
2771 | | - } |
2772 | | - } else { |
2773 | | - false |
2774 | | - }; |
2775 | | - |
2776 | | - visited.clear(); |
2777 | | - if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { |
2778 | | - if ty.is_fn_ptr() { |
2779 | | - unimplemented!("what to do whith fn ptr?"); |
2780 | | - } |
2781 | | - let inner_ty = ty.builtin_deref(true).unwrap().ty; |
2782 | | - if inner_ty.is_slice() { |
2783 | | - // We know that the lenght will be passed as extra arg. |
2784 | | - let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span); |
2785 | | - let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; |
2786 | | - args.push(TypeTree(vec![tt])); |
2787 | | - let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() }; |
2788 | | - args.push(TypeTree(vec![i64_tt])); |
2789 | | - if !da.is_empty() { |
2790 | | - // We are looking at a slice. The length of that slice will become an |
2791 | | - // extra integer on llvm level. Integers are always const. |
2792 | | - // However, if the slice get's duplicated, we want to know to later check the |
2793 | | - // size. So we mark the new size argument as FakeActivitySize. |
2794 | | - let activity = match da[i] { |
2795 | | - DiffActivity::DualOnly | DiffActivity::Dual | |
2796 | | - DiffActivity::DuplicatedOnly | DiffActivity::Duplicated |
2797 | | - => DiffActivity::FakeActivitySize, |
2798 | | - DiffActivity::Const => DiffActivity::Const, |
2799 | | - _ => panic!("unexpected activity for ptr/ref"), |
2800 | | - }; |
2801 | | - new_activities.push(activity); |
2802 | | - new_positions.push(i + 1); |
2803 | | - } |
2804 | | - trace!("ABI MATCHING!"); |
2805 | | - continue; |
2806 | | - } |
2807 | | - } |
2808 | | - let arg_tt = typetree_from_ty(*ty, tcx, 0, safety, &mut visited, span); |
2809 | | - args.push(arg_tt); |
2810 | | - } |
2811 | | - |
2812 | | - // now add the extra activities coming from slices |
2813 | | - // Reverse order to not invalidate the indices |
2814 | | - for _ in 0..new_activities.len() { |
2815 | | - let pos = new_positions.pop().unwrap(); |
2816 | | - let activity = new_activities.pop().unwrap(); |
2817 | | - da.insert(pos, activity); |
2818 | | - } |
2819 | | - |
2820 | | - visited.clear(); |
2821 | | - let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited, span); |
2822 | | - |
2823 | | - FncTree { args, ret } |
2824 | | -} |
2825 | | - |
2826 | | -fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec<Ty<'a>>, span: Option<Span>) -> TypeTree { |
2827 | | - if depth > 20 { |
2828 | | - trace!("depth > 20 for ty: {}", &ty); |
2829 | | - } |
2830 | | - if visited.contains(&ty) { |
2831 | | - // recursive type |
2832 | | - trace!("recursive type: {}", &ty); |
2833 | | - return TypeTree::new(); |
2834 | | - } |
2835 | | - visited.push(ty); |
2836 | | - |
2837 | | - if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { |
2838 | | - if ty.is_fn_ptr() { |
2839 | | - unimplemented!("what to do whith fn ptr?"); |
2840 | | - } |
2841 | | - |
2842 | | - let inner_ty_and_mut = ty.builtin_deref(true).unwrap(); |
2843 | | - let is_mut = inner_ty_and_mut.mutbl == hir::Mutability::Mut; |
2844 | | - let inner_ty = inner_ty_and_mut.ty; |
2845 | | - |
2846 | | - // Now account for inner mutability. |
2847 | | - if !is_mut && depth > 0 && safety { |
2848 | | - let ptr_ty: String = if ty.is_ref() { |
2849 | | - "ref" |
2850 | | - } else if ty.is_unsafe_ptr() { |
2851 | | - "ptr" |
2852 | | - } else { |
2853 | | - assert!(ty.is_box()); |
2854 | | - "box" |
2855 | | - }.to_string(); |
2856 | | - |
2857 | | - // If we have mutability, we also have a span |
2858 | | - assert!(span.is_some()); |
2859 | | - let span = span.unwrap(); |
2860 | | - |
2861 | | - tcx.sess |
2862 | | - .dcx() |
2863 | | - .emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty}); |
2864 | | - } |
2865 | | - |
2866 | | - //visited.push(inner_ty); |
2867 | | - let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); |
2868 | | - let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; |
2869 | | - visited.pop(); |
2870 | | - return TypeTree(vec![tt]); |
2871 | | - } |
2872 | | - |
2873 | | - |
2874 | | - if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() { |
2875 | | - visited.pop(); |
2876 | | - return TypeTree::new(); |
2877 | | - } |
2878 | | - |
2879 | | - if ty.is_scalar() { |
2880 | | - let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { |
2881 | | - (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) |
2882 | | - } else if ty.is_floating_point() { |
2883 | | - match ty { |
2884 | | - x if x == tcx.types.f32 => (Kind::Float, 4), |
2885 | | - x if x == tcx.types.f64 => (Kind::Double, 8), |
2886 | | - _ => panic!("floatTy scalar that is neither f32 nor f64"), |
2887 | | - } |
2888 | | - } else { |
2889 | | - panic!("scalar that is neither integral nor floating point"); |
2890 | | - }; |
2891 | | - visited.pop(); |
2892 | | - return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]); |
2893 | | - } |
2894 | | - |
2895 | | - let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; |
2896 | | - |
2897 | | - let layout = tcx.layout_of(param_env_and); |
2898 | | - assert!(layout.is_ok()); |
2899 | | - |
2900 | | - let layout = layout.unwrap().layout; |
2901 | | - let fields = layout.fields(); |
2902 | | - let max_size = layout.size(); |
2903 | | - |
2904 | | - |
2905 | | - |
2906 | | - if ty.is_adt() && !ty.is_simd() { |
2907 | | - let adt_def = ty.ty_adt_def().unwrap(); |
2908 | | - |
2909 | | - if adt_def.is_struct() { |
2910 | | - let (offsets, _memory_index) = match fields { |
2911 | | - // Manuel TODO: |
2912 | | - FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), |
2913 | | - FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later |
2914 | | - FieldsShape::Union(_) => {return TypeTree::new();}, |
2915 | | - FieldsShape::Primitive => {return TypeTree::new();}, |
2916 | | - }; |
2917 | | - |
2918 | | - let substs = match ty.kind() { |
2919 | | - Adt(_, subst_ref) => subst_ref, |
2920 | | - _ => panic!(""), |
2921 | | - }; |
2922 | | - |
2923 | | - let fields = adt_def.all_fields(); |
2924 | | - let fields = fields |
2925 | | - .into_iter() |
2926 | | - .zip(offsets.into_iter()) |
2927 | | - .filter_map(|(field, offset)| { |
2928 | | - let field_ty: Ty<'_> = field.ty(tcx, substs); |
2929 | | - let field_ty: Ty<'_> = |
2930 | | - tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); |
2931 | | - |
2932 | | - if field_ty.is_phantom_data() { |
2933 | | - return None; |
2934 | | - } |
2935 | | - |
2936 | | - //visited.push(field_ty); |
2937 | | - let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0; |
2938 | | - |
2939 | | - for c in &mut child { |
2940 | | - if c.offset == -1 { |
2941 | | - c.offset = offset.bytes() as isize |
2942 | | - } else { |
2943 | | - c.offset += offset.bytes() as isize; |
2944 | | - } |
2945 | | - } |
2946 | | - |
2947 | | - Some(child) |
2948 | | - }) |
2949 | | - .flatten() |
2950 | | - .collect::<Vec<Type>>(); |
2951 | | - |
2952 | | - visited.pop(); |
2953 | | - let ret_tt = TypeTree(fields); |
2954 | | - return ret_tt; |
2955 | | - } else if adt_def.is_enum() { |
2956 | | - // Enzyme can't represent enums, so let it figure it out itself, without seeeding |
2957 | | - // typetree |
2958 | | - //unimplemented!("adt that is an enum"); |
2959 | | - } else { |
2960 | | - //let ty_name = tcx.def_path_debug_str(adt_def.did()); |
2961 | | - //tcx.sess.emit_fatal(UnsupportedUnion { ty_name }); |
2962 | | - } |
2963 | | - } |
2964 | | - |
2965 | | - if ty.is_simd() { |
2966 | | - trace!("simd"); |
2967 | | - let (_size, inner_ty) = ty.simd_size_and_type(tcx); |
2968 | | - //visited.push(inner_ty); |
2969 | | - let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); |
2970 | | - //let tt = TypeTree( |
2971 | | - // std::iter::repeat(subtt) |
2972 | | - // .take(*count as usize) |
2973 | | - // .enumerate() |
2974 | | - // .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) |
2975 | | - // .flatten() |
2976 | | - // .collect(), |
2977 | | - //); |
2978 | | - // TODO |
2979 | | - visited.pop(); |
2980 | | - return TypeTree::new(); |
2981 | | - } |
2982 | | - |
2983 | | - if ty.is_array() { |
2984 | | - let (stride, count) = match fields { |
2985 | | - FieldsShape::Array { stride: s, count: c } => (s, c), |
2986 | | - _ => panic!(""), |
2987 | | - }; |
2988 | | - let byte_stride = stride.bytes_usize(); |
2989 | | - let byte_max_size = max_size.bytes_usize(); |
2990 | | - |
2991 | | - assert!(byte_stride * *count as usize == byte_max_size); |
2992 | | - if (*count as usize) == 0 { |
2993 | | - return TypeTree::new(); |
2994 | | - } |
2995 | | - let sub_ty = ty.builtin_index().unwrap(); |
2996 | | - //visited.push(sub_ty); |
2997 | | - let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); |
2998 | | - |
2999 | | - // calculate size of subtree |
3000 | | - let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; |
3001 | | - let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; |
3002 | | - let tt = TypeTree( |
3003 | | - std::iter::repeat(subtt) |
3004 | | - .take(*count as usize) |
3005 | | - .enumerate() |
3006 | | - .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) |
3007 | | - .flatten() |
3008 | | - .collect(), |
3009 | | - ); |
3010 | | - |
3011 | | - visited.pop(); |
3012 | | - return tt; |
3013 | | - } |
3014 | | - |
3015 | | - if ty.is_slice() { |
3016 | | - let sub_ty = ty.builtin_index().unwrap(); |
3017 | | - //visited.push(sub_ty); |
3018 | | - let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); |
3019 | | - |
3020 | | - visited.pop(); |
3021 | | - return subtt; |
3022 | | - } |
3023 | | - |
3024 | | - visited.pop(); |
3025 | | - TypeTree::new() |
3026 | | -} |
0 commit comments