@@ -569,6 +569,18 @@ pub(super) fn thir_abstract_const<'tcx>(
569569 }
570570}
571571
572+ /// Tries to unify two abstract constants using structural equality.
573+ #[ instrument( skip( tcx) , level = "debug" ) ]
574+ pub ( super ) fn try_unify < ' tcx > (
575+ tcx : TyCtxt < ' tcx > ,
576+ a : AbstractConst < ' tcx > ,
577+ b : AbstractConst < ' tcx > ,
578+ param_env : ty:: ParamEnv < ' tcx > ,
579+ ) -> bool {
580+ let const_unify_ctxt = ConstUnifyCtxt :: new ( tcx, param_env) ;
581+ const_unify_ctxt. try_unify_inner ( a, b)
582+ }
583+
572584pub ( super ) fn try_unify_abstract_consts < ' tcx > (
573585 tcx : TyCtxt < ' tcx > ,
574586 ( a, b) : ( ty:: Unevaluated < ' tcx , ( ) > , ty:: Unevaluated < ' tcx , ( ) > ) ,
@@ -622,115 +634,119 @@ where
622634 recurse ( tcx, ct, & mut f)
623635}
624636
625- // Substitutes generics repeatedly to allow AbstractConsts to unify where a
626- // ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g.
627- // Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])]
628- #[ inline]
629- #[ instrument( skip( tcx) , level = "debug" ) ]
630- fn try_replace_substs_in_root < ' tcx > (
637+ pub ( super ) struct ConstUnifyCtxt < ' tcx > {
631638 tcx : TyCtxt < ' tcx > ,
632- mut abstr_const : AbstractConst < ' tcx > ,
633- ) -> Option < AbstractConst < ' tcx > > {
634- while let Node :: Leaf ( ct) = abstr_const. root ( tcx) {
635- match AbstractConst :: from_const ( tcx, ct) {
636- Ok ( Some ( act) ) => abstr_const = act,
637- Ok ( None ) => break ,
638- Err ( _) => return None ,
639- }
640- }
641-
642- Some ( abstr_const)
639+ param_env : ty:: ParamEnv < ' tcx > ,
643640}
644641
645- /// Tries to unify two abstract constants using structural equality.
646- #[ instrument( skip( tcx) , level = "debug" ) ]
647- pub ( super ) fn try_unify < ' tcx > (
648- tcx : TyCtxt < ' tcx > ,
649- a : AbstractConst < ' tcx > ,
650- b : AbstractConst < ' tcx > ,
651- param_env : ty:: ParamEnv < ' tcx > ,
652- ) -> bool {
653- let a = match try_replace_substs_in_root ( tcx, a) {
654- Some ( a) => a,
655- None => {
656- return true ;
642+ impl < ' tcx > ConstUnifyCtxt < ' tcx > {
643+ pub ( super ) fn new ( tcx : TyCtxt < ' tcx > , param_env : ty:: ParamEnv < ' tcx > ) -> Self {
644+ ConstUnifyCtxt { tcx, param_env }
645+ }
646+
647+ // Substitutes generics repeatedly to allow AbstractConsts to unify where a
648+ // ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g.
649+ // Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])]
650+ #[ inline]
651+ #[ instrument( skip( self ) , level = "debug" ) ]
652+ pub ( super ) fn try_replace_substs_in_root (
653+ & self ,
654+ mut abstr_const : AbstractConst < ' tcx > ,
655+ ) -> Option < AbstractConst < ' tcx > > {
656+ while let Node :: Leaf ( ct) = abstr_const. root ( self . tcx ) {
657+ match AbstractConst :: from_const ( self . tcx , ct) {
658+ Ok ( Some ( act) ) => abstr_const = act,
659+ Ok ( None ) => break ,
660+ Err ( _) => return None ,
661+ }
657662 }
658- } ;
659663
660- let b = match try_replace_substs_in_root ( tcx, b) {
661- Some ( b) => b,
662- None => {
664+ Some ( abstr_const)
665+ }
666+
667+ /// Tries to unify two abstract constants using structural equality.
668+ #[ instrument( skip( self ) , level = "debug" ) ]
669+ fn try_unify_inner ( & self , a : AbstractConst < ' tcx > , b : AbstractConst < ' tcx > ) -> bool {
670+ let a = if let Some ( a) = self . try_replace_substs_in_root ( a) {
671+ a
672+ } else {
663673 return true ;
664- }
665- } ;
674+ } ;
666675
667- let a_root = a. root ( tcx) ;
668- let b_root = b. root ( tcx) ;
669- debug ! ( ?a_root, ?b_root) ;
676+ let b = if let Some ( b) = self . try_replace_substs_in_root ( b) {
677+ b
678+ } else {
679+ return true ;
680+ } ;
670681
671- match ( a_root, b_root) {
672- ( Node :: Leaf ( a_ct) , Node :: Leaf ( b_ct) ) => {
673- let a_ct = a_ct. eval ( tcx, param_env) ;
674- debug ! ( "a_ct evaluated: {:?}" , a_ct) ;
675- let b_ct = b_ct. eval ( tcx, param_env) ;
676- debug ! ( "b_ct evaluated: {:?}" , b_ct) ;
682+ let a_root = a. root ( self . tcx ) ;
683+ let b_root = b. root ( self . tcx ) ;
684+ debug ! ( ?a_root, ?b_root) ;
677685
678- if a_ct. ty ( ) != b_ct. ty ( ) {
679- return false ;
680- }
686+ match ( a_root, b_root) {
687+ ( Node :: Leaf ( a_ct) , Node :: Leaf ( b_ct) ) => {
688+ let a_ct = a_ct. eval ( self . tcx , self . param_env ) ;
689+ debug ! ( "a_ct evaluated: {:?}" , a_ct) ;
690+ let b_ct = b_ct. eval ( self . tcx , self . param_env ) ;
691+ debug ! ( "b_ct evaluated: {:?}" , b_ct) ;
681692
682- match ( a_ct. val ( ) , b_ct. val ( ) ) {
683- // We can just unify errors with everything to reduce the amount of
684- // emitted errors here.
685- ( ty:: ConstKind :: Error ( _) , _) | ( _, ty:: ConstKind :: Error ( _) ) => true ,
686- ( ty:: ConstKind :: Param ( a_param) , ty:: ConstKind :: Param ( b_param) ) => {
687- a_param == b_param
693+ if a_ct. ty ( ) != b_ct. ty ( ) {
694+ return false ;
688695 }
689- ( ty:: ConstKind :: Value ( a_val) , ty:: ConstKind :: Value ( b_val) ) => a_val == b_val,
690- // If we have `fn a<const N: usize>() -> [u8; N + 1]` and `fn b<const M: usize>() -> [u8; 1 + M]`
691- // we do not want to use `assert_eq!(a(), b())` to infer that `N` and `M` have to be `1`. This
692- // means that we only allow inference variables if they are equal.
693- ( ty:: ConstKind :: Infer ( a_val) , ty:: ConstKind :: Infer ( b_val) ) => a_val == b_val,
694- // We expand generic anonymous constants at the start of this function, so this
695- // branch should only be taking when dealing with associated constants, at
696- // which point directly comparing them seems like the desired behavior.
697- //
698- // FIXME(generic_const_exprs): This isn't actually the case.
699- // We also take this branch for concrete anonymous constants and
700- // expand generic anonymous constants with concrete substs.
701- ( ty:: ConstKind :: Unevaluated ( a_uv) , ty:: ConstKind :: Unevaluated ( b_uv) ) => {
702- a_uv == b_uv
696+
697+ match ( a_ct. val ( ) , b_ct. val ( ) ) {
698+ // We can just unify errors with everything to reduce the amount of
699+ // emitted errors here.
700+ ( ty:: ConstKind :: Error ( _) , _) | ( _, ty:: ConstKind :: Error ( _) ) => true ,
701+ ( ty:: ConstKind :: Param ( a_param) , ty:: ConstKind :: Param ( b_param) ) => {
702+ a_param == b_param
703+ }
704+ ( ty:: ConstKind :: Value ( a_val) , ty:: ConstKind :: Value ( b_val) ) => a_val == b_val,
705+ // If we have `fn a<const N: usize>() -> [u8; N + 1]` and `fn b<const M: usize>() -> [u8; 1 + M]`
706+ // we do not want to use `assert_eq!(a(), b())` to infer that `N` and `M` have to be `1`. This
707+ // means that we only allow inference variables if they are equal.
708+ ( ty:: ConstKind :: Infer ( a_val) , ty:: ConstKind :: Infer ( b_val) ) => a_val == b_val,
709+ // We expand generic anonymous constants at the start of this function, so this
710+ // branch should only be taking when dealing with associated constants, at
711+ // which point directly comparing them seems like the desired behavior.
712+ //
713+ // FIXME(generic_const_exprs): This isn't actually the case.
714+ // We also take this branch for concrete anonymous constants and
715+ // expand generic anonymous constants with concrete substs.
716+ ( ty:: ConstKind :: Unevaluated ( a_uv) , ty:: ConstKind :: Unevaluated ( b_uv) ) => {
717+ a_uv == b_uv
718+ }
719+ // FIXME(generic_const_exprs): We may want to either actually try
720+ // to evaluate `a_ct` and `b_ct` if they are are fully concrete or something like
721+ // this, for now we just return false here.
722+ _ => false ,
703723 }
704- // FIXME(generic_const_exprs): We may want to either actually try
705- // to evaluate `a_ct` and `b_ct` if they are are fully concrete or something like
706- // this, for now we just return false here.
707- _ => false ,
708724 }
725+ ( Node :: Binop ( a_op, al, ar) , Node :: Binop ( b_op, bl, br) ) if a_op == b_op => {
726+ self . try_unify_inner ( a. subtree ( al) , b. subtree ( bl) )
727+ && self . try_unify_inner ( a. subtree ( ar) , b. subtree ( br) )
728+ }
729+ ( Node :: UnaryOp ( a_op, av) , Node :: UnaryOp ( b_op, bv) ) if a_op == b_op => {
730+ self . try_unify_inner ( a. subtree ( av) , b. subtree ( bv) )
731+ }
732+ ( Node :: FunctionCall ( a_f, a_args) , Node :: FunctionCall ( b_f, b_args) )
733+ if a_args. len ( ) == b_args. len ( ) =>
734+ {
735+ self . try_unify_inner ( a. subtree ( a_f) , b. subtree ( b_f) )
736+ && iter:: zip ( a_args, b_args)
737+ . all ( |( & an, & bn) | self . try_unify_inner ( a. subtree ( an) , b. subtree ( bn) ) )
738+ }
739+ ( Node :: Cast ( a_kind, a_operand, a_ty) , Node :: Cast ( b_kind, b_operand, b_ty) )
740+ if ( a_ty == b_ty) && ( a_kind == b_kind) =>
741+ {
742+ self . try_unify_inner ( a. subtree ( a_operand) , b. subtree ( b_operand) )
743+ }
744+ // use this over `_ => false` to make adding variants to `Node` less error prone
745+ ( Node :: Cast ( ..) , _)
746+ | ( Node :: FunctionCall ( ..) , _)
747+ | ( Node :: UnaryOp ( ..) , _)
748+ | ( Node :: Binop ( ..) , _)
749+ | ( Node :: Leaf ( ..) , _) => false ,
709750 }
710- ( Node :: Binop ( a_op, al, ar) , Node :: Binop ( b_op, bl, br) ) if a_op == b_op => {
711- try_unify ( tcx, a. subtree ( al) , b. subtree ( bl) , param_env)
712- && try_unify ( tcx, a. subtree ( ar) , b. subtree ( br) , param_env)
713- }
714- ( Node :: UnaryOp ( a_op, av) , Node :: UnaryOp ( b_op, bv) ) if a_op == b_op => {
715- try_unify ( tcx, a. subtree ( av) , b. subtree ( bv) , param_env)
716- }
717- ( Node :: FunctionCall ( a_f, a_args) , Node :: FunctionCall ( b_f, b_args) )
718- if a_args. len ( ) == b_args. len ( ) =>
719- {
720- try_unify ( tcx, a. subtree ( a_f) , b. subtree ( b_f) , param_env)
721- && iter:: zip ( a_args, b_args)
722- . all ( |( & an, & bn) | try_unify ( tcx, a. subtree ( an) , b. subtree ( bn) , param_env) )
723- }
724- ( Node :: Cast ( a_kind, a_operand, a_ty) , Node :: Cast ( b_kind, b_operand, b_ty) )
725- if ( a_ty == b_ty) && ( a_kind == b_kind) =>
726- {
727- try_unify ( tcx, a. subtree ( a_operand) , b. subtree ( b_operand) , param_env)
728- }
729- // use this over `_ => false` to make adding variants to `Node` less error prone
730- ( Node :: Cast ( ..) , _)
731- | ( Node :: FunctionCall ( ..) , _)
732- | ( Node :: UnaryOp ( ..) , _)
733- | ( Node :: Binop ( ..) , _)
734- | ( Node :: Leaf ( ..) , _) => false ,
735751 }
736752}
0 commit comments