• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 extern crate proc_macro;
2 use proc_macro::TokenStream;
3 use proc_macro2::{Group, TokenStream as TokenStream2, TokenTree};
4 use quote::quote;
5 use syn::parse::{Parse, ParseStream, Parser, Result};
6 use syn::visit_mut::VisitMut;
7 
8 struct Scrub<'a> {
9     /// Whether the stream is a try stream.
10     is_try: bool,
11     /// The unit expression, `()`.
12     unit: Box<syn::Expr>,
13     has_yielded: bool,
14     crate_path: &'a TokenStream2,
15 }
16 
parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)>17 fn parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)> {
18     let mut input = TokenStream2::from(input).into_iter();
19     let crate_path = match input.next().unwrap() {
20         TokenTree::Group(group) => group.stream(),
21         _ => panic!(),
22     };
23     let stmts = syn::Block::parse_within.parse2(replace_for_await(input))?;
24     Ok((crate_path, stmts))
25 }
26 
27 impl<'a> Scrub<'a> {
new(is_try: bool, crate_path: &'a TokenStream2) -> Self28     fn new(is_try: bool, crate_path: &'a TokenStream2) -> Self {
29         Self {
30             is_try,
31             unit: syn::parse_quote!(()),
32             has_yielded: false,
33             crate_path,
34         }
35     }
36 }
37 
38 struct Partial<T>(T, TokenStream2);
39 
40 impl<T: Parse> Parse for Partial<T> {
parse(input: ParseStream) -> Result<Self>41     fn parse(input: ParseStream) -> Result<Self> {
42         Ok(Partial(input.parse()?, input.parse()?))
43     }
44 }
45 
visit_token_stream_impl( visitor: &mut Scrub<'_>, tokens: TokenStream2, modified: &mut bool, out: &mut TokenStream2, )46 fn visit_token_stream_impl(
47     visitor: &mut Scrub<'_>,
48     tokens: TokenStream2,
49     modified: &mut bool,
50     out: &mut TokenStream2,
51 ) {
52     use quote::ToTokens;
53     use quote::TokenStreamExt;
54 
55     let mut tokens = tokens.into_iter().peekable();
56     while let Some(tt) = tokens.next() {
57         match tt {
58             TokenTree::Ident(i) if i == "yield" => {
59                 let stream = std::iter::once(TokenTree::Ident(i)).chain(tokens).collect();
60                 match syn::parse2(stream) {
61                     Ok(Partial(yield_expr, rest)) => {
62                         let mut expr = syn::Expr::Yield(yield_expr);
63                         visitor.visit_expr_mut(&mut expr);
64                         expr.to_tokens(out);
65                         *modified = true;
66                         tokens = rest.into_iter().peekable();
67                     }
68                     Err(e) => {
69                         out.append_all(&mut e.to_compile_error().into_iter());
70                         *modified = true;
71                         return;
72                     }
73                 }
74             }
75             TokenTree::Ident(i) if i == "stream" || i == "try_stream" => {
76                 out.append(TokenTree::Ident(i));
77                 match tokens.peek() {
78                     Some(TokenTree::Punct(p)) if p.as_char() == '!' => {
79                         out.extend(tokens.next()); // !
80                         if let Some(TokenTree::Group(_)) = tokens.peek() {
81                             out.extend(tokens.next()); // { .. } or [ .. ] or ( .. )
82                         }
83                     }
84                     _ => {}
85                 }
86             }
87             TokenTree::Group(group) => {
88                 let mut content = group.stream();
89                 *modified |= visitor.visit_token_stream(&mut content);
90                 let mut new = Group::new(group.delimiter(), content);
91                 new.set_span(group.span());
92                 out.append(new);
93             }
94             other => out.append(other),
95         }
96     }
97 }
98 
99 impl Scrub<'_> {
visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool100     fn visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool {
101         let (mut out, mut modified) = (TokenStream2::new(), false);
102         visit_token_stream_impl(self, tokens.clone(), &mut modified, &mut out);
103 
104         if modified {
105             *tokens = out;
106         }
107 
108         modified
109     }
110 }
111 
112 impl VisitMut for Scrub<'_> {
visit_expr_mut(&mut self, i: &mut syn::Expr)113     fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
114         match i {
115             syn::Expr::Yield(yield_expr) => {
116                 self.has_yielded = true;
117 
118                 let value_expr = yield_expr.expr.as_ref().unwrap_or(&self.unit);
119 
120                 // let ident = &self.yielder;
121 
122                 *i = if self.is_try {
123                     syn::parse_quote! { __yield_tx.send(::core::result::Result::Ok(#value_expr)).await }
124                 } else {
125                     syn::parse_quote! { __yield_tx.send(#value_expr).await }
126                 };
127             }
128             syn::Expr::Try(try_expr) => {
129                 syn::visit_mut::visit_expr_try_mut(self, try_expr);
130                 // let ident = &self.yielder;
131                 let e = &try_expr.expr;
132 
133                 *i = syn::parse_quote! {
134                     match #e {
135                         ::core::result::Result::Ok(v) => v,
136                         ::core::result::Result::Err(e) => {
137                             __yield_tx.send(::core::result::Result::Err(e.into())).await;
138                             return;
139                         }
140                     }
141                 };
142             }
143             syn::Expr::Closure(_) | syn::Expr::Async(_) => {
144                 // Don't transform inner closures or async blocks.
145             }
146             syn::Expr::ForLoop(expr) => {
147                 syn::visit_mut::visit_expr_for_loop_mut(self, expr);
148                 // TODO: Should we allow other attributes?
149                 if expr.attrs.len() != 1 || !expr.attrs[0].path.is_ident("await") {
150                     return;
151                 }
152                 let syn::ExprForLoop {
153                     attrs,
154                     label,
155                     pat,
156                     expr,
157                     body,
158                     ..
159                 } = expr;
160 
161                 let attr = attrs.pop().unwrap();
162                 if let Err(e) = syn::parse2::<syn::parse::Nothing>(attr.tokens) {
163                     *i = syn::parse2(e.to_compile_error()).unwrap();
164                     return;
165                 }
166 
167                 let crate_path = self.crate_path;
168                 *i = syn::parse_quote! {{
169                     let mut __pinned = #expr;
170                     let mut __pinned = unsafe {
171                         ::core::pin::Pin::new_unchecked(&mut __pinned)
172                     };
173                     #label
174                     loop {
175                         let #pat = match #crate_path::reexport::next(&mut __pinned).await {
176                             ::core::option::Option::Some(e) => e,
177                             ::core::option::Option::None => break,
178                         };
179                         #body
180                     }
181                 }}
182             }
183             _ => syn::visit_mut::visit_expr_mut(self, i),
184         }
185     }
186 
visit_macro_mut(&mut self, mac: &mut syn::Macro)187     fn visit_macro_mut(&mut self, mac: &mut syn::Macro) {
188         let mac_ident = mac.path.segments.last().map(|p| &p.ident);
189         if mac_ident.map_or(false, |i| i == "stream" || i == "try_stream") {
190             return;
191         }
192 
193         self.visit_token_stream(&mut mac.tokens);
194     }
195 
visit_item_mut(&mut self, i: &mut syn::Item)196     fn visit_item_mut(&mut self, i: &mut syn::Item) {
197         // Recurse into macros but otherwise don't transform inner items.
198         if let syn::Item::Macro(i) = i {
199             self.visit_macro_mut(&mut i.mac);
200         }
201     }
202 }
203 
204 /// The first token tree in the stream must be a group containing the path to the `async-stream`
205 /// crate.
206 #[proc_macro]
207 #[doc(hidden)]
stream_inner(input: TokenStream) -> TokenStream208 pub fn stream_inner(input: TokenStream) -> TokenStream {
209     let (crate_path, mut stmts) = match parse_input(input) {
210         Ok(x) => x,
211         Err(e) => return e.to_compile_error().into(),
212     };
213 
214     let mut scrub = Scrub::new(false, &crate_path);
215 
216     for mut stmt in &mut stmts {
217         scrub.visit_stmt_mut(&mut stmt);
218     }
219 
220     let dummy_yield = if scrub.has_yielded {
221         None
222     } else {
223         Some(quote!(if false {
224             __yield_tx.send(()).await;
225         }))
226     };
227 
228     quote!({
229         let (mut __yield_tx, __yield_rx) = #crate_path::yielder::pair();
230         #crate_path::AsyncStream::new(__yield_rx, async move {
231             #dummy_yield
232             #(#stmts)*
233         })
234     })
235     .into()
236 }
237 
238 /// The first token tree in the stream must be a group containing the path to the `async-stream`
239 /// crate.
240 #[proc_macro]
241 #[doc(hidden)]
try_stream_inner(input: TokenStream) -> TokenStream242 pub fn try_stream_inner(input: TokenStream) -> TokenStream {
243     let (crate_path, mut stmts) = match parse_input(input) {
244         Ok(x) => x,
245         Err(e) => return e.to_compile_error().into(),
246     };
247 
248     let mut scrub = Scrub::new(true, &crate_path);
249 
250     for mut stmt in &mut stmts {
251         scrub.visit_stmt_mut(&mut stmt);
252     }
253 
254     let dummy_yield = if scrub.has_yielded {
255         None
256     } else {
257         Some(quote!(if false {
258             __yield_tx.send(()).await;
259         }))
260     };
261 
262     quote!({
263         let (mut __yield_tx, __yield_rx) = #crate_path::yielder::pair();
264         #crate_path::AsyncStream::new(__yield_rx, async move {
265             #dummy_yield
266             #(#stmts)*
267         })
268     })
269     .into()
270 }
271 
272 /// Replace `for await` with `#[await] for`, which will be later transformed into a `next` loop.
replace_for_await(input: impl IntoIterator<Item = TokenTree>) -> TokenStream2273 fn replace_for_await(input: impl IntoIterator<Item = TokenTree>) -> TokenStream2 {
274     let mut input = input.into_iter().peekable();
275     let mut tokens = Vec::new();
276 
277     while let Some(token) = input.next() {
278         match token {
279             TokenTree::Ident(ident) => {
280                 match input.peek() {
281                     Some(TokenTree::Ident(next)) if ident == "for" && next == "await" => {
282                         tokens.extend(quote!(#[#next]));
283                         let _ = input.next();
284                     }
285                     _ => {}
286                 }
287                 tokens.push(ident.into());
288             }
289             TokenTree::Group(group) => {
290                 let stream = replace_for_await(group.stream());
291                 tokens.push(Group::new(group.delimiter(), stream).into());
292             }
293             _ => tokens.push(token),
294         }
295     }
296 
297     tokens.into_iter().collect()
298 }
299