@@ -11,7 +11,7 @@ use std::ops::ControlFlow;
1111use hir:: LangItem ;
1212use hir:: def_id:: DefId ;
1313use rustc_data_structures:: fx:: { FxHashSet , FxIndexSet } ;
14- use rustc_hir as hir;
14+ use rustc_hir:: { self as hir, CoroutineDesugaring , CoroutineKind } ;
1515use rustc_infer:: traits:: { Obligation , PolyTraitObligation , SelectionError } ;
1616use rustc_middle:: ty:: fast_reject:: DeepRejectCtxt ;
1717use rustc_middle:: ty:: { self , Ty , TypeVisitableExt , TypingMode } ;
@@ -125,11 +125,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
125125 self . assemble_async_iterator_candidates ( obligation, & mut candidates) ;
126126 } else if tcx. is_lang_item ( def_id, LangItem :: AsyncFnKindHelper ) {
127127 self . assemble_async_fn_kind_helper_candidates ( obligation, & mut candidates) ;
128+ } else if tcx. is_lang_item ( def_id, LangItem :: AsyncFn )
129+ || tcx. is_lang_item ( def_id, LangItem :: AsyncFnOnce )
130+ || tcx. is_lang_item ( def_id, LangItem :: AsyncFnMut )
131+ {
132+ self . assemble_async_closure_candidates ( obligation, & mut candidates) ;
128133 }
129134
130135 // FIXME: Put these into `else if` blocks above, since they're built-in.
131136 self . assemble_closure_candidates ( obligation, & mut candidates) ;
132- self . assemble_async_closure_candidates ( obligation, & mut candidates) ;
133137 self . assemble_fn_pointer_candidates ( obligation, & mut candidates) ;
134138
135139 self . assemble_candidates_from_impls ( obligation, & mut candidates) ;
@@ -425,6 +429,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
425429 }
426430 }
427431
432+ #[ instrument( level = "debug" , skip( self , candidates) ) ]
428433 fn assemble_async_closure_candidates (
429434 & mut self ,
430435 obligation : & PolyTraitObligation < ' tcx > ,
@@ -436,15 +441,30 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
436441 return ;
437442 } ;
438443
444+ debug ! ( "self_ty = {:?}" , obligation. self_ty( ) . skip_binder( ) . kind( ) ) ;
439445 match * obligation. self_ty ( ) . skip_binder ( ) . kind ( ) {
440- ty:: CoroutineClosure ( _ , args) => {
446+ ty:: CoroutineClosure ( def_id , args) => {
441447 if let Some ( closure_kind) =
442448 args. as_coroutine_closure ( ) . kind_ty ( ) . to_opt_closure_kind ( )
443449 && !closure_kind. extends ( goal_kind)
444450 {
445451 return ;
446452 }
447- candidates. vec . push ( AsyncClosureCandidate ) ;
453+
454+ // Make sure this is actually an async closure.
455+ let Some ( coroutine_kind) =
456+ self . tcx ( ) . coroutine_kind ( self . tcx ( ) . coroutine_for_closure ( def_id) )
457+ else {
458+ bug ! ( "coroutine with no kind" ) ;
459+ } ;
460+
461+ debug ! ( ?coroutine_kind) ;
462+ match coroutine_kind {
463+ CoroutineKind :: Desugared ( CoroutineDesugaring :: Async , _) => {
464+ candidates. vec . push ( AsyncClosureCandidate ) ;
465+ }
466+ _ => ( ) ,
467+ }
448468 }
449469 // Closures and fn pointers implement `AsyncFn*` if their return types
450470 // implement `Future`, which is checked later.
0 commit comments