@@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
337337 ) ,
338338 ExprKind :: Try ( sub_expr) => self . lower_expr_try ( e. span , sub_expr) ,
339339
340- ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
340+ ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
341341 unreachable ! ( "already handled" )
342342 }
343343
@@ -871,6 +871,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
871871 /// }
872872 /// ```
873873 fn lower_expr_await ( & mut self , await_kw_span : Span , expr : & Expr ) -> hir:: ExprKind < ' hir > {
874+ let expr = self . arena . alloc ( self . lower_expr_mut ( expr) ) ;
875+ self . make_lowered_await ( await_kw_span, expr, FutureKind :: Future )
876+ }
877+
878+ /// Takes an expr that has already been lowered and generates a desugared await loop around it
879+ fn make_lowered_await (
880+ & mut self ,
881+ await_kw_span : Span ,
882+ expr : & ' hir hir:: Expr < ' hir > ,
883+ await_kind : FutureKind ,
884+ ) -> hir:: ExprKind < ' hir > {
874885 let full_span = expr. span . to ( await_kw_span) ;
875886
876887 let is_async_gen = match self . coroutine_kind {
@@ -884,13 +895,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
884895 }
885896 } ;
886897
887- let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, None ) ;
898+ let features = match await_kind {
899+ FutureKind :: Future => None ,
900+ FutureKind :: AsyncIterator => Some ( self . allow_for_await . clone ( ) ) ,
901+ } ;
902+ let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, features) ;
888903 let gen_future_span = self . mark_span_with_reason (
889904 DesugaringKind :: Await ,
890905 full_span,
891906 Some ( self . allow_gen_future . clone ( ) ) ,
892907 ) ;
893- let expr = self . lower_expr_mut ( expr) ;
894908 let expr_hir_id = expr. hir_id ;
895909
896910 // Note that the name of this binding must not be changed to something else because
@@ -930,11 +944,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
930944 hir:: LangItem :: GetContext ,
931945 arena_vec ! [ self ; task_context] ,
932946 ) ;
933- let call = self . expr_call_lang_item_fn (
934- span,
935- hir:: LangItem :: FuturePoll ,
936- arena_vec ! [ self ; new_unchecked, get_context] ,
937- ) ;
947+ let call = match await_kind {
948+ FutureKind :: Future => self . expr_call_lang_item_fn (
949+ span,
950+ hir:: LangItem :: FuturePoll ,
951+ arena_vec ! [ self ; new_unchecked, get_context] ,
952+ ) ,
953+ FutureKind :: AsyncIterator => self . expr_call_lang_item_fn (
954+ span,
955+ hir:: LangItem :: AsyncIteratorPollNext ,
956+ arena_vec ! [ self ; new_unchecked, get_context] ,
957+ ) ,
958+ } ;
938959 self . arena . alloc ( self . expr_unsafe ( call) )
939960 } ;
940961
@@ -1018,11 +1039,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
10181039 let awaitee_arm = self . arm ( awaitee_pat, loop_expr) ;
10191040
10201041 // `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
1021- let into_future_expr = self . expr_call_lang_item_fn (
1022- span,
1023- hir:: LangItem :: IntoFutureIntoFuture ,
1024- arena_vec ! [ self ; expr] ,
1025- ) ;
1042+ let into_future_expr = match await_kind {
1043+ FutureKind :: Future => self . expr_call_lang_item_fn (
1044+ span,
1045+ hir:: LangItem :: IntoFutureIntoFuture ,
1046+ arena_vec ! [ self ; * expr] ,
1047+ ) ,
1048+ // Not needed for `for await` because we expect to have already called
1049+ // `IntoAsyncIterator::into_async_iter` on it.
1050+ FutureKind :: AsyncIterator => expr,
1051+ } ;
10261052
10271053 // match <into_future_expr> {
10281054 // mut __awaitee => loop { .. }
@@ -1670,7 +1696,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
16701696 head : & Expr ,
16711697 body : & Block ,
16721698 opt_label : Option < Label > ,
1673- _loop_kind : ForLoopKind ,
1699+ loop_kind : ForLoopKind ,
16741700 ) -> hir:: Expr < ' hir > {
16751701 let head = self . lower_expr_mut ( head) ;
16761702 let pat = self . lower_pat ( pat) ;
@@ -1699,17 +1725,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
16991725 let ( iter_pat, iter_pat_nid) =
17001726 self . pat_ident_binding_mode ( head_span, iter, hir:: BindingAnnotation :: MUT ) ;
17011727
1702- // `match Iterator::next(&mut iter) { ... }`
17031728 let match_expr = {
17041729 let iter = self . expr_ident ( head_span, iter, iter_pat_nid) ;
1705- let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1706- let next_expr = self . expr_call_lang_item_fn (
1707- head_span,
1708- hir:: LangItem :: IteratorNext ,
1709- arena_vec ! [ self ; ref_mut_iter] ,
1710- ) ;
1730+ let next_expr = match loop_kind {
1731+ ForLoopKind :: For => {
1732+ // `Iterator::next(&mut iter)`
1733+ let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1734+ self . expr_call_lang_item_fn (
1735+ head_span,
1736+ hir:: LangItem :: IteratorNext ,
1737+ arena_vec ! [ self ; ref_mut_iter] ,
1738+ )
1739+ }
1740+ ForLoopKind :: ForAwait => {
1741+ // we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
1742+ // to make_lowered_await with `FutureKind::AsyncIterator` which will generator
1743+ // calls to `poll_next`. In user code, this would probably be a call to
1744+ // `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
1745+
1746+ // `&mut iter`
1747+ let iter = self . expr_mut_addr_of ( head_span, iter) ;
1748+ // `Pin::new_unchecked(...)`
1749+ let iter = self . arena . alloc ( self . expr_call_lang_item_fn_mut (
1750+ head_span,
1751+ hir:: LangItem :: PinNewUnchecked ,
1752+ arena_vec ! [ self ; iter] ,
1753+ ) ) ;
1754+ // `unsafe { ... }`
1755+ let iter = self . arena . alloc ( self . expr_unsafe ( iter) ) ;
1756+ let kind = self . make_lowered_await ( head_span, iter, FutureKind :: AsyncIterator ) ;
1757+ self . arena . alloc ( hir:: Expr { hir_id : self . next_id ( ) , kind, span : head_span } )
1758+ }
1759+ } ;
17111760 let arms = arena_vec ! [ self ; none_arm, some_arm] ;
17121761
1762+ // `match $next_expr { ... }`
17131763 self . expr_match ( head_span, next_expr, arms, hir:: MatchSource :: ForLoopDesugar )
17141764 } ;
17151765 let match_stmt = self . stmt_expr ( for_span, match_expr) ;
@@ -1729,13 +1779,24 @@ impl<'hir> LoweringContext<'_, 'hir> {
17291779 // `mut iter => { ... }`
17301780 let iter_arm = self . arm ( iter_pat, loop_expr) ;
17311781
1732- // `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
1733- let into_iter_expr = {
1734- self . expr_call_lang_item_fn (
1735- head_span,
1736- hir:: LangItem :: IntoIterIntoIter ,
1737- arena_vec ! [ self ; head] ,
1738- )
1782+ let into_iter_expr = match loop_kind {
1783+ ForLoopKind :: For => {
1784+ // `::std::iter::IntoIterator::into_iter(<head>)`
1785+ self . expr_call_lang_item_fn (
1786+ head_span,
1787+ hir:: LangItem :: IntoIterIntoIter ,
1788+ arena_vec ! [ self ; head] ,
1789+ )
1790+ }
1791+ ForLoopKind :: ForAwait => {
1792+ // `::core::async_iter::IntoAsyncIterator::into_async_iter(<head>)`
1793+ let iter = self . expr_call_lang_item_fn (
1794+ head_span,
1795+ hir:: LangItem :: IntoAsyncIterIntoIter ,
1796+ arena_vec ! [ self ; head] ,
1797+ ) ;
1798+ self . arena . alloc ( self . expr_mut_addr_of ( head_span, iter) )
1799+ }
17391800 } ;
17401801
17411802 let match_expr = self . arena . alloc ( self . expr_match (
@@ -2138,3 +2199,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
21382199 }
21392200 }
21402201}
2202+
2203+ /// Used by [`make_lowered_await`] to customize the desugaring based on what kind of future we are
2204+ /// awaiting.
2205+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
2206+ enum FutureKind {
2207+ /// We are awaiting a normal future
2208+ Future ,
2209+ /// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
2210+ /// a `for await` loop)
2211+ AsyncIterator ,
2212+ }
0 commit comments