1 extern crate proc_macro;
2 use proc_macro::TokenStream;
3 use proc_macro2::{Group, TokenStream as TokenStream2, TokenTree};
4 use quote::quote;
5 use syn::parse::{Parse, ParseStream, Parser, Result};
6 use syn::visit_mut::VisitMut;
7
8 struct Scrub<'a> {
9 /// Whether the stream is a try stream.
10 is_try: bool,
11 /// The unit expression, `()`.
12 unit: Box<syn::Expr>,
13 has_yielded: bool,
14 crate_path: &'a TokenStream2,
15 }
16
parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)>17 fn parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)> {
18 let mut input = TokenStream2::from(input).into_iter();
19 let crate_path = match input.next().unwrap() {
20 TokenTree::Group(group) => group.stream(),
21 _ => panic!(),
22 };
23 let stmts = syn::Block::parse_within.parse2(replace_for_await(input))?;
24 Ok((crate_path, stmts))
25 }
26
27 impl<'a> Scrub<'a> {
new(is_try: bool, crate_path: &'a TokenStream2) -> Self28 fn new(is_try: bool, crate_path: &'a TokenStream2) -> Self {
29 Self {
30 is_try,
31 unit: syn::parse_quote!(()),
32 has_yielded: false,
33 crate_path,
34 }
35 }
36 }
37
38 struct Partial<T>(T, TokenStream2);
39
40 impl<T: Parse> Parse for Partial<T> {
parse(input: ParseStream) -> Result<Self>41 fn parse(input: ParseStream) -> Result<Self> {
42 Ok(Partial(input.parse()?, input.parse()?))
43 }
44 }
45
visit_token_stream_impl( visitor: &mut Scrub<'_>, tokens: TokenStream2, modified: &mut bool, out: &mut TokenStream2, )46 fn visit_token_stream_impl(
47 visitor: &mut Scrub<'_>,
48 tokens: TokenStream2,
49 modified: &mut bool,
50 out: &mut TokenStream2,
51 ) {
52 use quote::ToTokens;
53 use quote::TokenStreamExt;
54
55 let mut tokens = tokens.into_iter().peekable();
56 while let Some(tt) = tokens.next() {
57 match tt {
58 TokenTree::Ident(i) if i == "yield" => {
59 let stream = std::iter::once(TokenTree::Ident(i)).chain(tokens).collect();
60 match syn::parse2(stream) {
61 Ok(Partial(yield_expr, rest)) => {
62 let mut expr = syn::Expr::Yield(yield_expr);
63 visitor.visit_expr_mut(&mut expr);
64 expr.to_tokens(out);
65 *modified = true;
66 tokens = rest.into_iter().peekable();
67 }
68 Err(e) => {
69 out.append_all(&mut e.to_compile_error().into_iter());
70 *modified = true;
71 return;
72 }
73 }
74 }
75 TokenTree::Ident(i) if i == "stream" || i == "try_stream" => {
76 out.append(TokenTree::Ident(i));
77 match tokens.peek() {
78 Some(TokenTree::Punct(p)) if p.as_char() == '!' => {
79 out.extend(tokens.next()); // !
80 if let Some(TokenTree::Group(_)) = tokens.peek() {
81 out.extend(tokens.next()); // { .. } or [ .. ] or ( .. )
82 }
83 }
84 _ => {}
85 }
86 }
87 TokenTree::Group(group) => {
88 let mut content = group.stream();
89 *modified |= visitor.visit_token_stream(&mut content);
90 let mut new = Group::new(group.delimiter(), content);
91 new.set_span(group.span());
92 out.append(new);
93 }
94 other => out.append(other),
95 }
96 }
97 }
98
99 impl Scrub<'_> {
visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool100 fn visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool {
101 let (mut out, mut modified) = (TokenStream2::new(), false);
102 visit_token_stream_impl(self, tokens.clone(), &mut modified, &mut out);
103
104 if modified {
105 *tokens = out;
106 }
107
108 modified
109 }
110 }
111
112 impl VisitMut for Scrub<'_> {
visit_expr_mut(&mut self, i: &mut syn::Expr)113 fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
114 match i {
115 syn::Expr::Yield(yield_expr) => {
116 self.has_yielded = true;
117
118 let value_expr = yield_expr.expr.as_ref().unwrap_or(&self.unit);
119
120 // let ident = &self.yielder;
121
122 *i = if self.is_try {
123 syn::parse_quote! { __yield_tx.send(::core::result::Result::Ok(#value_expr)).await }
124 } else {
125 syn::parse_quote! { __yield_tx.send(#value_expr).await }
126 };
127 }
128 syn::Expr::Try(try_expr) => {
129 syn::visit_mut::visit_expr_try_mut(self, try_expr);
130 // let ident = &self.yielder;
131 let e = &try_expr.expr;
132
133 *i = syn::parse_quote! {
134 match #e {
135 ::core::result::Result::Ok(v) => v,
136 ::core::result::Result::Err(e) => {
137 __yield_tx.send(::core::result::Result::Err(e.into())).await;
138 return;
139 }
140 }
141 };
142 }
143 syn::Expr::Closure(_) | syn::Expr::Async(_) => {
144 // Don't transform inner closures or async blocks.
145 }
146 syn::Expr::ForLoop(expr) => {
147 syn::visit_mut::visit_expr_for_loop_mut(self, expr);
148 // TODO: Should we allow other attributes?
149 if expr.attrs.len() != 1 || !expr.attrs[0].path.is_ident("await") {
150 return;
151 }
152 let syn::ExprForLoop {
153 attrs,
154 label,
155 pat,
156 expr,
157 body,
158 ..
159 } = expr;
160
161 let attr = attrs.pop().unwrap();
162 if let Err(e) = syn::parse2::<syn::parse::Nothing>(attr.tokens) {
163 *i = syn::parse2(e.to_compile_error()).unwrap();
164 return;
165 }
166
167 let crate_path = self.crate_path;
168 *i = syn::parse_quote! {{
169 let mut __pinned = #expr;
170 let mut __pinned = unsafe {
171 ::core::pin::Pin::new_unchecked(&mut __pinned)
172 };
173 #label
174 loop {
175 let #pat = match #crate_path::reexport::next(&mut __pinned).await {
176 ::core::option::Option::Some(e) => e,
177 ::core::option::Option::None => break,
178 };
179 #body
180 }
181 }}
182 }
183 _ => syn::visit_mut::visit_expr_mut(self, i),
184 }
185 }
186
visit_macro_mut(&mut self, mac: &mut syn::Macro)187 fn visit_macro_mut(&mut self, mac: &mut syn::Macro) {
188 let mac_ident = mac.path.segments.last().map(|p| &p.ident);
189 if mac_ident.map_or(false, |i| i == "stream" || i == "try_stream") {
190 return;
191 }
192
193 self.visit_token_stream(&mut mac.tokens);
194 }
195
visit_item_mut(&mut self, i: &mut syn::Item)196 fn visit_item_mut(&mut self, i: &mut syn::Item) {
197 // Recurse into macros but otherwise don't transform inner items.
198 if let syn::Item::Macro(i) = i {
199 self.visit_macro_mut(&mut i.mac);
200 }
201 }
202 }
203
204 /// The first token tree in the stream must be a group containing the path to the `async-stream`
205 /// crate.
206 #[proc_macro]
207 #[doc(hidden)]
stream_inner(input: TokenStream) -> TokenStream208 pub fn stream_inner(input: TokenStream) -> TokenStream {
209 let (crate_path, mut stmts) = match parse_input(input) {
210 Ok(x) => x,
211 Err(e) => return e.to_compile_error().into(),
212 };
213
214 let mut scrub = Scrub::new(false, &crate_path);
215
216 for mut stmt in &mut stmts {
217 scrub.visit_stmt_mut(&mut stmt);
218 }
219
220 let dummy_yield = if scrub.has_yielded {
221 None
222 } else {
223 Some(quote!(if false {
224 __yield_tx.send(()).await;
225 }))
226 };
227
228 quote!({
229 let (mut __yield_tx, __yield_rx) = #crate_path::yielder::pair();
230 #crate_path::AsyncStream::new(__yield_rx, async move {
231 #dummy_yield
232 #(#stmts)*
233 })
234 })
235 .into()
236 }
237
238 /// The first token tree in the stream must be a group containing the path to the `async-stream`
239 /// crate.
240 #[proc_macro]
241 #[doc(hidden)]
try_stream_inner(input: TokenStream) -> TokenStream242 pub fn try_stream_inner(input: TokenStream) -> TokenStream {
243 let (crate_path, mut stmts) = match parse_input(input) {
244 Ok(x) => x,
245 Err(e) => return e.to_compile_error().into(),
246 };
247
248 let mut scrub = Scrub::new(true, &crate_path);
249
250 for mut stmt in &mut stmts {
251 scrub.visit_stmt_mut(&mut stmt);
252 }
253
254 let dummy_yield = if scrub.has_yielded {
255 None
256 } else {
257 Some(quote!(if false {
258 __yield_tx.send(()).await;
259 }))
260 };
261
262 quote!({
263 let (mut __yield_tx, __yield_rx) = #crate_path::yielder::pair();
264 #crate_path::AsyncStream::new(__yield_rx, async move {
265 #dummy_yield
266 #(#stmts)*
267 })
268 })
269 .into()
270 }
271
272 /// Replace `for await` with `#[await] for`, which will be later transformed into a `next` loop.
replace_for_await(input: impl IntoIterator<Item = TokenTree>) -> TokenStream2273 fn replace_for_await(input: impl IntoIterator<Item = TokenTree>) -> TokenStream2 {
274 let mut input = input.into_iter().peekable();
275 let mut tokens = Vec::new();
276
277 while let Some(token) = input.next() {
278 match token {
279 TokenTree::Ident(ident) => {
280 match input.peek() {
281 Some(TokenTree::Ident(next)) if ident == "for" && next == "await" => {
282 tokens.extend(quote!(#[#next]));
283 let _ = input.next();
284 }
285 _ => {}
286 }
287 tokens.push(ident.into());
288 }
289 TokenTree::Group(group) => {
290 let stream = replace_for_await(group.stream());
291 tokens.push(Group::new(group.delimiter(), stream).into());
292 }
293 _ => tokens.push(token),
294 }
295 }
296
297 tokens.into_iter().collect()
298 }
299