@@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
337337 ) ,
338338 ExprKind :: Try ( sub_expr) => self . lower_expr_try ( e. span , sub_expr) ,
339339
340- ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
340+ ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
341341 unreachable ! ( "already handled" )
342342 }
343343
@@ -874,6 +874,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
874874 /// }
875875 /// ```
876876 fn lower_expr_await ( & mut self , await_kw_span : Span , expr : & Expr ) -> hir:: ExprKind < ' hir > {
877+ let expr = self . arena . alloc ( self . lower_expr_mut ( expr) ) ;
878+ self . make_lowered_await ( await_kw_span, expr, FutureKind :: Future )
879+ }
880+
881+ /// Takes an expr that has already been lowered and generates a desugared await loop around it
882+ fn make_lowered_await (
883+ & mut self ,
884+ await_kw_span : Span ,
885+ expr : & ' hir hir:: Expr < ' hir > ,
886+ await_kind : FutureKind ,
887+ ) -> hir:: ExprKind < ' hir > {
877888 let full_span = expr. span . to ( await_kw_span) ;
878889
879890 let is_async_gen = match self . coroutine_kind {
@@ -887,13 +898,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
887898 }
888899 } ;
889900
890- let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, None ) ;
901+ let features = match await_kind {
902+ FutureKind :: Future => None ,
903+ FutureKind :: AsyncIterator => Some ( self . allow_for_await . clone ( ) ) ,
904+ } ;
905+ let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, features) ;
891906 let gen_future_span = self . mark_span_with_reason (
892907 DesugaringKind :: Await ,
893908 full_span,
894909 Some ( self . allow_gen_future . clone ( ) ) ,
895910 ) ;
896- let expr = self . lower_expr_mut ( expr) ;
897911 let expr_hir_id = expr. hir_id ;
898912
899913 // Note that the name of this binding must not be changed to something else because
@@ -933,11 +947,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
933947 hir:: LangItem :: GetContext ,
934948 arena_vec ! [ self ; task_context] ,
935949 ) ;
936- let call = self . expr_call_lang_item_fn (
937- span,
938- hir:: LangItem :: FuturePoll ,
939- arena_vec ! [ self ; new_unchecked, get_context] ,
940- ) ;
950+ let call = match await_kind {
951+ FutureKind :: Future => self . expr_call_lang_item_fn (
952+ span,
953+ hir:: LangItem :: FuturePoll ,
954+ arena_vec ! [ self ; new_unchecked, get_context] ,
955+ ) ,
956+ FutureKind :: AsyncIterator => self . expr_call_lang_item_fn (
957+ span,
958+ hir:: LangItem :: AsyncIteratorPollNext ,
959+ arena_vec ! [ self ; new_unchecked, get_context] ,
960+ ) ,
961+ } ;
941962 self . arena . alloc ( self . expr_unsafe ( call) )
942963 } ;
943964
@@ -1021,11 +1042,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
10211042 let awaitee_arm = self . arm ( awaitee_pat, loop_expr) ;
10221043
10231044 // `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
1024- let into_future_expr = self . expr_call_lang_item_fn (
1025- span,
1026- hir:: LangItem :: IntoFutureIntoFuture ,
1027- arena_vec ! [ self ; expr] ,
1028- ) ;
1045+ let into_future_expr = match await_kind {
1046+ FutureKind :: Future => self . expr_call_lang_item_fn (
1047+ span,
1048+ hir:: LangItem :: IntoFutureIntoFuture ,
1049+ arena_vec ! [ self ; * expr] ,
1050+ ) ,
1051+ // Not needed for `for await` because we expect to have already called
1052+ // `IntoAsyncIterator::into_async_iter` on it.
1053+ FutureKind :: AsyncIterator => expr,
1054+ } ;
10291055
10301056 // match <into_future_expr> {
10311057 // mut __awaitee => loop { .. }
@@ -1673,7 +1699,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
16731699 head : & Expr ,
16741700 body : & Block ,
16751701 opt_label : Option < Label > ,
1676- _loop_kind : ForLoopKind ,
1702+ loop_kind : ForLoopKind ,
16771703 ) -> hir:: Expr < ' hir > {
16781704 let head = self . lower_expr_mut ( head) ;
16791705 let pat = self . lower_pat ( pat) ;
@@ -1702,17 +1728,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
17021728 let ( iter_pat, iter_pat_nid) =
17031729 self . pat_ident_binding_mode ( head_span, iter, hir:: BindingAnnotation :: MUT ) ;
17041730
1705- // `match Iterator::next(&mut iter) { ... }`
17061731 let match_expr = {
17071732 let iter = self . expr_ident ( head_span, iter, iter_pat_nid) ;
1708- let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1709- let next_expr = self . expr_call_lang_item_fn (
1710- head_span,
1711- hir:: LangItem :: IteratorNext ,
1712- arena_vec ! [ self ; ref_mut_iter] ,
1713- ) ;
1733+ let next_expr = match loop_kind {
1734+ ForLoopKind :: For => {
1735+ // `Iterator::next(&mut iter)`
1736+ let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1737+ self . expr_call_lang_item_fn (
1738+ head_span,
1739+ hir:: LangItem :: IteratorNext ,
1740+ arena_vec ! [ self ; ref_mut_iter] ,
1741+ )
1742+ }
1743+ ForLoopKind :: ForAwait => {
1744+ // we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
1745+ // to make_lowered_await with `FutureKind::AsyncIterator` which will generator
1746+ // calls to `poll_next`. In user code, this would probably be a call to
1747+ // `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
1748+
1749+ // `&mut iter`
1750+ let iter = self . expr_mut_addr_of ( head_span, iter) ;
1751+ // `Pin::new_unchecked(...)`
1752+ let iter = self . arena . alloc ( self . expr_call_lang_item_fn_mut (
1753+ head_span,
1754+ hir:: LangItem :: PinNewUnchecked ,
1755+ arena_vec ! [ self ; iter] ,
1756+ ) ) ;
1757+ // `unsafe { ... }`
1758+ let iter = self . arena . alloc ( self . expr_unsafe ( iter) ) ;
1759+ let kind = self . make_lowered_await ( head_span, iter, FutureKind :: AsyncIterator ) ;
1760+ self . arena . alloc ( hir:: Expr { hir_id : self . next_id ( ) , kind, span : head_span } )
1761+ }
1762+ } ;
17141763 let arms = arena_vec ! [ self ; none_arm, some_arm] ;
17151764
1765+ // `match $next_expr { ... }`
17161766 self . expr_match ( head_span, next_expr, arms, hir:: MatchSource :: ForLoopDesugar )
17171767 } ;
17181768 let match_stmt = self . stmt_expr ( for_span, match_expr) ;
@@ -1732,13 +1782,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
17321782 // `mut iter => { ... }`
17331783 let iter_arm = self . arm ( iter_pat, loop_expr) ;
17341784
1735- // `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
1736- let into_iter_expr = {
1737- self . expr_call_lang_item_fn (
1738- head_span,
1739- hir:: LangItem :: IntoIterIntoIter ,
1740- arena_vec ! [ self ; head] ,
1741- )
1785+ let into_iter_expr = match loop_kind {
1786+ ForLoopKind :: For => {
1787+ // `::std::iter::IntoIterator::into_iter(<head>)`
1788+ self . expr_call_lang_item_fn (
1789+ head_span,
1790+ hir:: LangItem :: IntoIterIntoIter ,
1791+ arena_vec ! [ self ; head] ,
1792+ )
1793+ }
1794+ ForLoopKind :: ForAwait => self . arena . alloc ( head) ,
17421795 } ;
17431796
17441797 let match_expr = self . arena . alloc ( self . expr_match (
@@ -2141,3 +2194,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
21412194 }
21422195 }
21432196}
2197+
2198+ /// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
2199+ /// of future we are awaiting.
2200+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
2201+ enum FutureKind {
2202+ /// We are awaiting a normal future
2203+ Future ,
2204+ /// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
2205+ /// a `for await` loop)
2206+ AsyncIterator ,
2207+ }
0 commit comments