• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro2::{Span, TokenStream, TokenTree};
2 use quote::{quote, quote_spanned, ToTokens};
3 use syn::parse::{Parse, ParseStream, Parser};
4 use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
5 
6 // syn::AttributeArgs does not implement syn::Parse
7 type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
8 
9 #[derive(Clone, Copy, PartialEq)]
10 enum RuntimeFlavor {
11     CurrentThread,
12     Threaded,
13 }
14 
15 impl RuntimeFlavor {
from_str(s: &str) -> Result<RuntimeFlavor, String>16     fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
17         match s {
18             "current_thread" => Ok(RuntimeFlavor::CurrentThread),
19             "multi_thread" => Ok(RuntimeFlavor::Threaded),
20             "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()),
21             "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()),
22             "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()),
23             _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.", s)),
24         }
25     }
26 }
27 
28 struct FinalConfig {
29     flavor: RuntimeFlavor,
30     worker_threads: Option<usize>,
31     start_paused: Option<bool>,
32     crate_name: Option<Path>,
33 }
34 
35 /// Config used in case of the attribute not being able to build a valid config
36 const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig {
37     flavor: RuntimeFlavor::CurrentThread,
38     worker_threads: None,
39     start_paused: None,
40     crate_name: None,
41 };
42 
43 struct Configuration {
44     rt_multi_thread_available: bool,
45     default_flavor: RuntimeFlavor,
46     flavor: Option<RuntimeFlavor>,
47     worker_threads: Option<(usize, Span)>,
48     start_paused: Option<(bool, Span)>,
49     is_test: bool,
50     crate_name: Option<Path>,
51 }
52 
53 impl Configuration {
new(is_test: bool, rt_multi_thread: bool) -> Self54     fn new(is_test: bool, rt_multi_thread: bool) -> Self {
55         Configuration {
56             rt_multi_thread_available: rt_multi_thread,
57             default_flavor: match is_test {
58                 true => RuntimeFlavor::CurrentThread,
59                 false => RuntimeFlavor::Threaded,
60             },
61             flavor: None,
62             worker_threads: None,
63             start_paused: None,
64             is_test,
65             crate_name: None,
66         }
67     }
68 
set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error>69     fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
70         if self.flavor.is_some() {
71             return Err(syn::Error::new(span, "`flavor` set multiple times."));
72         }
73 
74         let runtime_str = parse_string(runtime, span, "flavor")?;
75         let runtime =
76             RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
77         self.flavor = Some(runtime);
78         Ok(())
79     }
80 
set_worker_threads( &mut self, worker_threads: syn::Lit, span: Span, ) -> Result<(), syn::Error>81     fn set_worker_threads(
82         &mut self,
83         worker_threads: syn::Lit,
84         span: Span,
85     ) -> Result<(), syn::Error> {
86         if self.worker_threads.is_some() {
87             return Err(syn::Error::new(
88                 span,
89                 "`worker_threads` set multiple times.",
90             ));
91         }
92 
93         let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
94         if worker_threads == 0 {
95             return Err(syn::Error::new(span, "`worker_threads` may not be 0."));
96         }
97         self.worker_threads = Some((worker_threads, span));
98         Ok(())
99     }
100 
set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error>101     fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
102         if self.start_paused.is_some() {
103             return Err(syn::Error::new(span, "`start_paused` set multiple times."));
104         }
105 
106         let start_paused = parse_bool(start_paused, span, "start_paused")?;
107         self.start_paused = Some((start_paused, span));
108         Ok(())
109     }
110 
set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error>111     fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> {
112         if self.crate_name.is_some() {
113             return Err(syn::Error::new(span, "`crate` set multiple times."));
114         }
115         let name_path = parse_path(name, span, "crate")?;
116         self.crate_name = Some(name_path);
117         Ok(())
118     }
119 
macro_name(&self) -> &'static str120     fn macro_name(&self) -> &'static str {
121         if self.is_test {
122             "tokio::test"
123         } else {
124             "tokio::main"
125         }
126     }
127 
build(&self) -> Result<FinalConfig, syn::Error>128     fn build(&self) -> Result<FinalConfig, syn::Error> {
129         let flavor = self.flavor.unwrap_or(self.default_flavor);
130         use RuntimeFlavor::*;
131 
132         let worker_threads = match (flavor, self.worker_threads) {
133             (CurrentThread, Some((_, worker_threads_span))) => {
134                 let msg = format!(
135                     "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
136                     self.macro_name(),
137                 );
138                 return Err(syn::Error::new(worker_threads_span, msg));
139             }
140             (CurrentThread, None) => None,
141             (Threaded, worker_threads) if self.rt_multi_thread_available => {
142                 worker_threads.map(|(val, _span)| val)
143             }
144             (Threaded, _) => {
145                 let msg = if self.flavor.is_none() {
146                     "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
147                 } else {
148                     "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
149                 };
150                 return Err(syn::Error::new(Span::call_site(), msg));
151             }
152         };
153 
154         let start_paused = match (flavor, self.start_paused) {
155             (Threaded, Some((_, start_paused_span))) => {
156                 let msg = format!(
157                     "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
158                     self.macro_name(),
159                 );
160                 return Err(syn::Error::new(start_paused_span, msg));
161             }
162             (CurrentThread, Some((start_paused, _))) => Some(start_paused),
163             (_, None) => None,
164         };
165 
166         Ok(FinalConfig {
167             crate_name: self.crate_name.clone(),
168             flavor,
169             worker_threads,
170             start_paused,
171         })
172     }
173 }
174 
parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error>175 fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
176     match int {
177         syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
178             Ok(value) => Ok(value),
179             Err(e) => Err(syn::Error::new(
180                 span,
181                 format!("Failed to parse value of `{}` as integer: {}", field, e),
182             )),
183         },
184         _ => Err(syn::Error::new(
185             span,
186             format!("Failed to parse value of `{}` as integer.", field),
187         )),
188     }
189 }
190 
parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error>191 fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
192     match int {
193         syn::Lit::Str(s) => Ok(s.value()),
194         syn::Lit::Verbatim(s) => Ok(s.to_string()),
195         _ => Err(syn::Error::new(
196             span,
197             format!("Failed to parse value of `{}` as string.", field),
198         )),
199     }
200 }
201 
parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error>202 fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
203     match lit {
204         syn::Lit::Str(s) => {
205             let err = syn::Error::new(
206                 span,
207                 format!(
208                     "Failed to parse value of `{}` as path: \"{}\"",
209                     field,
210                     s.value()
211                 ),
212             );
213             s.parse::<syn::Path>().map_err(|_| err.clone())
214         }
215         _ => Err(syn::Error::new(
216             span,
217             format!("Failed to parse value of `{}` as path.", field),
218         )),
219     }
220 }
221 
parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error>222 fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
223     match bool {
224         syn::Lit::Bool(b) => Ok(b.value),
225         _ => Err(syn::Error::new(
226             span,
227             format!("Failed to parse value of `{}` as bool.", field),
228         )),
229     }
230 }
231 
build_config( input: &ItemFn, args: AttributeArgs, is_test: bool, rt_multi_thread: bool, ) -> Result<FinalConfig, syn::Error>232 fn build_config(
233     input: &ItemFn,
234     args: AttributeArgs,
235     is_test: bool,
236     rt_multi_thread: bool,
237 ) -> Result<FinalConfig, syn::Error> {
238     if input.sig.asyncness.is_none() {
239         let msg = "the `async` keyword is missing from the function declaration";
240         return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
241     }
242 
243     let mut config = Configuration::new(is_test, rt_multi_thread);
244     let macro_name = config.macro_name();
245 
246     for arg in args {
247         match arg {
248             syn::Meta::NameValue(namevalue) => {
249                 let ident = namevalue
250                     .path
251                     .get_ident()
252                     .ok_or_else(|| {
253                         syn::Error::new_spanned(&namevalue, "Must have specified ident")
254                     })?
255                     .to_string()
256                     .to_lowercase();
257                 let lit = match &namevalue.value {
258                     syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit,
259                     expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")),
260                 };
261                 match ident.as_str() {
262                     "worker_threads" => {
263                         config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?;
264                     }
265                     "flavor" => {
266                         config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?;
267                     }
268                     "start_paused" => {
269                         config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?;
270                     }
271                     "core_threads" => {
272                         let msg = "Attribute `core_threads` is renamed to `worker_threads`";
273                         return Err(syn::Error::new_spanned(namevalue, msg));
274                     }
275                     "crate" => {
276                         config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?;
277                     }
278                     name => {
279                         let msg = format!(
280                             "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`",
281                             name,
282                         );
283                         return Err(syn::Error::new_spanned(namevalue, msg));
284                     }
285                 }
286             }
287             syn::Meta::Path(path) => {
288                 let name = path
289                     .get_ident()
290                     .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
291                     .to_string()
292                     .to_lowercase();
293                 let msg = match name.as_str() {
294                     "threaded_scheduler" | "multi_thread" => {
295                         format!(
296                             "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].",
297                             macro_name
298                         )
299                     }
300                     "basic_scheduler" | "current_thread" | "single_threaded" => {
301                         format!(
302                             "Set the runtime flavor with #[{}(flavor = \"current_thread\")].",
303                             macro_name
304                         )
305                     }
306                     "flavor" | "worker_threads" | "start_paused" => {
307                         format!("The `{}` attribute requires an argument.", name)
308                     }
309                     name => {
310                         format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name)
311                     }
312                 };
313                 return Err(syn::Error::new_spanned(path, msg));
314             }
315             other => {
316                 return Err(syn::Error::new_spanned(
317                     other,
318                     "Unknown attribute inside the macro",
319                 ));
320             }
321         }
322     }
323 
324     config.build()
325 }
326 
parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream327 fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream {
328     input.sig.asyncness = None;
329 
330     // If type mismatch occurs, the current rustc points to the last statement.
331     let (last_stmt_start_span, last_stmt_end_span) = {
332         let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
333 
334         // `Span` on stable Rust has a limitation that only points to the first
335         // token, not the whole tokens. We can work around this limitation by
336         // using the first/last span of the tokens like
337         // `syn::Error::new_spanned` does.
338         let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
339         let end = last_stmt.last().map_or(start, |t| t.span());
340         (start, end)
341     };
342 
343     let crate_path = config
344         .crate_name
345         .map(ToTokens::into_token_stream)
346         .unwrap_or_else(|| Ident::new("tokio", last_stmt_start_span).into_token_stream());
347 
348     let mut rt = match config.flavor {
349         RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=>
350             #crate_path::runtime::Builder::new_current_thread()
351         },
352         RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=>
353             #crate_path::runtime::Builder::new_multi_thread()
354         },
355     };
356     if let Some(v) = config.worker_threads {
357         rt = quote! { #rt.worker_threads(#v) };
358     }
359     if let Some(v) = config.start_paused {
360         rt = quote! { #rt.start_paused(#v) };
361     }
362 
363     let header = if is_test {
364         quote! {
365             #[::core::prelude::v1::test]
366         }
367     } else {
368         quote! {}
369     };
370 
371     let body_ident = quote! { body };
372     let last_block = quote_spanned! {last_stmt_end_span=>
373         #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
374         {
375             return #rt
376                 .enable_all()
377                 .build()
378                 .expect("Failed building the Runtime")
379                 .block_on(#body_ident);
380         }
381     };
382 
383     let body = input.body();
384 
385     // For test functions pin the body to the stack and use `Pin<&mut dyn
386     // Future>` to reduce the amount of `Runtime::block_on` (and related
387     // functions) copies we generate during compilation due to the generic
388     // parameter `F` (the future to block on). This could have an impact on
389     // performance, but because it's only for testing it's unlikely to be very
390     // large.
391     //
392     // We don't do this for the main function as it should only be used once so
393     // there will be no benefit.
394     let body = if is_test {
395         let output_type = match &input.sig.output {
396             // For functions with no return value syn doesn't print anything,
397             // but that doesn't work as `Output` for our boxed `Future`, so
398             // default to `()` (the same type as the function output).
399             syn::ReturnType::Default => quote! { () },
400             syn::ReturnType::Type(_, ret_type) => quote! { #ret_type },
401         };
402         quote! {
403             let body = async #body;
404             #crate_path::pin!(body);
405             let body: ::std::pin::Pin<&mut dyn ::std::future::Future<Output = #output_type>> = body;
406         }
407     } else {
408         quote! {
409             let body = async #body;
410         }
411     };
412 
413     input.into_tokens(header, body, last_block)
414 }
415 
token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream416 fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
417     tokens.extend(error.into_compile_error());
418     tokens
419 }
420 
421 #[cfg(not(test))] // Work around for rust-lang/rust#62127
main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream422 pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
423     // If any of the steps for this macro fail, we still want to expand to an item that is as close
424     // to the expected output as possible. This helps out IDEs such that completions and other
425     // related features keep working.
426     let input: ItemFn = match syn::parse2(item.clone()) {
427         Ok(it) => it,
428         Err(e) => return token_stream_with_error(item, e),
429     };
430 
431     let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
432         let msg = "the main function cannot accept arguments";
433         Err(syn::Error::new_spanned(&input.sig.ident, msg))
434     } else {
435         AttributeArgs::parse_terminated
436             .parse2(args)
437             .and_then(|args| build_config(&input, args, false, rt_multi_thread))
438     };
439 
440     match config {
441         Ok(config) => parse_knobs(input, false, config),
442         Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e),
443     }
444 }
445 
test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream446 pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
447     // If any of the steps for this macro fail, we still want to expand to an item that is as close
448     // to the expected output as possible. This helps out IDEs such that completions and other
449     // related features keep working.
450     let input: ItemFn = match syn::parse2(item.clone()) {
451         Ok(it) => it,
452         Err(e) => return token_stream_with_error(item, e),
453     };
454     let config = if let Some(attr) = input.attrs().find(|attr| attr.meta.path().is_ident("test")) {
455         let msg = "second test attribute is supplied";
456         Err(syn::Error::new_spanned(attr, msg))
457     } else {
458         AttributeArgs::parse_terminated
459             .parse2(args)
460             .and_then(|args| build_config(&input, args, true, rt_multi_thread))
461     };
462 
463     match config {
464         Ok(config) => parse_knobs(input, true, config),
465         Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e),
466     }
467 }
468 
469 struct ItemFn {
470     outer_attrs: Vec<Attribute>,
471     vis: Visibility,
472     sig: Signature,
473     brace_token: syn::token::Brace,
474     inner_attrs: Vec<Attribute>,
475     stmts: Vec<proc_macro2::TokenStream>,
476 }
477 
478 impl ItemFn {
479     /// Access all attributes of the function item.
attrs(&self) -> impl Iterator<Item = &Attribute>480     fn attrs(&self) -> impl Iterator<Item = &Attribute> {
481         self.outer_attrs.iter().chain(self.inner_attrs.iter())
482     }
483 
484     /// Get the body of the function item in a manner so that it can be
485     /// conveniently used with the `quote!` macro.
body(&self) -> Body<'_>486     fn body(&self) -> Body<'_> {
487         Body {
488             brace_token: self.brace_token,
489             stmts: &self.stmts,
490         }
491     }
492 
493     /// Convert our local function item into a token stream.
into_tokens( self, header: proc_macro2::TokenStream, body: proc_macro2::TokenStream, last_block: proc_macro2::TokenStream, ) -> TokenStream494     fn into_tokens(
495         self,
496         header: proc_macro2::TokenStream,
497         body: proc_macro2::TokenStream,
498         last_block: proc_macro2::TokenStream,
499     ) -> TokenStream {
500         let mut tokens = proc_macro2::TokenStream::new();
501         header.to_tokens(&mut tokens);
502 
503         // Outer attributes are simply streamed as-is.
504         for attr in self.outer_attrs {
505             attr.to_tokens(&mut tokens);
506         }
507 
508         // Inner attributes require extra care, since they're not supported on
509         // blocks (which is what we're expanded into) we instead lift them
510         // outside of the function. This matches the behaviour of `syn`.
511         for mut attr in self.inner_attrs {
512             attr.style = syn::AttrStyle::Outer;
513             attr.to_tokens(&mut tokens);
514         }
515 
516         self.vis.to_tokens(&mut tokens);
517         self.sig.to_tokens(&mut tokens);
518 
519         self.brace_token.surround(&mut tokens, |tokens| {
520             body.to_tokens(tokens);
521             last_block.to_tokens(tokens);
522         });
523 
524         tokens
525     }
526 }
527 
528 impl Parse for ItemFn {
529     #[inline]
parse(input: ParseStream<'_>) -> syn::Result<Self>530     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
531         // This parse implementation has been largely lifted from `syn`, with
532         // the exception of:
533         // * We don't have access to the plumbing necessary to parse inner
534         //   attributes in-place.
535         // * We do our own statements parsing to avoid recursively parsing
536         //   entire statements and only look for the parts we're interested in.
537 
538         let outer_attrs = input.call(Attribute::parse_outer)?;
539         let vis: Visibility = input.parse()?;
540         let sig: Signature = input.parse()?;
541 
542         let content;
543         let brace_token = braced!(content in input);
544         let inner_attrs = Attribute::parse_inner(&content)?;
545 
546         let mut buf = proc_macro2::TokenStream::new();
547         let mut stmts = Vec::new();
548 
549         while !content.is_empty() {
550             if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
551                 semi.to_tokens(&mut buf);
552                 stmts.push(buf);
553                 buf = proc_macro2::TokenStream::new();
554                 continue;
555             }
556 
557             // Parse a single token tree and extend our current buffer with it.
558             // This avoids parsing the entire content of the sub-tree.
559             buf.extend([content.parse::<TokenTree>()?]);
560         }
561 
562         if !buf.is_empty() {
563             stmts.push(buf);
564         }
565 
566         Ok(Self {
567             outer_attrs,
568             vis,
569             sig,
570             brace_token,
571             inner_attrs,
572             stmts,
573         })
574     }
575 }
576 
577 struct Body<'a> {
578     brace_token: syn::token::Brace,
579     // Statements, with terminating `;`.
580     stmts: &'a [TokenStream],
581 }
582 
583 impl ToTokens for Body<'_> {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)584     fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
585         self.brace_token.surround(tokens, |tokens| {
586             for stmt in self.stmts {
587                 stmt.to_tokens(tokens);
588             }
589         })
590     }
591 }
592