@@ -66,6 +66,7 @@ use rustc_hir::lang_items::LangItem;
6666use rustc_hir:: Node ;
6767use rustc_middle:: dep_graph:: DepContext ;
6868use rustc_middle:: ty:: print:: with_no_trimmed_paths;
69+ use rustc_middle:: ty:: relate:: { self , RelateResult , TypeRelation } ;
6970use rustc_middle:: ty:: {
7071 self , error:: TypeError , Binder , List , Region , Subst , Ty , TyCtxt , TypeFoldable ,
7172 TypeSuperVisitable , TypeVisitable ,
@@ -2661,67 +2662,92 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
26612662 /// Float types, respectively). When comparing two ADTs, these rules apply recursively.
26622663 pub fn same_type_modulo_infer ( & self , a : Ty < ' tcx > , b : Ty < ' tcx > ) -> bool {
26632664 let ( a, b) = self . resolve_vars_if_possible ( ( a, b) ) ;
2664- match ( a. kind ( ) , b. kind ( ) ) {
2665- ( & ty:: Adt ( def_a, substs_a) , & ty:: Adt ( def_b, substs_b) ) => {
2666- if def_a != def_b {
2667- return false ;
2668- }
2665+ SameTypeModuloInfer ( self ) . relate ( a, b) . is_ok ( )
2666+ }
2667+ }
26692668
2670- substs_a
2671- . types ( )
2672- . zip ( substs_b. types ( ) )
2673- . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2674- }
2675- ( & ty:: FnDef ( did_a, substs_a) , & ty:: FnDef ( did_b, substs_b) ) => {
2676- if did_a != did_b {
2677- return false ;
2678- }
2669+ struct SameTypeModuloInfer < ' a , ' tcx > ( & ' a InferCtxt < ' a , ' tcx > ) ;
26792670
2680- substs_a
2681- . types ( )
2682- . zip ( substs_b. types ( ) )
2683- . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2684- }
2685- ( & ty:: Int ( _) | & ty:: Uint ( _) , & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) )
2671+ impl < ' tcx > TypeRelation < ' tcx > for SameTypeModuloInfer < ' _ , ' tcx > {
2672+ fn tcx ( & self ) -> TyCtxt < ' tcx > {
2673+ self . 0 . tcx
2674+ }
2675+
2676+ fn param_env ( & self ) -> ty:: ParamEnv < ' tcx > {
2677+ // Unused, only for consts which we treat as always equal
2678+ ty:: ParamEnv :: empty ( )
2679+ }
2680+
2681+ fn tag ( & self ) -> & ' static str {
2682+ "SameTypeModuloInfer"
2683+ }
2684+
2685+ fn a_is_expected ( & self ) -> bool {
2686+ true
2687+ }
2688+
2689+ fn relate_with_variance < T : relate:: Relate < ' tcx > > (
2690+ & mut self ,
2691+ _variance : ty:: Variance ,
2692+ _info : ty:: VarianceDiagInfo < ' tcx > ,
2693+ a : T ,
2694+ b : T ,
2695+ ) -> relate:: RelateResult < ' tcx , T > {
2696+ self . relate ( a, b)
2697+ }
2698+
2699+ fn tys ( & mut self , a : Ty < ' tcx > , b : Ty < ' tcx > ) -> RelateResult < ' tcx , Ty < ' tcx > > {
2700+ match ( a. kind ( ) , b. kind ( ) ) {
2701+ ( ty:: Int ( _) | ty:: Uint ( _) , ty:: Infer ( ty:: InferTy :: IntVar ( _) ) )
26862702 | (
2687- & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2688- & ty:: Int ( _) | & ty:: Uint ( _) | & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2703+ ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2704+ ty:: Int ( _) | ty:: Uint ( _) | ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
26892705 )
2690- | ( & ty:: Float ( _) , & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) )
2706+ | ( ty:: Float ( _) , ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) )
26912707 | (
2692- & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2693- & ty:: Float ( _) | & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2708+ ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2709+ ty:: Float ( _) | ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
26942710 )
2695- | ( & ty:: Infer ( ty:: InferTy :: TyVar ( _) ) , _)
2696- | ( _, & ty:: Infer ( ty:: InferTy :: TyVar ( _) ) ) => true ,
2697- ( & ty:: Ref ( _, ty_a, mut_a) , & ty:: Ref ( _, ty_b, mut_b) ) => {
2698- mut_a == mut_b && self . same_type_modulo_infer ( ty_a, ty_b)
2699- }
2700- ( & ty:: RawPtr ( a) , & ty:: RawPtr ( b) ) => {
2701- a. mutbl == b. mutbl && self . same_type_modulo_infer ( a. ty , b. ty )
2702- }
2703- ( & ty:: Slice ( a) , & ty:: Slice ( b) ) => self . same_type_modulo_infer ( a, b) ,
2704- ( & ty:: Array ( a_ty, a_ct) , & ty:: Array ( b_ty, b_ct) ) => {
2705- self . same_type_modulo_infer ( a_ty, b_ty) && a_ct == b_ct
2706- }
2707- ( & ty:: Tuple ( a) , & ty:: Tuple ( b) ) => {
2708- if a. len ( ) != b. len ( ) {
2709- return false ;
2710- }
2711- std:: iter:: zip ( a. iter ( ) , b. iter ( ) ) . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2712- }
2713- ( & ty:: FnPtr ( a) , & ty:: FnPtr ( b) ) => {
2714- let a = a. skip_binder ( ) . inputs_and_output ;
2715- let b = b. skip_binder ( ) . inputs_and_output ;
2716- if a. len ( ) != b. len ( ) {
2717- return false ;
2718- }
2719- std:: iter:: zip ( a. iter ( ) , b. iter ( ) ) . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2720- }
2721- // FIXME(compiler-errors): This needs to be generalized more
2722- _ => a == b,
2711+ | ( ty:: Infer ( ty:: InferTy :: TyVar ( _) ) , _)
2712+ | ( _, ty:: Infer ( ty:: InferTy :: TyVar ( _) ) ) => Ok ( a) ,
2713+ ( ty:: Infer ( _) , _) | ( _, ty:: Infer ( _) ) => Err ( TypeError :: Mismatch ) ,
2714+ _ => relate:: super_relate_tys ( self , a, b) ,
27232715 }
27242716 }
2717+
2718+ fn regions (
2719+ & mut self ,
2720+ a : ty:: Region < ' tcx > ,
2721+ b : ty:: Region < ' tcx > ,
2722+ ) -> RelateResult < ' tcx , ty:: Region < ' tcx > > {
2723+ if ( a. is_var ( ) && b. is_free_or_static ( ) ) || ( b. is_var ( ) && a. is_free_or_static ( ) ) || a == b
2724+ {
2725+ Ok ( a)
2726+ } else {
2727+ Err ( TypeError :: Mismatch )
2728+ }
2729+ }
2730+
2731+ fn binders < T > (
2732+ & mut self ,
2733+ a : ty:: Binder < ' tcx , T > ,
2734+ b : ty:: Binder < ' tcx , T > ,
2735+ ) -> relate:: RelateResult < ' tcx , ty:: Binder < ' tcx , T > >
2736+ where
2737+ T : relate:: Relate < ' tcx > ,
2738+ {
2739+ Ok ( ty:: Binder :: dummy ( self . relate ( a. skip_binder ( ) , b. skip_binder ( ) ) ?) )
2740+ }
2741+
2742+ fn consts (
2743+ & mut self ,
2744+ a : ty:: Const < ' tcx > ,
2745+ _b : ty:: Const < ' tcx > ,
2746+ ) -> relate:: RelateResult < ' tcx , ty:: Const < ' tcx > > {
2747+ // FIXME(compiler-errors): This could at least do some first-order
2748+ // relation
2749+ Ok ( a)
2750+ }
27252751}
27262752
27272753impl < ' a , ' tcx > InferCtxt < ' a , ' tcx > {
0 commit comments