• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro::TokenStream;
2 use proc_macro2::Span;
3 use quote::{quote, quote_spanned, ToTokens};
4 use syn::parse::Parser;
5 
6 // syn::AttributeArgs does not implement syn::Parse
7 type AttributeArgs = syn::punctuated::Punctuated<syn::NestedMeta, syn::Token![,]>;
8 
9 #[derive(Clone, Copy, PartialEq)]
10 enum RuntimeFlavor {
11     CurrentThread,
12     Threaded,
13 }
14 
15 impl RuntimeFlavor {
from_str(s: &str) -> Result<RuntimeFlavor, String>16     fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
17         match s {
18             "current_thread" => Ok(RuntimeFlavor::CurrentThread),
19             "multi_thread" => Ok(RuntimeFlavor::Threaded),
20             "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()),
21             "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()),
22             "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()),
23             _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.", s)),
24         }
25     }
26 }
27 
28 struct FinalConfig {
29     flavor: RuntimeFlavor,
30     worker_threads: Option<usize>,
31     start_paused: Option<bool>,
32 }
33 
34 /// Config used in case of the attribute not being able to build a valid config
35 const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig {
36     flavor: RuntimeFlavor::CurrentThread,
37     worker_threads: None,
38     start_paused: None,
39 };
40 
41 struct Configuration {
42     rt_multi_thread_available: bool,
43     default_flavor: RuntimeFlavor,
44     flavor: Option<RuntimeFlavor>,
45     worker_threads: Option<(usize, Span)>,
46     start_paused: Option<(bool, Span)>,
47     is_test: bool,
48 }
49 
50 impl Configuration {
new(is_test: bool, rt_multi_thread: bool) -> Self51     fn new(is_test: bool, rt_multi_thread: bool) -> Self {
52         Configuration {
53             rt_multi_thread_available: rt_multi_thread,
54             default_flavor: match is_test {
55                 true => RuntimeFlavor::CurrentThread,
56                 false => RuntimeFlavor::Threaded,
57             },
58             flavor: None,
59             worker_threads: None,
60             start_paused: None,
61             is_test,
62         }
63     }
64 
set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error>65     fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
66         if self.flavor.is_some() {
67             return Err(syn::Error::new(span, "`flavor` set multiple times."));
68         }
69 
70         let runtime_str = parse_string(runtime, span, "flavor")?;
71         let runtime =
72             RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
73         self.flavor = Some(runtime);
74         Ok(())
75     }
76 
set_worker_threads( &mut self, worker_threads: syn::Lit, span: Span, ) -> Result<(), syn::Error>77     fn set_worker_threads(
78         &mut self,
79         worker_threads: syn::Lit,
80         span: Span,
81     ) -> Result<(), syn::Error> {
82         if self.worker_threads.is_some() {
83             return Err(syn::Error::new(
84                 span,
85                 "`worker_threads` set multiple times.",
86             ));
87         }
88 
89         let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
90         if worker_threads == 0 {
91             return Err(syn::Error::new(span, "`worker_threads` may not be 0."));
92         }
93         self.worker_threads = Some((worker_threads, span));
94         Ok(())
95     }
96 
set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error>97     fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
98         if self.start_paused.is_some() {
99             return Err(syn::Error::new(span, "`start_paused` set multiple times."));
100         }
101 
102         let start_paused = parse_bool(start_paused, span, "start_paused")?;
103         self.start_paused = Some((start_paused, span));
104         Ok(())
105     }
106 
macro_name(&self) -> &'static str107     fn macro_name(&self) -> &'static str {
108         if self.is_test {
109             "tokio::test"
110         } else {
111             "tokio::main"
112         }
113     }
114 
build(&self) -> Result<FinalConfig, syn::Error>115     fn build(&self) -> Result<FinalConfig, syn::Error> {
116         let flavor = self.flavor.unwrap_or(self.default_flavor);
117         use RuntimeFlavor::*;
118 
119         let worker_threads = match (flavor, self.worker_threads) {
120             (CurrentThread, Some((_, worker_threads_span))) => {
121                 let msg = format!(
122                     "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
123                     self.macro_name(),
124                 );
125                 return Err(syn::Error::new(worker_threads_span, msg));
126             }
127             (CurrentThread, None) => None,
128             (Threaded, worker_threads) if self.rt_multi_thread_available => {
129                 worker_threads.map(|(val, _span)| val)
130             }
131             (Threaded, _) => {
132                 let msg = if self.flavor.is_none() {
133                     "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
134                 } else {
135                     "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
136                 };
137                 return Err(syn::Error::new(Span::call_site(), msg));
138             }
139         };
140 
141         let start_paused = match (flavor, self.start_paused) {
142             (Threaded, Some((_, start_paused_span))) => {
143                 let msg = format!(
144                     "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
145                     self.macro_name(),
146                 );
147                 return Err(syn::Error::new(start_paused_span, msg));
148             }
149             (CurrentThread, Some((start_paused, _))) => Some(start_paused),
150             (_, None) => None,
151         };
152 
153         Ok(FinalConfig {
154             flavor,
155             worker_threads,
156             start_paused,
157         })
158     }
159 }
160 
parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error>161 fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
162     match int {
163         syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
164             Ok(value) => Ok(value),
165             Err(e) => Err(syn::Error::new(
166                 span,
167                 format!("Failed to parse value of `{}` as integer: {}", field, e),
168             )),
169         },
170         _ => Err(syn::Error::new(
171             span,
172             format!("Failed to parse value of `{}` as integer.", field),
173         )),
174     }
175 }
176 
parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error>177 fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
178     match int {
179         syn::Lit::Str(s) => Ok(s.value()),
180         syn::Lit::Verbatim(s) => Ok(s.to_string()),
181         _ => Err(syn::Error::new(
182             span,
183             format!("Failed to parse value of `{}` as string.", field),
184         )),
185     }
186 }
187 
parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error>188 fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
189     match bool {
190         syn::Lit::Bool(b) => Ok(b.value),
191         _ => Err(syn::Error::new(
192             span,
193             format!("Failed to parse value of `{}` as bool.", field),
194         )),
195     }
196 }
197 
build_config( input: syn::ItemFn, args: AttributeArgs, is_test: bool, rt_multi_thread: bool, ) -> Result<FinalConfig, syn::Error>198 fn build_config(
199     input: syn::ItemFn,
200     args: AttributeArgs,
201     is_test: bool,
202     rt_multi_thread: bool,
203 ) -> Result<FinalConfig, syn::Error> {
204     if input.sig.asyncness.is_none() {
205         let msg = "the `async` keyword is missing from the function declaration";
206         return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
207     }
208 
209     let mut config = Configuration::new(is_test, rt_multi_thread);
210     let macro_name = config.macro_name();
211 
212     for arg in args {
213         match arg {
214             syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => {
215                 let ident = namevalue
216                     .path
217                     .get_ident()
218                     .ok_or_else(|| {
219                         syn::Error::new_spanned(&namevalue, "Must have specified ident")
220                     })?
221                     .to_string()
222                     .to_lowercase();
223                 match ident.as_str() {
224                     "worker_threads" => {
225                         config.set_worker_threads(
226                             namevalue.lit.clone(),
227                             syn::spanned::Spanned::span(&namevalue.lit),
228                         )?;
229                     }
230                     "flavor" => {
231                         config.set_flavor(
232                             namevalue.lit.clone(),
233                             syn::spanned::Spanned::span(&namevalue.lit),
234                         )?;
235                     }
236                     "start_paused" => {
237                         config.set_start_paused(
238                             namevalue.lit.clone(),
239                             syn::spanned::Spanned::span(&namevalue.lit),
240                         )?;
241                     }
242                     "core_threads" => {
243                         let msg = "Attribute `core_threads` is renamed to `worker_threads`";
244                         return Err(syn::Error::new_spanned(namevalue, msg));
245                     }
246                     name => {
247                         let msg = format!(
248                             "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`",
249                             name,
250                         );
251                         return Err(syn::Error::new_spanned(namevalue, msg));
252                     }
253                 }
254             }
255             syn::NestedMeta::Meta(syn::Meta::Path(path)) => {
256                 let name = path
257                     .get_ident()
258                     .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
259                     .to_string()
260                     .to_lowercase();
261                 let msg = match name.as_str() {
262                     "threaded_scheduler" | "multi_thread" => {
263                         format!(
264                             "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].",
265                             macro_name
266                         )
267                     }
268                     "basic_scheduler" | "current_thread" | "single_threaded" => {
269                         format!(
270                             "Set the runtime flavor with #[{}(flavor = \"current_thread\")].",
271                             macro_name
272                         )
273                     }
274                     "flavor" | "worker_threads" | "start_paused" => {
275                         format!("The `{}` attribute requires an argument.", name)
276                     }
277                     name => {
278                         format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", name)
279                     }
280                 };
281                 return Err(syn::Error::new_spanned(path, msg));
282             }
283             other => {
284                 return Err(syn::Error::new_spanned(
285                     other,
286                     "Unknown attribute inside the macro",
287                 ));
288             }
289         }
290     }
291 
292     config.build()
293 }
294 
parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> TokenStream295 fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> TokenStream {
296     input.sig.asyncness = None;
297 
298     // If type mismatch occurs, the current rustc points to the last statement.
299     let (last_stmt_start_span, last_stmt_end_span) = {
300         let mut last_stmt = input
301             .block
302             .stmts
303             .last()
304             .map(ToTokens::into_token_stream)
305             .unwrap_or_default()
306             .into_iter();
307         // `Span` on stable Rust has a limitation that only points to the first
308         // token, not the whole tokens. We can work around this limitation by
309         // using the first/last span of the tokens like
310         // `syn::Error::new_spanned` does.
311         let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
312         let end = last_stmt.last().map_or(start, |t| t.span());
313         (start, end)
314     };
315 
316     let mut rt = match config.flavor {
317         RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=>
318             tokio::runtime::Builder::new_current_thread()
319         },
320         RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=>
321             tokio::runtime::Builder::new_multi_thread()
322         },
323     };
324     if let Some(v) = config.worker_threads {
325         rt = quote! { #rt.worker_threads(#v) };
326     }
327     if let Some(v) = config.start_paused {
328         rt = quote! { #rt.start_paused(#v) };
329     }
330 
331     let header = if is_test {
332         quote! {
333             #[::core::prelude::v1::test]
334         }
335     } else {
336         quote! {}
337     };
338 
339     let body = &input.block;
340     let brace_token = input.block.brace_token;
341     let (tail_return, tail_semicolon) = match body.stmts.last() {
342         Some(syn::Stmt::Semi(syn::Expr::Return(_), _)) => (quote! { return }, quote! { ; }),
343         Some(syn::Stmt::Semi(..)) | Some(syn::Stmt::Local(..)) | None => {
344             match &input.sig.output {
345                 syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) =>
346                 {
347                     (quote! {}, quote! { ; }) // unit
348                 }
349                 syn::ReturnType::Default => (quote! {}, quote! { ; }), // unit
350                 syn::ReturnType::Type(..) => (quote! {}, quote! {}),   // ! or another
351             }
352         }
353         _ => (quote! {}, quote! {}),
354     };
355     input.block = syn::parse2(quote_spanned! {last_stmt_end_span=>
356         {
357             let body = async #body;
358             #[allow(clippy::expect_used)]
359             #tail_return #rt
360                 .enable_all()
361                 .build()
362                 .expect("Failed building the Runtime")
363                 .block_on(body)#tail_semicolon
364         }
365     })
366     .expect("Parsing failure");
367     input.block.brace_token = brace_token;
368 
369     let result = quote! {
370         #header
371         #input
372     };
373 
374     result.into()
375 }
376 
token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream377 fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
378     tokens.extend(TokenStream::from(error.into_compile_error()));
379     tokens
380 }
381 
382 #[cfg(not(test))] // Work around for rust-lang/rust#62127
main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream383 pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
384     // If any of the steps for this macro fail, we still want to expand to an item that is as close
385     // to the expected output as possible. This helps out IDEs such that completions and other
386     // related features keep working.
387     let input: syn::ItemFn = match syn::parse(item.clone()) {
388         Ok(it) => it,
389         Err(e) => return token_stream_with_error(item, e),
390     };
391 
392     let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
393         let msg = "the main function cannot accept arguments";
394         Err(syn::Error::new_spanned(&input.sig.ident, msg))
395     } else {
396         AttributeArgs::parse_terminated
397             .parse(args)
398             .and_then(|args| build_config(input.clone(), args, false, rt_multi_thread))
399     };
400 
401     match config {
402         Ok(config) => parse_knobs(input, false, config),
403         Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e),
404     }
405 }
406 
test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream407 pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
408     // If any of the steps for this macro fail, we still want to expand to an item that is as close
409     // to the expected output as possible. This helps out IDEs such that completions and other
410     // related features keep working.
411     let input: syn::ItemFn = match syn::parse(item.clone()) {
412         Ok(it) => it,
413         Err(e) => return token_stream_with_error(item, e),
414     };
415     let config = if let Some(attr) = input.attrs.iter().find(|attr| attr.path.is_ident("test")) {
416         let msg = "second test attribute is supplied";
417         Err(syn::Error::new_spanned(&attr, msg))
418     } else {
419         AttributeArgs::parse_terminated
420             .parse(args)
421             .and_then(|args| build_config(input.clone(), args, true, rt_multi_thread))
422     };
423 
424     match config {
425         Ok(config) => parse_knobs(input, true, config),
426         Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e),
427     }
428 }
429