@@ -66,6 +66,7 @@ use rustc_hir::lang_items::LangItem;
6666use rustc_hir:: Node ;
6767use rustc_middle:: dep_graph:: DepContext ;
6868use rustc_middle:: ty:: print:: with_no_trimmed_paths;
69+ use rustc_middle:: ty:: relate:: { self , RelateResult , TypeRelation } ;
6970use rustc_middle:: ty:: {
7071 self , error:: TypeError , Binder , List , Region , Subst , Ty , TyCtxt , TypeFoldable ,
7172 TypeSuperVisitable , TypeVisitable ,
@@ -2660,67 +2661,92 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
26602661 /// Float types, respectively). When comparing two ADTs, these rules apply recursively.
26612662 pub fn same_type_modulo_infer ( & self , a : Ty < ' tcx > , b : Ty < ' tcx > ) -> bool {
26622663 let ( a, b) = self . resolve_vars_if_possible ( ( a, b) ) ;
2663- match ( a. kind ( ) , b. kind ( ) ) {
2664- ( & ty:: Adt ( def_a, substs_a) , & ty:: Adt ( def_b, substs_b) ) => {
2665- if def_a != def_b {
2666- return false ;
2667- }
2664+ SameTypeModuloInfer ( self ) . relate ( a, b) . is_ok ( )
2665+ }
2666+ }
26682667
2669- substs_a
2670- . types ( )
2671- . zip ( substs_b. types ( ) )
2672- . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2673- }
2674- ( & ty:: FnDef ( did_a, substs_a) , & ty:: FnDef ( did_b, substs_b) ) => {
2675- if did_a != did_b {
2676- return false ;
2677- }
2668+ struct SameTypeModuloInfer < ' a , ' tcx > ( & ' a InferCtxt < ' a , ' tcx > ) ;
26782669
2679- substs_a
2680- . types ( )
2681- . zip ( substs_b. types ( ) )
2682- . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2683- }
2684- ( & ty:: Int ( _) | & ty:: Uint ( _) , & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) )
2670+ impl < ' tcx > TypeRelation < ' tcx > for SameTypeModuloInfer < ' _ , ' tcx > {
2671+ fn tcx ( & self ) -> TyCtxt < ' tcx > {
2672+ self . 0 . tcx
2673+ }
2674+
2675+ fn param_env ( & self ) -> ty:: ParamEnv < ' tcx > {
2676+ // Unused, only for consts which we treat as always equal
2677+ ty:: ParamEnv :: empty ( )
2678+ }
2679+
2680+ fn tag ( & self ) -> & ' static str {
2681+ "SameTypeModuloInfer"
2682+ }
2683+
2684+ fn a_is_expected ( & self ) -> bool {
2685+ true
2686+ }
2687+
2688+ fn relate_with_variance < T : relate:: Relate < ' tcx > > (
2689+ & mut self ,
2690+ _variance : ty:: Variance ,
2691+ _info : ty:: VarianceDiagInfo < ' tcx > ,
2692+ a : T ,
2693+ b : T ,
2694+ ) -> relate:: RelateResult < ' tcx , T > {
2695+ self . relate ( a, b)
2696+ }
2697+
2698+ fn tys ( & mut self , a : Ty < ' tcx > , b : Ty < ' tcx > ) -> RelateResult < ' tcx , Ty < ' tcx > > {
2699+ match ( a. kind ( ) , b. kind ( ) ) {
2700+ ( ty:: Int ( _) | ty:: Uint ( _) , ty:: Infer ( ty:: InferTy :: IntVar ( _) ) )
26852701 | (
2686- & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2687- & ty:: Int ( _) | & ty:: Uint ( _) | & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2702+ ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2703+ ty:: Int ( _) | ty:: Uint ( _) | ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
26882704 )
2689- | ( & ty:: Float ( _) , & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) )
2705+ | ( ty:: Float ( _) , ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) )
26902706 | (
2691- & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2692- & ty:: Float ( _) | & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2707+ ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2708+ ty:: Float ( _) | ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
26932709 )
2694- | ( & ty:: Infer ( ty:: InferTy :: TyVar ( _) ) , _)
2695- | ( _, & ty:: Infer ( ty:: InferTy :: TyVar ( _) ) ) => true ,
2696- ( & ty:: Ref ( _, ty_a, mut_a) , & ty:: Ref ( _, ty_b, mut_b) ) => {
2697- mut_a == mut_b && self . same_type_modulo_infer ( ty_a, ty_b)
2698- }
2699- ( & ty:: RawPtr ( a) , & ty:: RawPtr ( b) ) => {
2700- a. mutbl == b. mutbl && self . same_type_modulo_infer ( a. ty , b. ty )
2701- }
2702- ( & ty:: Slice ( a) , & ty:: Slice ( b) ) => self . same_type_modulo_infer ( a, b) ,
2703- ( & ty:: Array ( a_ty, a_ct) , & ty:: Array ( b_ty, b_ct) ) => {
2704- self . same_type_modulo_infer ( a_ty, b_ty) && a_ct == b_ct
2705- }
2706- ( & ty:: Tuple ( a) , & ty:: Tuple ( b) ) => {
2707- if a. len ( ) != b. len ( ) {
2708- return false ;
2709- }
2710- std:: iter:: zip ( a. iter ( ) , b. iter ( ) ) . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2711- }
2712- ( & ty:: FnPtr ( a) , & ty:: FnPtr ( b) ) => {
2713- let a = a. skip_binder ( ) . inputs_and_output ;
2714- let b = b. skip_binder ( ) . inputs_and_output ;
2715- if a. len ( ) != b. len ( ) {
2716- return false ;
2717- }
2718- std:: iter:: zip ( a. iter ( ) , b. iter ( ) ) . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2719- }
2720- // FIXME(compiler-errors): This needs to be generalized more
2721- _ => a == b,
2710+ | ( ty:: Infer ( ty:: InferTy :: TyVar ( _) ) , _)
2711+ | ( _, ty:: Infer ( ty:: InferTy :: TyVar ( _) ) ) => Ok ( a) ,
2712+ ( ty:: Infer ( _) , _) | ( _, ty:: Infer ( _) ) => Err ( TypeError :: Mismatch ) ,
2713+ _ => relate:: super_relate_tys ( self , a, b) ,
27222714 }
27232715 }
2716+
2717+ fn regions (
2718+ & mut self ,
2719+ a : ty:: Region < ' tcx > ,
2720+ b : ty:: Region < ' tcx > ,
2721+ ) -> RelateResult < ' tcx , ty:: Region < ' tcx > > {
2722+ if ( a. is_var ( ) && b. is_free_or_static ( ) ) || ( b. is_var ( ) && a. is_free_or_static ( ) ) || a == b
2723+ {
2724+ Ok ( a)
2725+ } else {
2726+ Err ( TypeError :: Mismatch )
2727+ }
2728+ }
2729+
2730+ fn binders < T > (
2731+ & mut self ,
2732+ a : ty:: Binder < ' tcx , T > ,
2733+ b : ty:: Binder < ' tcx , T > ,
2734+ ) -> relate:: RelateResult < ' tcx , ty:: Binder < ' tcx , T > >
2735+ where
2736+ T : relate:: Relate < ' tcx > ,
2737+ {
2738+ Ok ( ty:: Binder :: dummy ( self . relate ( a. skip_binder ( ) , b. skip_binder ( ) ) ?) )
2739+ }
2740+
2741+ fn consts (
2742+ & mut self ,
2743+ a : ty:: Const < ' tcx > ,
2744+ _b : ty:: Const < ' tcx > ,
2745+ ) -> relate:: RelateResult < ' tcx , ty:: Const < ' tcx > > {
2746+ // FIXME(compiler-errors): This could at least do some first-order
2747+ // relation
2748+ Ok ( a)
2749+ }
27242750}
27252751
27262752impl < ' a , ' tcx > InferCtxt < ' a , ' tcx > {
0 commit comments