@@ -17,7 +17,7 @@ use rustc_errors::{
1717 ErrorGuaranteed , MultiSpan , Style , SuggestionStyle ,
1818} ;
1919use rustc_hir as hir;
20- use rustc_hir:: def:: DefKind ;
20+ use rustc_hir:: def:: { DefKind , Res } ;
2121use rustc_hir:: def_id:: DefId ;
2222use rustc_hir:: intravisit:: Visitor ;
2323use rustc_hir:: is_range_literal;
@@ -36,7 +36,7 @@ use rustc_middle::ty::{
3636 TypeSuperFoldable , TypeVisitableExt , TypeckResults ,
3737} ;
3838use rustc_span:: def_id:: LocalDefId ;
39- use rustc_span:: symbol:: { sym, Ident , Symbol } ;
39+ use rustc_span:: symbol:: { kw , sym, Ident , Symbol } ;
4040use rustc_span:: { BytePos , DesugaringKind , ExpnKind , MacroKind , Span , DUMMY_SP } ;
4141use rustc_target:: spec:: abi;
4242use std:: borrow:: Cow ;
@@ -222,6 +222,15 @@ pub trait TypeErrCtxtExt<'tcx> {
222222 param_env : ty:: ParamEnv < ' tcx > ,
223223 ) -> DiagnosticBuilder < ' tcx , ErrorGuaranteed > ;
224224
225+ fn note_conflicting_fn_args (
226+ & self ,
227+ err : & mut Diagnostic ,
228+ cause : & ObligationCauseCode < ' tcx > ,
229+ expected : Ty < ' tcx > ,
230+ found : Ty < ' tcx > ,
231+ param_env : ty:: ParamEnv < ' tcx > ,
232+ ) ;
233+
225234 fn note_conflicting_closure_bounds (
226235 & self ,
227236 cause : & ObligationCauseCode < ' tcx > ,
@@ -1034,7 +1043,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
10341043 let hir:: ExprKind :: Path ( hir:: QPath :: Resolved ( None , path) ) = expr. kind else {
10351044 return ;
10361045 } ;
1037- let hir :: def :: Res :: Local ( hir_id) = path. res else {
1046+ let Res :: Local ( hir_id) = path. res else {
10381047 return ;
10391048 } ;
10401049 let Some ( hir:: Node :: Pat ( pat) ) = self . tcx . hir ( ) . find ( hir_id) else {
@@ -1618,7 +1627,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
16181627 }
16191628 }
16201629 if let hir:: ExprKind :: Path ( hir:: QPath :: Resolved ( None , path) ) = expr. kind
1621- && let hir :: def :: Res :: Local ( hir_id) = path. res
1630+ && let Res :: Local ( hir_id) = path. res
16221631 && let Some ( hir:: Node :: Pat ( binding) ) = self . tcx . hir ( ) . find ( hir_id)
16231632 && let Some ( hir:: Node :: Local ( local) ) = self . tcx . hir ( ) . find_parent ( binding. hir_id )
16241633 && let None = local. ty
@@ -2005,6 +2014,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
20052014 let signature_kind = format ! ( "{argument_kind} signature" ) ;
20062015 err. note_expected_found ( & signature_kind, expected_str, & signature_kind, found_str) ;
20072016
2017+ self . note_conflicting_fn_args ( & mut err, cause, expected, found, param_env) ;
20082018 self . note_conflicting_closure_bounds ( cause, & mut err) ;
20092019
20102020 if let Some ( found_node) = found_node {
@@ -2014,6 +2024,151 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
20142024 err
20152025 }
20162026
2027+ fn note_conflicting_fn_args (
2028+ & self ,
2029+ err : & mut Diagnostic ,
2030+ cause : & ObligationCauseCode < ' tcx > ,
2031+ expected : Ty < ' tcx > ,
2032+ found : Ty < ' tcx > ,
2033+ param_env : ty:: ParamEnv < ' tcx > ,
2034+ ) {
2035+ let ObligationCauseCode :: FunctionArgumentObligation { arg_hir_id, .. } = cause else {
2036+ return ;
2037+ } ;
2038+ let ty:: FnPtr ( expected) = expected. kind ( ) else {
2039+ return ;
2040+ } ;
2041+ let ty:: FnPtr ( found) = found. kind ( ) else {
2042+ return ;
2043+ } ;
2044+ let Some ( Node :: Expr ( arg) ) = self . tcx . hir ( ) . find ( * arg_hir_id) else {
2045+ return ;
2046+ } ;
2047+ let hir:: ExprKind :: Path ( path) = arg. kind else {
2048+ return ;
2049+ } ;
2050+ let expected_inputs = self . tcx . instantiate_bound_regions_with_erased ( * expected) . inputs ( ) ;
2051+ let found_inputs = self . tcx . instantiate_bound_regions_with_erased ( * found) . inputs ( ) ;
2052+ let both_tys = expected_inputs. iter ( ) . copied ( ) . zip ( found_inputs. iter ( ) . copied ( ) ) ;
2053+
2054+ let arg_expr = |infcx : & InferCtxt < ' tcx > , name, expected : Ty < ' tcx > , found : Ty < ' tcx > | {
2055+ let ( expected_ty, expected_refs) = get_deref_type_and_refs ( expected) ;
2056+ let ( found_ty, found_refs) = get_deref_type_and_refs ( found) ;
2057+
2058+ if infcx. can_eq ( param_env, found_ty, expected_ty) {
2059+ if found_refs. len ( ) == expected_refs. len ( )
2060+ && found_refs. iter ( ) . eq ( expected_refs. iter ( ) )
2061+ {
2062+ name
2063+ } else if found_refs. len ( ) > expected_refs. len ( ) {
2064+ let refs = & found_refs[ ..found_refs. len ( ) - expected_refs. len ( ) ] ;
2065+ if found_refs[ ..expected_refs. len ( ) ] . iter ( ) . eq ( expected_refs. iter ( ) ) {
2066+ format ! (
2067+ "{}{name}" ,
2068+ refs. iter( )
2069+ . map( |mutbl| format!( "&{}" , mutbl. prefix_str( ) ) )
2070+ . collect:: <Vec <_>>( )
2071+ . join( "" ) ,
2072+ )
2073+ } else {
2074+ // The refs have different mutability.
2075+ format ! (
2076+ "{}*{name}" ,
2077+ refs. iter( )
2078+ . map( |mutbl| format!( "&{}" , mutbl. prefix_str( ) ) )
2079+ . collect:: <Vec <_>>( )
2080+ . join( "" ) ,
2081+ )
2082+ }
2083+ } else if expected_refs. len ( ) > found_refs. len ( ) {
2084+ format ! (
2085+ "{}{name}" ,
2086+ ( 0 ..( expected_refs. len( ) - found_refs. len( ) ) )
2087+ . map( |_| "*" )
2088+ . collect:: <Vec <_>>( )
2089+ . join( "" ) ,
2090+ )
2091+ } else {
2092+ format ! (
2093+ "{}{name}" ,
2094+ found_refs
2095+ . iter( )
2096+ . map( |mutbl| format!( "&{}" , mutbl. prefix_str( ) ) )
2097+ . chain( found_refs. iter( ) . map( |_| "*" . to_string( ) ) )
2098+ . collect:: <Vec <_>>( )
2099+ . join( "" ) ,
2100+ )
2101+ }
2102+ } else {
2103+ format ! ( "/* {found} */" )
2104+ }
2105+ } ;
2106+ let args_have_same_underlying_type = both_tys. clone ( ) . all ( |( expected, found) | {
2107+ let ( expected_ty, _) = get_deref_type_and_refs ( expected) ;
2108+ let ( found_ty, _) = get_deref_type_and_refs ( found) ;
2109+ self . can_eq ( param_env, found_ty, expected_ty)
2110+ } ) ;
2111+ let ( closure_names, call_names) : ( Vec < _ > , Vec < _ > ) = if args_have_same_underlying_type
2112+ && !expected_inputs. is_empty ( )
2113+ && expected_inputs. len ( ) == found_inputs. len ( )
2114+ && let Some ( typeck) = & self . typeck_results
2115+ && let Res :: Def ( _, fn_def_id) = typeck. qpath_res ( & path, * arg_hir_id)
2116+ {
2117+ let closure: Vec < _ > = self
2118+ . tcx
2119+ . fn_arg_names ( fn_def_id)
2120+ . iter ( )
2121+ . enumerate ( )
2122+ . map ( |( i, ident) | {
2123+ if ident. name . is_empty ( ) || ident. name == kw:: SelfLower {
2124+ format ! ( "arg{i}" )
2125+ } else {
2126+ format ! ( "{ident}" )
2127+ }
2128+ } )
2129+ . collect ( ) ;
2130+ let args = closure
2131+ . iter ( )
2132+ . zip ( both_tys)
2133+ . map ( |( name, ( expected, found) ) | {
2134+ arg_expr ( self . infcx , name. to_owned ( ) , expected, found)
2135+ } )
2136+ . collect ( ) ;
2137+ ( closure, args)
2138+ } else {
2139+ let closure_args = expected_inputs
2140+ . iter ( )
2141+ . enumerate ( )
2142+ . map ( |( i, _) | format ! ( "arg{i}" ) )
2143+ . collect :: < Vec < _ > > ( ) ;
2144+ let call_args = both_tys
2145+ . enumerate ( )
2146+ . map ( |( i, ( expected, found) ) | {
2147+ arg_expr ( self . infcx , format ! ( "arg{i}" ) , expected, found)
2148+ } )
2149+ . collect :: < Vec < _ > > ( ) ;
2150+ ( closure_args, call_args)
2151+ } ;
2152+ let closure_names: Vec < _ > = closure_names
2153+ . into_iter ( )
2154+ . zip ( expected_inputs. iter ( ) )
2155+ . map ( |( name, ty) | {
2156+ format ! (
2157+ "{name}{}" ,
2158+ if ty. has_infer_types( ) { String :: new( ) } else { format!( ": {ty}" ) }
2159+ )
2160+ } )
2161+ . collect ( ) ;
2162+ err. multipart_suggestion (
2163+ format ! ( "consider wrapping the function in a closure" ) ,
2164+ vec ! [
2165+ ( arg. span. shrink_to_lo( ) , format!( "|{}| " , closure_names. join( ", " ) ) ) ,
2166+ ( arg. span. shrink_to_hi( ) , format!( "({})" , call_names. join( ", " ) ) ) ,
2167+ ] ,
2168+ Applicability :: MaybeIncorrect ,
2169+ ) ;
2170+ }
2171+
20172172 // Add a note if there are two `Fn`-family bounds that have conflicting argument
20182173 // requirements, which will always cause a closure to have a type error.
20192174 fn note_conflicting_closure_bounds (
@@ -3634,7 +3789,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
36343789 }
36353790 }
36363791 if let hir:: ExprKind :: Path ( hir:: QPath :: Resolved ( None , path) ) = expr. kind
3637- && let hir:: Path { res : hir :: def :: Res :: Local ( hir_id) , .. } = path
3792+ && let hir:: Path { res : Res :: Local ( hir_id) , .. } = path
36383793 && let Some ( hir:: Node :: Pat ( binding) ) = self . tcx . hir ( ) . find ( * hir_id)
36393794 && let parent_hir_id = self . tcx . hir ( ) . parent_id ( binding. hir_id )
36403795 && let Some ( hir:: Node :: Local ( local) ) = self . tcx . hir ( ) . find ( parent_hir_id)
@@ -3894,7 +4049,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
38944049 ) ;
38954050
38964051 if let hir:: ExprKind :: Path ( hir:: QPath :: Resolved ( None , path) ) = expr. kind
3897- && let hir:: Path { res : hir :: def :: Res :: Local ( hir_id) , .. } = path
4052+ && let hir:: Path { res : Res :: Local ( hir_id) , .. } = path
38984053 && let Some ( hir:: Node :: Pat ( binding) ) = self . tcx . hir ( ) . find ( * hir_id)
38994054 && let Some ( parent) = self . tcx . hir ( ) . find_parent ( binding. hir_id )
39004055 {
@@ -4349,17 +4504,6 @@ fn hint_missing_borrow<'tcx>(
43494504
43504505 let args = fn_decl. inputs . iter ( ) ;
43514506
4352- fn get_deref_type_and_refs ( mut ty : Ty < ' _ > ) -> ( Ty < ' _ > , Vec < hir:: Mutability > ) {
4353- let mut refs = vec ! [ ] ;
4354-
4355- while let ty:: Ref ( _, new_ty, mutbl) = ty. kind ( ) {
4356- ty = * new_ty;
4357- refs. push ( * mutbl) ;
4358- }
4359-
4360- ( ty, refs)
4361- }
4362-
43634507 let mut to_borrow = Vec :: new ( ) ;
43644508 let mut remove_borrow = Vec :: new ( ) ;
43654509
@@ -4519,7 +4663,7 @@ impl<'a, 'hir> hir::intravisit::Visitor<'hir> for ReplaceImplTraitVisitor<'a> {
45194663 fn visit_ty ( & mut self , t : & ' hir hir:: Ty < ' hir > ) {
45204664 if let hir:: TyKind :: Path ( hir:: QPath :: Resolved (
45214665 None ,
4522- hir:: Path { res : hir :: def :: Res :: Def ( _, segment_did) , .. } ,
4666+ hir:: Path { res : Res :: Def ( _, segment_did) , .. } ,
45234667 ) ) = t. kind
45244668 {
45254669 if self . param_did == * segment_did {
@@ -4652,3 +4796,14 @@ pub fn suggest_desugaring_async_fn_to_impl_future_in_trait<'tcx>(
46524796
46534797 Some ( sugg)
46544798}
4799+
4800+ fn get_deref_type_and_refs ( mut ty : Ty < ' _ > ) -> ( Ty < ' _ > , Vec < hir:: Mutability > ) {
4801+ let mut refs = vec ! [ ] ;
4802+
4803+ while let ty:: Ref ( _, new_ty, mutbl) = ty. kind ( ) {
4804+ ty = * new_ty;
4805+ refs. push ( * mutbl) ;
4806+ }
4807+
4808+ ( ty, refs)
4809+ }
0 commit comments