11extern crate proc_macro;
22
33use proc_macro:: { TokenStream , TokenTree } ;
4- use proc_macro2:: Span ;
4+ use proc_macro2:: { Group , Span , TokenStream as TokenStream2 , TokenTree as TokenTree2 } ;
55use quote:: quote;
66use syn:: visit_mut:: VisitMut ;
77
@@ -19,7 +19,7 @@ struct AsyncStreamEnumHack {
1919}
2020
2121impl AsyncStreamEnumHack {
22- fn parse ( input : TokenStream ) -> Self {
22+ fn parse ( input : TokenStream ) -> syn :: Result < Self > {
2323 macro_rules! n {
2424 ( $i: ident) => {
2525 $i. next( ) . unwrap( )
@@ -44,14 +44,15 @@ impl AsyncStreamEnumHack {
4444 n ! ( braces) ; // !
4545
4646 let inner = n ! ( braces) ;
47- let syn:: Block { stmts, .. } = syn:: parse ( inner. clone ( ) . into ( ) ) . unwrap ( ) ;
47+ let inner = replace_for_await ( TokenStream2 :: from ( TokenStream :: from ( inner) ) ) ;
48+ let syn:: Block { stmts, .. } = syn:: parse2 ( inner. clone ( ) ) ?;
4849
4950 let macro_ident = syn:: Ident :: new (
5051 & format ! ( "stream_{}" , count_bangs( inner. into( ) ) ) ,
5152 Span :: call_site ( ) ,
5253 ) ;
5354
54- AsyncStreamEnumHack { stmts, macro_ident }
55+ Ok ( AsyncStreamEnumHack { stmts, macro_ident } )
5556 }
5657}
5758
@@ -100,6 +101,42 @@ impl VisitMut for Scrub {
100101 syn:: visit_mut:: visit_expr_mut ( self , i) ;
101102 self . is_xforming = prev;
102103 }
104+ syn:: Expr :: ForLoop ( expr) => {
105+ syn:: visit_mut:: visit_expr_for_loop_mut ( self , expr) ;
106+ // TODO: Should we allow other attributes?
107+ if expr. attrs . len ( ) != 1 || !expr. attrs [ 0 ] . path . is_ident ( "await" ) {
108+ return ;
109+ }
110+ let syn:: ExprForLoop {
111+ attrs,
112+ label,
113+ pat,
114+ expr,
115+ body,
116+ ..
117+ } = expr;
118+
119+ let attr = attrs. pop ( ) . unwrap ( ) ;
120+ if let Err ( e) = syn:: parse2 :: < syn:: parse:: Nothing > ( attr. tokens ) {
121+ * i = syn:: parse2 ( e. to_compile_error ( ) ) . unwrap ( ) ;
122+ return ;
123+ }
124+
125+ * i = syn:: parse_quote! { {
126+ let mut __pinned = #expr;
127+ let mut __pinned = unsafe {
128+ :: async_stream:: reexport:: Pin :: new_unchecked( & mut __pinned)
129+ } ;
130+ #label
131+ loop {
132+ let #pat = match :: async_stream:: reexport:: next( & mut __pinned) . await {
133+ :: async_stream:: reexport:: Some ( e) => e,
134+ :: async_stream:: reexport:: None => break ,
135+ } ;
136+ #body
137+ }
138+ } }
139+ }
103140 _ => syn:: visit_mut:: visit_expr_mut ( self , i) ,
104141 }
105142 }
@@ -117,7 +154,10 @@ pub fn async_stream_impl(input: TokenStream) -> TokenStream {
117154 let AsyncStreamEnumHack {
118155 macro_ident,
119156 mut stmts,
120- } = AsyncStreamEnumHack :: parse ( input) ;
157+ } = match AsyncStreamEnumHack :: parse ( input) {
158+ Ok ( x) => x,
159+ Err ( e) => return e. to_compile_error ( ) . into ( ) ,
160+ } ;
121161
122162 let mut scrub = Scrub {
123163 is_xforming : true ,
@@ -156,7 +196,10 @@ pub fn async_try_stream_impl(input: TokenStream) -> TokenStream {
156196 let AsyncStreamEnumHack {
157197 macro_ident,
158198 mut stmts,
159- } = AsyncStreamEnumHack :: parse ( input) ;
199+ } = match AsyncStreamEnumHack :: parse ( input) {
200+ Ok ( x) => x,
201+ Err ( e) => return e. to_compile_error ( ) . into ( ) ,
202+ } ;
160203
161204 let mut scrub = Scrub {
162205 is_xforming : true ,
@@ -209,3 +252,30 @@ fn count_bangs(input: TokenStream) -> usize {
209252
210253 count
211254}
255+
256+ fn replace_for_await ( input : TokenStream2 ) -> TokenStream2 {
257+ let mut input = input. into_iter ( ) . peekable ( ) ;
258+ let mut tokens = Vec :: new ( ) ;
259+
260+ while let Some ( token) = input. next ( ) {
261+ match token {
262+ TokenTree2 :: Ident ( ident) => {
263+ match input. peek ( ) {
264+ Some ( TokenTree2 :: Ident ( next) ) if ident == "for" && next == "await" => {
265+ tokens. extend ( quote ! ( #[ #next] ) ) ;
266+ let _ = input. next ( ) ;
267+ }
268+ _ => { }
269+ }
270+ tokens. push ( ident. into ( ) ) ;
271+ }
272+ TokenTree2 :: Group ( group) => {
273+ let stream = replace_for_await ( group. stream ( ) ) ;
274+ tokens. push ( Group :: new ( group. delimiter ( ) , stream) . into ( ) ) ;
275+ }
276+ _ => tokens. push ( token) ,
277+ }
278+ }
279+
280+ tokens. into_iter ( ) . collect ( )
281+ }
0 commit comments