@@ -21,6 +21,7 @@ use rustc_hir::def_id::DefId;
2121use rustc_hir:: { ExprKind , Node , QPath } ;
2222use rustc_index:: vec:: IndexVec ;
2323use rustc_infer:: infer:: error_reporting:: { FailureCode , ObligationCauseExt } ;
24+ use rustc_infer:: infer:: type_variable:: { TypeVariableOrigin , TypeVariableOriginKind } ;
2425use rustc_infer:: infer:: InferOk ;
2526use rustc_infer:: infer:: TypeTrace ;
2627use rustc_middle:: ty:: adjustment:: AllowTwoPhase ;
@@ -29,7 +30,9 @@ use rustc_middle::ty::{self, DefIdTree, IsSuggestable, Ty};
2930use rustc_session:: Session ;
3031use rustc_span:: symbol:: Ident ;
3132use rustc_span:: { self , Span } ;
32- use rustc_trait_selection:: traits:: { self , ObligationCauseCode , StatementAsExpression } ;
33+ use rustc_trait_selection:: traits:: {
34+ self , ObligationCauseCode , SelectionContext , StatementAsExpression ,
35+ } ;
3336
3437use std:: iter;
3538use std:: slice;
@@ -393,41 +396,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
393396 }
394397
395398 if !call_appears_satisfied {
396- // Next, let's construct the error
397- let ( error_span, full_call_span, ctor_of) = match & call_expr. kind {
398- hir:: ExprKind :: Call (
399- hir:: Expr {
400- span,
401- kind :
402- hir:: ExprKind :: Path ( hir:: QPath :: Resolved (
403- _,
404- hir:: Path { res : Res :: Def ( DefKind :: Ctor ( of, _) , _) , .. } ,
405- ) ) ,
406- ..
407- } ,
408- _,
409- ) => ( call_span, * span, Some ( of) ) ,
410- hir:: ExprKind :: Call ( hir:: Expr { span, .. } , _) => ( call_span, * span, None ) ,
411- hir:: ExprKind :: MethodCall ( path_segment, _, span) => {
412- let ident_span = path_segment. ident . span ;
413- let ident_span = if let Some ( args) = path_segment. args {
414- ident_span. with_hi ( args. span_ext . hi ( ) )
415- } else {
416- ident_span
417- } ;
418- (
419- * span, ident_span, None , // methods are never ctors
420- )
421- }
422- k => span_bug ! ( call_span, "checking argument types on a non-call: `{:?}`" , k) ,
423- } ;
424- let args_span = error_span. trim_start ( full_call_span) . unwrap_or ( error_span) ;
425- let call_name = match ctor_of {
426- Some ( CtorOf :: Struct ) => "struct" ,
427- Some ( CtorOf :: Variant ) => "enum variant" ,
428- None => "function" ,
429- } ;
430-
431399 let compatibility_diagonal = IndexVec :: from_raw ( compatibility_diagonal) ;
432400 let provided_args = IndexVec :: from_iter ( provided_args. iter ( ) . take ( if c_variadic {
433401 minimum_input_count
@@ -451,13 +419,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
451419 compatibility_diagonal,
452420 formal_and_expected_inputs,
453421 provided_args,
454- full_call_span,
455- error_span,
456- args_span,
457- call_name,
458422 c_variadic,
459423 err_code,
460424 fn_def_id,
425+ call_span,
461426 call_expr,
462427 ) ;
463428 }
@@ -468,15 +433,47 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
468433 compatibility_diagonal : IndexVec < ProvidedIdx , Compatibility < ' tcx > > ,
469434 formal_and_expected_inputs : IndexVec < ExpectedIdx , ( Ty < ' tcx > , Ty < ' tcx > ) > ,
470435 provided_args : IndexVec < ProvidedIdx , & ' tcx hir:: Expr < ' tcx > > ,
471- full_call_span : Span ,
472- error_span : Span ,
473- args_span : Span ,
474- call_name : & str ,
475436 c_variadic : bool ,
476437 err_code : & str ,
477438 fn_def_id : Option < DefId > ,
439+ call_span : Span ,
478440 call_expr : & hir:: Expr < ' tcx > ,
479441 ) {
442+ // Next, let's construct the error
443+ let ( error_span, full_call_span, ctor_of) = match & call_expr. kind {
444+ hir:: ExprKind :: Call (
445+ hir:: Expr {
446+ span,
447+ kind :
448+ hir:: ExprKind :: Path ( hir:: QPath :: Resolved (
449+ _,
450+ hir:: Path { res : Res :: Def ( DefKind :: Ctor ( of, _) , _) , .. } ,
451+ ) ) ,
452+ ..
453+ } ,
454+ _,
455+ ) => ( call_span, * span, Some ( of) ) ,
456+ hir:: ExprKind :: Call ( hir:: Expr { span, .. } , _) => ( call_span, * span, None ) ,
457+ hir:: ExprKind :: MethodCall ( path_segment, _, span) => {
458+ let ident_span = path_segment. ident . span ;
459+ let ident_span = if let Some ( args) = path_segment. args {
460+ ident_span. with_hi ( args. span_ext . hi ( ) )
461+ } else {
462+ ident_span
463+ } ;
464+ (
465+ * span, ident_span, None , // methods are never ctors
466+ )
467+ }
468+ k => span_bug ! ( call_span, "checking argument types on a non-call: `{:?}`" , k) ,
469+ } ;
470+ let args_span = error_span. trim_start ( full_call_span) . unwrap_or ( error_span) ;
471+ let call_name = match ctor_of {
472+ Some ( CtorOf :: Struct ) => "struct" ,
473+ Some ( CtorOf :: Variant ) => "enum variant" ,
474+ None => "function" ,
475+ } ;
476+
480477 // Don't print if it has error types or is just plain `_`
481478 fn has_error_or_infer < ' tcx > ( tys : impl IntoIterator < Item = Ty < ' tcx > > ) -> bool {
482479 tys. into_iter ( ) . any ( |ty| ty. references_error ( ) || ty. is_ty_var ( ) )
@@ -1818,17 +1815,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
18181815 fn label_fn_like (
18191816 & self ,
18201817 err : & mut rustc_errors:: DiagnosticBuilder < ' tcx , rustc_errors:: ErrorGuaranteed > ,
1821- def_id : Option < DefId > ,
1818+ callable_def_id : Option < DefId > ,
18221819 callee_ty : Option < Ty < ' tcx > > ,
18231820 ) {
1824- let Some ( mut def_id) = def_id else {
1821+ let Some ( mut def_id) = callable_def_id else {
18251822 return ;
18261823 } ;
18271824
18281825 if let Some ( assoc_item) = self . tcx . opt_associated_item ( def_id)
1829- && let trait_def_id = assoc_item. trait_item_def_id . unwrap_or_else ( || self . tcx . parent ( def_id) )
1826+ // Possibly points at either impl or trait item, so try to get it
1827+ // to point to trait item, then get the parent.
1828+ // This parent might be an impl in the case of an inherent function,
1829+ // but the next check will fail.
1830+ && let maybe_trait_item_def_id = assoc_item. trait_item_def_id . unwrap_or ( def_id)
1831+ && let maybe_trait_def_id = self . tcx . parent ( maybe_trait_item_def_id)
18301832 // Just an easy way to check "trait_def_id == Fn/FnMut/FnOnce"
1831- && ty:: ClosureKind :: from_def_id ( self . tcx , trait_def_id ) . is_some ( )
1833+ && let Some ( call_kind ) = ty:: ClosureKind :: from_def_id ( self . tcx , maybe_trait_def_id )
18321834 && let Some ( callee_ty) = callee_ty
18331835 {
18341836 let callee_ty = callee_ty. peel_refs ( ) ;
@@ -1853,7 +1855,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
18531855 std:: iter:: zip ( instantiated. predicates , instantiated. spans )
18541856 {
18551857 if let ty:: PredicateKind :: Trait ( pred) = predicate. kind ( ) . skip_binder ( )
1856- && pred. self_ty ( ) == callee_ty
1858+ && pred. self_ty ( ) . peel_refs ( ) == callee_ty
18571859 && ty:: ClosureKind :: from_def_id ( self . tcx , pred. def_id ( ) ) . is_some ( )
18581860 {
18591861 err. span_note ( span, "callable defined here" ) ;
@@ -1862,11 +1864,46 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
18621864 }
18631865 }
18641866 }
1865- ty:: Opaque ( new_def_id, _) | ty:: Closure ( new_def_id, _) | ty:: FnDef ( new_def_id, _) => {
1867+ ty:: Opaque ( new_def_id, _)
1868+ | ty:: Closure ( new_def_id, _)
1869+ | ty:: FnDef ( new_def_id, _) => {
18661870 def_id = new_def_id;
18671871 }
18681872 _ => {
1869- return ;
1873+ // Look for a user-provided impl of a `Fn` trait, and point to it.
1874+ let new_def_id = self . probe ( |_| {
1875+ let trait_ref = ty:: TraitRef :: new (
1876+ call_kind. to_def_id ( self . tcx ) ,
1877+ self . tcx . mk_substs ( [
1878+ ty:: GenericArg :: from ( callee_ty) ,
1879+ self . next_ty_var ( TypeVariableOrigin {
1880+ kind : TypeVariableOriginKind :: MiscVariable ,
1881+ span : rustc_span:: DUMMY_SP ,
1882+ } )
1883+ . into ( ) ,
1884+ ] . into_iter ( ) ) ,
1885+ ) ;
1886+ let obligation = traits:: Obligation :: new (
1887+ traits:: ObligationCause :: dummy ( ) ,
1888+ self . param_env ,
1889+ ty:: Binder :: dummy ( ty:: TraitPredicate {
1890+ trait_ref,
1891+ constness : ty:: BoundConstness :: NotConst ,
1892+ polarity : ty:: ImplPolarity :: Positive ,
1893+ } ) ,
1894+ ) ;
1895+ match SelectionContext :: new ( & self ) . select ( & obligation) {
1896+ Ok ( Some ( traits:: ImplSource :: UserDefined ( impl_source) ) ) => {
1897+ Some ( impl_source. impl_def_id )
1898+ }
1899+ _ => None
1900+ }
1901+ } ) ;
1902+ if let Some ( new_def_id) = new_def_id {
1903+ def_id = new_def_id;
1904+ } else {
1905+ return ;
1906+ }
18701907 }
18711908 }
18721909 }
@@ -1888,8 +1925,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
18881925
18891926 let def_kind = self . tcx . def_kind ( def_id) ;
18901927 err. span_note ( spans, & format ! ( "{} defined here" , def_kind. descr( def_id) ) ) ;
1891- } else if let def_kind @ ( DefKind :: Closure | DefKind :: OpaqueTy ) = self . tcx . def_kind ( def_id )
1892- {
1928+ } else {
1929+ let def_kind = self . tcx . def_kind ( def_id ) ;
18931930 err. span_note (
18941931 self . tcx . def_span ( def_id) ,
18951932 & format ! ( "{} defined here" , def_kind. descr( def_id) ) ,
0 commit comments