@@ -25,8 +25,11 @@ pub(crate) mod unify;
2525use std:: { convert:: identity, iter, ops:: Index } ;
2626
2727use chalk_ir:: {
28- cast:: Cast , fold:: TypeFoldable , interner:: HasInterner , DebruijnIndex , Mutability , Safety ,
29- Scalar , TyKind , TypeFlags , Variance ,
28+ cast:: Cast ,
29+ fold:: TypeFoldable ,
30+ interner:: HasInterner ,
31+ visit:: { TypeSuperVisitable , TypeVisitable , TypeVisitor } ,
32+ DebruijnIndex , Mutability , Safety , Scalar , TyKind , TypeFlags , Variance ,
3033} ;
3134use either:: Either ;
3235use hir_def:: {
@@ -53,14 +56,14 @@ use triomphe::Arc;
5356use crate :: {
5457 db:: HirDatabase ,
5558 fold_tys,
56- infer:: coerce:: CoerceMany ,
59+ infer:: { coerce:: CoerceMany , unify :: InferenceTable } ,
5760 lower:: ImplTraitLoweringMode ,
5861 static_lifetime, to_assoc_type_id,
5962 traits:: FnTrait ,
6063 utils:: { InTypeConstIdMetadata , UnevaluatedConstEvaluatorFolder } ,
6164 AliasEq , AliasTy , Binders , ClosureId , Const , DomainGoal , GenericArg , Goal , ImplTraitId ,
62- InEnvironment , Interner , Lifetime , ProjectionTy , RpitId , Substitution , TraitEnvironment ,
63- TraitRef , Ty , TyBuilder , TyExt ,
65+ ImplTraitIdx , InEnvironment , Interner , Lifetime , OpaqueTyId , ProjectionTy , Substitution ,
66+ TraitEnvironment , Ty , TyBuilder , TyExt ,
6467} ;
6568
6669// This lint has a false positive here. See the link below for details.
@@ -422,7 +425,7 @@ pub struct InferenceResult {
422425 /// unresolved or missing subpatterns or subpatterns of mismatched types.
423426 pub type_of_pat : ArenaMap < PatId , Ty > ,
424427 pub type_of_binding : ArenaMap < BindingId , Ty > ,
425- pub type_of_rpit : ArenaMap < RpitId , Ty > ,
428+ pub type_of_rpit : ArenaMap < ImplTraitIdx , Ty > ,
426429 /// Type of the result of `.into_iter()` on the for. `ExprId` is the one of the whole for loop.
427430 pub type_of_for_iterator : FxHashMap < ExprId , Ty > ,
428431 type_mismatches : FxHashMap < ExprOrPatId , TypeMismatch > ,
@@ -752,7 +755,12 @@ impl<'a> InferenceContext<'a> {
752755 }
753756
754757 fn collect_const ( & mut self , data : & ConstData ) {
755- self . return_ty = self . make_ty ( & data. type_ref ) ;
758+ let return_ty = self . make_ty ( & data. type_ref ) ;
759+
760+ // Constants might be associated items that define ATPITs.
761+ self . insert_atpit_coercion_table ( iter:: once ( & return_ty) ) ;
762+
763+ self . return_ty = return_ty;
756764 }
757765
758766 fn collect_static ( & mut self , data : & StaticData ) {
@@ -785,11 +793,13 @@ impl<'a> InferenceContext<'a> {
785793 self . write_binding_ty ( self_param, ty) ;
786794 }
787795 }
796+ let mut params_and_ret_tys = Vec :: new ( ) ;
788797 for ( ty, pat) in param_tys. zip ( & * self . body . params ) {
789798 let ty = self . insert_type_vars ( ty) ;
790799 let ty = self . normalize_associated_types_in ( ty) ;
791800
792801 self . infer_top_pat ( * pat, & ty) ;
802+ params_and_ret_tys. push ( ty) ;
793803 }
794804 let return_ty = & * data. ret_type ;
795805
@@ -801,8 +811,11 @@ impl<'a> InferenceContext<'a> {
801811 let return_ty = if let Some ( rpits) = self . db . return_type_impl_traits ( func) {
802812 // RPIT opaque types use substitution of their parent function.
803813 let fn_placeholders = TyBuilder :: placeholder_subst ( self . db , func) ;
804- let result =
805- self . insert_inference_vars_for_rpit ( return_ty, rpits. clone ( ) , fn_placeholders) ;
814+ let result = self . insert_inference_vars_for_impl_trait (
815+ return_ty,
816+ rpits. clone ( ) ,
817+ fn_placeholders,
818+ ) ;
806819 let rpits = rpits. skip_binders ( ) ;
807820 for ( id, _) in rpits. impl_traits . iter ( ) {
808821 if let Entry :: Vacant ( e) = self . result . type_of_rpit . entry ( id) {
@@ -817,13 +830,19 @@ impl<'a> InferenceContext<'a> {
817830
818831 self . return_ty = self . normalize_associated_types_in ( return_ty) ;
819832 self . return_coercion = Some ( CoerceMany :: new ( self . return_ty . clone ( ) ) ) ;
833+
834+ // Functions might be associated items that define ATPITs.
835+ // To define an ATPITs, that ATPIT must appear in the function's signitures.
836+ // So, it suffices to check for params and return types.
837+ params_and_ret_tys. push ( self . return_ty . clone ( ) ) ;
838+ self . insert_atpit_coercion_table ( params_and_ret_tys. iter ( ) ) ;
820839 }
821840
822- fn insert_inference_vars_for_rpit < T > (
841+ fn insert_inference_vars_for_impl_trait < T > (
823842 & mut self ,
824843 t : T ,
825- rpits : Arc < chalk_ir:: Binders < crate :: ReturnTypeImplTraits > > ,
826- fn_placeholders : Substitution ,
844+ rpits : Arc < chalk_ir:: Binders < crate :: ImplTraits > > ,
845+ placeholders : Substitution ,
827846 ) -> T
828847 where
829848 T : crate :: HasInterner < Interner = Interner > + crate :: TypeFoldable < Interner > ,
@@ -837,22 +856,22 @@ impl<'a> InferenceContext<'a> {
837856 } ;
838857 let idx = match self . db . lookup_intern_impl_trait_id ( opaque_ty_id. into ( ) ) {
839858 ImplTraitId :: ReturnTypeImplTrait ( _, idx) => idx,
859+ ImplTraitId :: AssociatedTypeImplTrait ( _, idx) => idx,
840860 _ => unreachable ! ( ) ,
841861 } ;
842862 let bounds =
843863 ( * rpits) . map_ref ( |rpits| rpits. impl_traits [ idx] . bounds . map_ref ( |it| it. iter ( ) ) ) ;
844864 let var = self . table . new_type_var ( ) ;
845865 let var_subst = Substitution :: from1 ( Interner , var. clone ( ) ) ;
846866 for bound in bounds {
847- let predicate =
848- bound. map ( |it| it. cloned ( ) ) . substitute ( Interner , & fn_placeholders) ;
867+ let predicate = bound. map ( |it| it. cloned ( ) ) . substitute ( Interner , & placeholders) ;
849868 let ( var_predicate, binders) =
850869 predicate. substitute ( Interner , & var_subst) . into_value_and_skipped_binders ( ) ;
851870 always ! ( binders. is_empty( Interner ) ) ; // quantified where clauses not yet handled
852- let var_predicate = self . insert_inference_vars_for_rpit (
871+ let var_predicate = self . insert_inference_vars_for_impl_trait (
853872 var_predicate,
854873 rpits. clone ( ) ,
855- fn_placeholders . clone ( ) ,
874+ placeholders . clone ( ) ,
856875 ) ;
857876 self . push_obligation ( var_predicate. cast ( Interner ) ) ;
858877 }
@@ -863,6 +882,106 @@ impl<'a> InferenceContext<'a> {
863882 )
864883 }
865884
885+ /// The coercion of a non-inference var into an opaque type should fail,
886+ /// but not in the defining sites of the ATPITs.
887+ /// In such cases, we insert an proxy inference var for each ATPIT,
888+ /// and coerce into it instead of ATPIT itself.
889+ ///
890+ /// The inference var stretagy is effective because;
891+ ///
892+ /// - It can still unify types that coerced into ATPIT
893+ /// - We are pushing `impl Trait` bounds into it
894+ ///
895+ /// This function inserts a map that maps the opaque type to that proxy inference var.
896+ fn insert_atpit_coercion_table < ' b > ( & mut self , tys : impl Iterator < Item = & ' b Ty > ) {
897+ struct OpaqueTyCollector < ' a , ' b > {
898+ table : & ' b mut InferenceTable < ' a > ,
899+ opaque_tys : FxHashMap < OpaqueTyId , Ty > ,
900+ }
901+
902+ impl < ' a , ' b > TypeVisitor < Interner > for OpaqueTyCollector < ' a , ' b > {
903+ type BreakTy = ( ) ;
904+
905+ fn as_dyn ( & mut self ) -> & mut dyn TypeVisitor < Interner , BreakTy = Self :: BreakTy > {
906+ self
907+ }
908+
909+ fn interner ( & self ) -> Interner {
910+ Interner
911+ }
912+
913+ fn visit_ty (
914+ & mut self ,
915+ ty : & chalk_ir:: Ty < Interner > ,
916+ outer_binder : DebruijnIndex ,
917+ ) -> std:: ops:: ControlFlow < Self :: BreakTy > {
918+ let ty = self . table . resolve_ty_shallow ( ty) ;
919+
920+ if let TyKind :: OpaqueType ( id, _) = ty. kind ( Interner ) {
921+ self . opaque_tys . insert ( * id, ty. clone ( ) ) ;
922+ }
923+
924+ ty. super_visit_with ( self , outer_binder)
925+ }
926+ }
927+
928+ // Early return if this is not happening inside the impl block
929+ let impl_id = if let Some ( impl_id) = self . resolver . impl_def ( ) {
930+ impl_id
931+ } else {
932+ return ;
933+ } ;
934+
935+ let assoc_tys: FxHashSet < _ > = self
936+ . db
937+ . impl_data ( impl_id)
938+ . items
939+ . iter ( )
940+ . filter_map ( |item| match item {
941+ AssocItemId :: TypeAliasId ( alias) => Some ( * alias) ,
942+ _ => None ,
943+ } )
944+ . collect ( ) ;
945+ if assoc_tys. is_empty ( ) {
946+ return ;
947+ }
948+
949+ let mut collector =
950+ OpaqueTyCollector { table : & mut self . table , opaque_tys : FxHashMap :: default ( ) } ;
951+ for ty in tys {
952+ ty. visit_with ( collector. as_dyn ( ) , DebruijnIndex :: INNERMOST ) ;
953+ }
954+ let atpit_coercion_table: FxHashMap < _ , _ > = collector
955+ . opaque_tys
956+ . into_iter ( )
957+ . filter_map ( |( opaque_ty_id, ty) | {
958+ if let ImplTraitId :: AssociatedTypeImplTrait ( alias_id, _) =
959+ self . db . lookup_intern_impl_trait_id ( opaque_ty_id. into ( ) )
960+ {
961+ if assoc_tys. contains ( & alias_id) {
962+ let atpits = self
963+ . db
964+ . type_alias_impl_traits ( alias_id)
965+ . expect ( "Marked as ATPIT but no impl traits!" ) ;
966+ let alias_placeholders = TyBuilder :: placeholder_subst ( self . db , alias_id) ;
967+ let ty = self . insert_inference_vars_for_impl_trait (
968+ ty,
969+ atpits,
970+ alias_placeholders,
971+ ) ;
972+ return Some ( ( opaque_ty_id, ty) ) ;
973+ }
974+ }
975+
976+ None
977+ } )
978+ . collect ( ) ;
979+
980+ if !atpit_coercion_table. is_empty ( ) {
981+ self . table . atpit_coercion_table = Some ( atpit_coercion_table) ;
982+ }
983+ }
984+
866985 fn infer_body ( & mut self ) {
867986 match self . return_coercion {
868987 Some ( _) => self . infer_return ( self . body . body_expr ) ,
0 commit comments