@@ -4,15 +4,18 @@ use clippy_utils::ty::{implements_trait, implements_trait_with_env, is_copy};
44use clippy_utils:: { is_lint_allowed, match_def_path} ;
55use if_chain:: if_chain;
66use rustc_errors:: Applicability ;
7+ use rustc_hir:: def_id:: DefId ;
78use rustc_hir:: intravisit:: { walk_expr, walk_fn, walk_item, FnKind , Visitor } ;
89use rustc_hir:: {
9- self as hir, BlockCheckMode , BodyId , Expr , ExprKind , FnDecl , HirId , Impl , Item , ItemKind , UnsafeSource , Unsafety ,
10+ self as hir, BlockCheckMode , BodyId , Constness , Expr , ExprKind , FnDecl , HirId , Impl , Item , ItemKind , UnsafeSource ,
11+ Unsafety ,
1012} ;
1113use rustc_lint:: { LateContext , LateLintPass } ;
1214use rustc_middle:: hir:: nested_filter;
13- use rustc_middle:: ty :: subst :: GenericArg ;
15+ use rustc_middle:: traits :: Reveal ;
1416use rustc_middle:: ty:: {
15- self , BoundConstness , ImplPolarity , ParamEnv , PredicateKind , TraitPredicate , TraitRef , Ty , Visibility ,
17+ self , Binder , BoundConstness , GenericParamDefKind , ImplPolarity , ParamEnv , PredicateKind , TraitPredicate , TraitRef ,
18+ Ty , TyCtxt , Visibility ,
1619} ;
1720use rustc_session:: { declare_lint_pass, declare_tool_lint} ;
1821use rustc_span:: source_map:: Span ;
@@ -463,49 +466,16 @@ fn check_partial_eq_without_eq<'tcx>(cx: &LateContext<'tcx>, span: Span, trait_r
463466 if let ty:: Adt ( adt, substs) = ty. kind( ) ;
464467 if cx. tcx. visibility( adt. did( ) ) == Visibility :: Public ;
465468 if let Some ( eq_trait_def_id) = cx. tcx. get_diagnostic_item( sym:: Eq ) ;
466- if let Some ( peq_trait_def_id) = cx. tcx. get_diagnostic_item( sym:: PartialEq ) ;
467469 if let Some ( def_id) = trait_ref. trait_def_id( ) ;
468470 if cx. tcx. is_diagnostic_item( sym:: PartialEq , def_id) ;
469- // New `ParamEnv` replacing `T: PartialEq` with `T: Eq`
470- let param_env = ParamEnv :: new(
471- cx. tcx. mk_predicates( cx. param_env. caller_bounds( ) . iter( ) . map( |p| {
472- let kind = p. kind( ) ;
473- match kind. skip_binder( ) {
474- PredicateKind :: Trait ( p)
475- if p. trait_ref. def_id == peq_trait_def_id
476- && p. trait_ref. substs. get( 0 ) == p. trait_ref. substs. get( 1 )
477- && matches!( p. trait_ref. self_ty( ) . kind( ) , ty:: Param ( _) )
478- && p. constness == BoundConstness :: NotConst
479- && p. polarity == ImplPolarity :: Positive =>
480- {
481- cx. tcx. mk_predicate( kind. rebind( PredicateKind :: Trait ( TraitPredicate {
482- trait_ref: TraitRef :: new(
483- eq_trait_def_id,
484- cx. tcx. mk_substs( [ GenericArg :: from( p. trait_ref. self_ty( ) ) ] . into_iter( ) ) ,
485- ) ,
486- constness: BoundConstness :: NotConst ,
487- polarity: ImplPolarity :: Positive ,
488- } ) ) )
489- } ,
490- _ => p,
491- }
492- } ) ) ,
493- cx. param_env. reveal( ) ,
494- cx. param_env. constness( ) ,
495- ) ;
496- if !implements_trait_with_env( cx. tcx, param_env, ty, eq_trait_def_id, substs) ;
471+ let param_env = param_env_for_derived_eq( cx. tcx, adt. did( ) , eq_trait_def_id) ;
472+ if !implements_trait_with_env( cx. tcx, param_env, ty, eq_trait_def_id, & [ ] ) ;
473+ // If all of our fields implement `Eq`, we can implement `Eq` too
474+ if adt
475+ . all_fields( )
476+ . map( |f| f. ty( cx. tcx, substs) )
477+ . all( |ty| implements_trait_with_env( cx. tcx, param_env, ty, eq_trait_def_id, & [ ] ) ) ;
497478 then {
498- // If all of our fields implement `Eq`, we can implement `Eq` too
499- for variant in adt. variants( ) {
500- for field in & variant. fields {
501- let ty = field. ty( cx. tcx, substs) ;
502-
503- if !implements_trait( cx, ty, eq_trait_def_id, substs) {
504- return ;
505- }
506- }
507- }
508-
509479 span_lint_and_sugg(
510480 cx,
511481 DERIVE_PARTIAL_EQ_WITHOUT_EQ ,
@@ -518,3 +488,41 @@ fn check_partial_eq_without_eq<'tcx>(cx: &LateContext<'tcx>, span: Span, trait_r
518488 }
519489 }
520490}
491+
492+ /// Creates the `ParamEnv` used for the give type's derived `Eq` impl.
493+ fn param_env_for_derived_eq ( tcx : TyCtxt < ' _ > , did : DefId , eq_trait_id : DefId ) -> ParamEnv < ' _ > {
494+ // Initial map from generic index to param def.
495+ // Vec<(param_def, needs_eq)>
496+ let mut params = tcx
497+ . generics_of ( did)
498+ . params
499+ . iter ( )
500+ . map ( |p| ( p, matches ! ( p. kind, GenericParamDefKind :: Type { .. } ) ) )
501+ . collect :: < Vec < _ > > ( ) ;
502+
503+ let ty_predicates = tcx. predicates_of ( did) . predicates ;
504+ for ( p, _) in ty_predicates {
505+ if let PredicateKind :: Trait ( p) = p. kind ( ) . skip_binder ( )
506+ && p. trait_ref . def_id == eq_trait_id
507+ && let ty:: Param ( self_ty) = p. trait_ref . self_ty ( ) . kind ( )
508+ && p. constness == BoundConstness :: NotConst
509+ {
510+ // Flag types which already have an `Eq` bound.
511+ params[ self_ty. index as usize ] . 1 = false ;
512+ }
513+ }
514+
515+ ParamEnv :: new (
516+ tcx. mk_predicates ( ty_predicates. iter ( ) . map ( |& ( p, _) | p) . chain (
517+ params. iter ( ) . filter ( |& & ( _, needs_eq) | needs_eq) . map ( |& ( param, _) | {
518+ tcx. mk_predicate ( Binder :: dummy ( PredicateKind :: Trait ( TraitPredicate {
519+ trait_ref : TraitRef :: new ( eq_trait_id, tcx. mk_substs ( [ tcx. mk_param_from_def ( param) ] . into_iter ( ) ) ) ,
520+ constness : BoundConstness :: NotConst ,
521+ polarity : ImplPolarity :: Positive ,
522+ } ) ) )
523+ } ) ,
524+ ) ) ,
525+ Reveal :: UserFacing ,
526+ Constness :: NotConst ,
527+ )
528+ }
0 commit comments