1 use proc_macro::TokenStream;
2 use proc_macro2::Span;
3 use quote::{quote, quote_spanned, ToTokens};
4 use syn::parse::Parser;
5
6 // syn::AttributeArgs does not implement syn::Parse
7 type AttributeArgs = syn::punctuated::Punctuated<syn::NestedMeta, syn::Token![,]>;
8
9 #[derive(Clone, Copy, PartialEq)]
10 enum RuntimeFlavor {
11 CurrentThread,
12 Threaded,
13 }
14
15 impl RuntimeFlavor {
from_str(s: &str) -> Result<RuntimeFlavor, String>16 fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
17 match s {
18 "current_thread" => Ok(RuntimeFlavor::CurrentThread),
19 "multi_thread" => Ok(RuntimeFlavor::Threaded),
20 "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()),
21 "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()),
22 "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()),
23 _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.", s)),
24 }
25 }
26 }
27
28 struct FinalConfig {
29 flavor: RuntimeFlavor,
30 worker_threads: Option<usize>,
31 start_paused: Option<bool>,
32 }
33
34 /// Config used in case of the attribute not being able to build a valid config
35 const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig {
36 flavor: RuntimeFlavor::CurrentThread,
37 worker_threads: None,
38 start_paused: None,
39 };
40
41 struct Configuration {
42 rt_multi_thread_available: bool,
43 default_flavor: RuntimeFlavor,
44 flavor: Option<RuntimeFlavor>,
45 worker_threads: Option<(usize, Span)>,
46 start_paused: Option<(bool, Span)>,
47 is_test: bool,
48 }
49
50 impl Configuration {
new(is_test: bool, rt_multi_thread: bool) -> Self51 fn new(is_test: bool, rt_multi_thread: bool) -> Self {
52 Configuration {
53 rt_multi_thread_available: rt_multi_thread,
54 default_flavor: match is_test {
55 true => RuntimeFlavor::CurrentThread,
56 false => RuntimeFlavor::Threaded,
57 },
58 flavor: None,
59 worker_threads: None,
60 start_paused: None,
61 is_test,
62 }
63 }
64
set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error>65 fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
66 if self.flavor.is_some() {
67 return Err(syn::Error::new(span, "`flavor` set multiple times."));
68 }
69
70 let runtime_str = parse_string(runtime, span, "flavor")?;
71 let runtime =
72 RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
73 self.flavor = Some(runtime);
74 Ok(())
75 }
76
set_worker_threads( &mut self, worker_threads: syn::Lit, span: Span, ) -> Result<(), syn::Error>77 fn set_worker_threads(
78 &mut self,
79 worker_threads: syn::Lit,
80 span: Span,
81 ) -> Result<(), syn::Error> {
82 if self.worker_threads.is_some() {
83 return Err(syn::Error::new(
84 span,
85 "`worker_threads` set multiple times.",
86 ));
87 }
88
89 let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
90 if worker_threads == 0 {
91 return Err(syn::Error::new(span, "`worker_threads` may not be 0."));
92 }
93 self.worker_threads = Some((worker_threads, span));
94 Ok(())
95 }
96
set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error>97 fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
98 if self.start_paused.is_some() {
99 return Err(syn::Error::new(span, "`start_paused` set multiple times."));
100 }
101
102 let start_paused = parse_bool(start_paused, span, "start_paused")?;
103 self.start_paused = Some((start_paused, span));
104 Ok(())
105 }
106
macro_name(&self) -> &'static str107 fn macro_name(&self) -> &'static str {
108 if self.is_test {
109 "tokio::test"
110 } else {
111 "tokio::main"
112 }
113 }
114
build(&self) -> Result<FinalConfig, syn::Error>115 fn build(&self) -> Result<FinalConfig, syn::Error> {
116 let flavor = self.flavor.unwrap_or(self.default_flavor);
117 use RuntimeFlavor::*;
118
119 let worker_threads = match (flavor, self.worker_threads) {
120 (CurrentThread, Some((_, worker_threads_span))) => {
121 let msg = format!(
122 "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
123 self.macro_name(),
124 );
125 return Err(syn::Error::new(worker_threads_span, msg));
126 }
127 (CurrentThread, None) => None,
128 (Threaded, worker_threads) if self.rt_multi_thread_available => {
129 worker_threads.map(|(val, _span)| val)
130 }
131 (Threaded, _) => {
132 let msg = if self.flavor.is_none() {
133 "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
134 } else {
135 "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
136 };
137 return Err(syn::Error::new(Span::call_site(), msg));
138 }
139 };
140
141 let start_paused = match (flavor, self.start_paused) {
142 (Threaded, Some((_, start_paused_span))) => {
143 let msg = format!(
144 "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
145 self.macro_name(),
146 );
147 return Err(syn::Error::new(start_paused_span, msg));
148 }
149 (CurrentThread, Some((start_paused, _))) => Some(start_paused),
150 (_, None) => None,
151 };
152
153 Ok(FinalConfig {
154 flavor,
155 worker_threads,
156 start_paused,
157 })
158 }
159 }
160
parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error>161 fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
162 match int {
163 syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
164 Ok(value) => Ok(value),
165 Err(e) => Err(syn::Error::new(
166 span,
167 format!("Failed to parse value of `{}` as integer: {}", field, e),
168 )),
169 },
170 _ => Err(syn::Error::new(
171 span,
172 format!("Failed to parse value of `{}` as integer.", field),
173 )),
174 }
175 }
176
parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error>177 fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
178 match int {
179 syn::Lit::Str(s) => Ok(s.value()),
180 syn::Lit::Verbatim(s) => Ok(s.to_string()),
181 _ => Err(syn::Error::new(
182 span,
183 format!("Failed to parse value of `{}` as string.", field),
184 )),
185 }
186 }
187
parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error>188 fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
189 match bool {
190 syn::Lit::Bool(b) => Ok(b.value),
191 _ => Err(syn::Error::new(
192 span,
193 format!("Failed to parse value of `{}` as bool.", field),
194 )),
195 }
196 }
197
build_config( input: syn::ItemFn, args: AttributeArgs, is_test: bool, rt_multi_thread: bool, ) -> Result<FinalConfig, syn::Error>198 fn build_config(
199 input: syn::ItemFn,
200 args: AttributeArgs,
201 is_test: bool,
202 rt_multi_thread: bool,
203 ) -> Result<FinalConfig, syn::Error> {
204 if input.sig.asyncness.is_none() {
205 let msg = "the `async` keyword is missing from the function declaration";
206 return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
207 }
208
209 let mut config = Configuration::new(is_test, rt_multi_thread);
210 let macro_name = config.macro_name();
211
212 for arg in args {
213 match arg {
214 syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => {
215 let ident = namevalue
216 .path
217 .get_ident()
218 .ok_or_else(|| {
219 syn::Error::new_spanned(&namevalue, "Must have specified ident")
220 })?
221 .to_string()
222 .to_lowercase();
223 match ident.as_str() {
224 "worker_threads" => {
225 config.set_worker_threads(
226 namevalue.lit.clone(),
227 syn::spanned::Spanned::span(&namevalue.lit),
228 )?;
229 }
230 "flavor" => {
231 config.set_flavor(
232 namevalue.lit.clone(),
233 syn::spanned::Spanned::span(&namevalue.lit),
234 )?;
235 }
236 "start_paused" => {
237 config.set_start_paused(
238 namevalue.lit.clone(),
239 syn::spanned::Spanned::span(&namevalue.lit),
240 )?;
241 }
242 "core_threads" => {
243 let msg = "Attribute `core_threads` is renamed to `worker_threads`";
244 return Err(syn::Error::new_spanned(namevalue, msg));
245 }
246 name => {
247 let msg = format!(
248 "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`",
249 name,
250 );
251 return Err(syn::Error::new_spanned(namevalue, msg));
252 }
253 }
254 }
255 syn::NestedMeta::Meta(syn::Meta::Path(path)) => {
256 let name = path
257 .get_ident()
258 .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
259 .to_string()
260 .to_lowercase();
261 let msg = match name.as_str() {
262 "threaded_scheduler" | "multi_thread" => {
263 format!(
264 "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].",
265 macro_name
266 )
267 }
268 "basic_scheduler" | "current_thread" | "single_threaded" => {
269 format!(
270 "Set the runtime flavor with #[{}(flavor = \"current_thread\")].",
271 macro_name
272 )
273 }
274 "flavor" | "worker_threads" | "start_paused" => {
275 format!("The `{}` attribute requires an argument.", name)
276 }
277 name => {
278 format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", name)
279 }
280 };
281 return Err(syn::Error::new_spanned(path, msg));
282 }
283 other => {
284 return Err(syn::Error::new_spanned(
285 other,
286 "Unknown attribute inside the macro",
287 ));
288 }
289 }
290 }
291
292 config.build()
293 }
294
parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> TokenStream295 fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> TokenStream {
296 input.sig.asyncness = None;
297
298 // If type mismatch occurs, the current rustc points to the last statement.
299 let (last_stmt_start_span, last_stmt_end_span) = {
300 let mut last_stmt = input
301 .block
302 .stmts
303 .last()
304 .map(ToTokens::into_token_stream)
305 .unwrap_or_default()
306 .into_iter();
307 // `Span` on stable Rust has a limitation that only points to the first
308 // token, not the whole tokens. We can work around this limitation by
309 // using the first/last span of the tokens like
310 // `syn::Error::new_spanned` does.
311 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
312 let end = last_stmt.last().map_or(start, |t| t.span());
313 (start, end)
314 };
315
316 let mut rt = match config.flavor {
317 RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=>
318 tokio::runtime::Builder::new_current_thread()
319 },
320 RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=>
321 tokio::runtime::Builder::new_multi_thread()
322 },
323 };
324 if let Some(v) = config.worker_threads {
325 rt = quote! { #rt.worker_threads(#v) };
326 }
327 if let Some(v) = config.start_paused {
328 rt = quote! { #rt.start_paused(#v) };
329 }
330
331 let header = if is_test {
332 quote! {
333 #[::core::prelude::v1::test]
334 }
335 } else {
336 quote! {}
337 };
338
339 let body = &input.block;
340 let brace_token = input.block.brace_token;
341 let (tail_return, tail_semicolon) = match body.stmts.last() {
342 Some(syn::Stmt::Semi(syn::Expr::Return(_), _)) => (quote! { return }, quote! { ; }),
343 Some(syn::Stmt::Semi(..)) | Some(syn::Stmt::Local(..)) | None => {
344 match &input.sig.output {
345 syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) =>
346 {
347 (quote! {}, quote! { ; }) // unit
348 }
349 syn::ReturnType::Default => (quote! {}, quote! { ; }), // unit
350 syn::ReturnType::Type(..) => (quote! {}, quote! {}), // ! or another
351 }
352 }
353 _ => (quote! {}, quote! {}),
354 };
355 input.block = syn::parse2(quote_spanned! {last_stmt_end_span=>
356 {
357 let body = async #body;
358 #[allow(clippy::expect_used)]
359 #tail_return #rt
360 .enable_all()
361 .build()
362 .expect("Failed building the Runtime")
363 .block_on(body)#tail_semicolon
364 }
365 })
366 .expect("Parsing failure");
367 input.block.brace_token = brace_token;
368
369 let result = quote! {
370 #header
371 #input
372 };
373
374 result.into()
375 }
376
token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream377 fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
378 tokens.extend(TokenStream::from(error.into_compile_error()));
379 tokens
380 }
381
382 #[cfg(not(test))] // Work around for rust-lang/rust#62127
main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream383 pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
384 // If any of the steps for this macro fail, we still want to expand to an item that is as close
385 // to the expected output as possible. This helps out IDEs such that completions and other
386 // related features keep working.
387 let input: syn::ItemFn = match syn::parse(item.clone()) {
388 Ok(it) => it,
389 Err(e) => return token_stream_with_error(item, e),
390 };
391
392 let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
393 let msg = "the main function cannot accept arguments";
394 Err(syn::Error::new_spanned(&input.sig.ident, msg))
395 } else {
396 AttributeArgs::parse_terminated
397 .parse(args)
398 .and_then(|args| build_config(input.clone(), args, false, rt_multi_thread))
399 };
400
401 match config {
402 Ok(config) => parse_knobs(input, false, config),
403 Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e),
404 }
405 }
406
test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream407 pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
408 // If any of the steps for this macro fail, we still want to expand to an item that is as close
409 // to the expected output as possible. This helps out IDEs such that completions and other
410 // related features keep working.
411 let input: syn::ItemFn = match syn::parse(item.clone()) {
412 Ok(it) => it,
413 Err(e) => return token_stream_with_error(item, e),
414 };
415 let config = if let Some(attr) = input.attrs.iter().find(|attr| attr.path.is_ident("test")) {
416 let msg = "second test attribute is supplied";
417 Err(syn::Error::new_spanned(&attr, msg))
418 } else {
419 AttributeArgs::parse_terminated
420 .parse(args)
421 .and_then(|args| build_config(input.clone(), args, true, rt_multi_thread))
422 };
423
424 match config {
425 Ok(config) => parse_knobs(input, true, config),
426 Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e),
427 }
428 }
429