1 use crate::bound::{has_bound, InferredBound, Supertraits};
2 use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3 use crate::parse::Item;
4 use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5 use crate::verbatim::VerbatimFn;
6 use proc_macro2::{Span, TokenStream};
7 use quote::{format_ident, quote, quote_spanned, ToTokens};
8 use std::collections::BTreeSet as Set;
9 use std::mem;
10 use syn::punctuated::Punctuated;
11 use syn::visit_mut::{self, VisitMut};
12 use syn::{
13 parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14 Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15 ReturnType, Signature, Token, TraitItem, Type, TypePath, WhereClause,
16 };
17
18 impl ToTokens for Item {
to_tokens(&self, tokens: &mut TokenStream)19 fn to_tokens(&self, tokens: &mut TokenStream) {
20 match self {
21 Item::Trait(item) => item.to_tokens(tokens),
22 Item::Impl(item) => item.to_tokens(tokens),
23 }
24 }
25 }
26
27 #[derive(Clone, Copy)]
28 enum Context<'a> {
29 Trait {
30 generics: &'a Generics,
31 supertraits: &'a Supertraits,
32 },
33 Impl {
34 impl_generics: &'a Generics,
35 associated_type_impl_traits: &'a Set<Ident>,
36 },
37 }
38
39 impl Context<'_> {
lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam>40 fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41 let generics = match self {
42 Context::Trait { generics, .. } => generics,
43 Context::Impl { impl_generics, .. } => impl_generics,
44 };
45 generics.params.iter().filter_map(move |param| {
46 if let GenericParam::Lifetime(param) = param {
47 if used.contains(¶m.lifetime) {
48 return Some(param);
49 }
50 }
51 None
52 })
53 }
54 }
55
expand(input: &mut Item, is_local: bool)56 pub fn expand(input: &mut Item, is_local: bool) {
57 match input {
58 Item::Trait(input) => {
59 let context = Context::Trait {
60 generics: &input.generics,
61 supertraits: &input.supertraits,
62 };
63 for inner in &mut input.items {
64 if let TraitItem::Fn(method) = inner {
65 let sig = &mut method.sig;
66 if sig.asyncness.is_some() {
67 let block = &mut method.default;
68 let mut has_self = has_self_in_sig(sig);
69 method.attrs.push(parse_quote!(#[must_use]));
70 if let Some(block) = block {
71 has_self |= has_self_in_block(block);
72 transform_block(context, sig, block);
73 method.attrs.push(lint_suppress_with_body());
74 } else {
75 method.attrs.push(lint_suppress_without_body());
76 }
77 let has_default = method.default.is_some();
78 transform_sig(context, sig, has_self, has_default, is_local);
79 }
80 }
81 }
82 }
83 Item::Impl(input) => {
84 let mut associated_type_impl_traits = Set::new();
85 for inner in &input.items {
86 if let ImplItem::Type(assoc) = inner {
87 if let Type::ImplTrait(_) = assoc.ty {
88 associated_type_impl_traits.insert(assoc.ident.clone());
89 }
90 }
91 }
92
93 let context = Context::Impl {
94 impl_generics: &input.generics,
95 associated_type_impl_traits: &associated_type_impl_traits,
96 };
97 for inner in &mut input.items {
98 match inner {
99 ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100 let sig = &mut method.sig;
101 let block = &mut method.block;
102 let has_self = has_self_in_sig(sig) || has_self_in_block(block);
103 transform_block(context, sig, block);
104 transform_sig(context, sig, has_self, false, is_local);
105 method.attrs.push(lint_suppress_with_body());
106 }
107 ImplItem::Verbatim(tokens) => {
108 let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109 Ok(method) if method.sig.asyncness.is_some() => method,
110 _ => continue,
111 };
112 let sig = &mut method.sig;
113 let has_self = has_self_in_sig(sig);
114 transform_sig(context, sig, has_self, false, is_local);
115 method.attrs.push(lint_suppress_with_body());
116 *tokens = quote!(#method);
117 }
118 _ => {}
119 }
120 }
121 }
122 }
123 }
124
lint_suppress_with_body() -> Attribute125 fn lint_suppress_with_body() -> Attribute {
126 parse_quote! {
127 #[allow(
128 clippy::async_yields_async,
129 clippy::diverging_sub_expression,
130 clippy::let_unit_value,
131 clippy::no_effect_underscore_binding,
132 clippy::shadow_same,
133 clippy::type_complexity,
134 clippy::type_repetition_in_bounds,
135 clippy::used_underscore_binding
136 )]
137 }
138 }
139
lint_suppress_without_body() -> Attribute140 fn lint_suppress_without_body() -> Attribute {
141 parse_quote! {
142 #[allow(
143 clippy::type_complexity,
144 clippy::type_repetition_in_bounds
145 )]
146 }
147 }
148
149 // Input:
150 // async fn f<T>(&self, x: &T) -> Ret;
151 //
152 // Output:
153 // fn f<'life0, 'life1, 'async_trait, T>(
154 // &'life0 self,
155 // x: &'life1 T,
156 // ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
157 // where
158 // 'life0: 'async_trait,
159 // 'life1: 'async_trait,
160 // T: 'async_trait,
161 // Self: Sync + 'async_trait;
transform_sig( context: Context, sig: &mut Signature, has_self: bool, has_default: bool, is_local: bool, )162 fn transform_sig(
163 context: Context,
164 sig: &mut Signature,
165 has_self: bool,
166 has_default: bool,
167 is_local: bool,
168 ) {
169 let default_span = sig.asyncness.take().unwrap().span;
170 sig.fn_token.span = default_span;
171
172 let (ret_arrow, ret) = match &sig.output {
173 ReturnType::Default => (Token, quote_spanned!(default_span=> ())),
174 ReturnType::Type(arrow, ret) => (*arrow, quote!(#ret)),
175 };
176
177 let mut lifetimes = CollectLifetimes::new();
178 for arg in &mut sig.inputs {
179 match arg {
180 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
181 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
182 }
183 }
184
185 for param in &mut sig.generics.params {
186 match param {
187 GenericParam::Type(param) => {
188 let param_name = ¶m.ident;
189 let span = match param.colon_token.take() {
190 Some(colon_token) => colon_token.span,
191 None => param_name.span(),
192 };
193 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
194 where_clause_or_default(&mut sig.generics.where_clause)
195 .predicates
196 .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
197 }
198 GenericParam::Lifetime(param) => {
199 let param_name = ¶m.lifetime;
200 let span = match param.colon_token.take() {
201 Some(colon_token) => colon_token.span,
202 None => param_name.span(),
203 };
204 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
205 where_clause_or_default(&mut sig.generics.where_clause)
206 .predicates
207 .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
208 }
209 GenericParam::Const(_) => {}
210 }
211 }
212
213 for param in context.lifetimes(&lifetimes.explicit) {
214 let param = ¶m.lifetime;
215 let span = param.span();
216 where_clause_or_default(&mut sig.generics.where_clause)
217 .predicates
218 .push(parse_quote_spanned!(span=> #param: 'async_trait));
219 }
220
221 if sig.generics.lt_token.is_none() {
222 sig.generics.lt_token = Some(Token));
223 }
224 if sig.generics.gt_token.is_none() {
225 sig.generics.gt_token = Some(Token));
226 }
227
228 for elided in lifetimes.elided {
229 sig.generics.params.push(parse_quote!(#elided));
230 where_clause_or_default(&mut sig.generics.where_clause)
231 .predicates
232 .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
233 }
234
235 sig.generics
236 .params
237 .push(parse_quote_spanned!(default_span=> 'async_trait));
238
239 if has_self {
240 let bounds: &[InferredBound] = if let Some(receiver) = sig.receiver() {
241 match receiver.ty.as_ref() {
242 // self: &Self
243 Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
244 // self: Arc<Self>
245 Type::Path(ty)
246 if {
247 let segment = ty.path.segments.last().unwrap();
248 segment.ident == "Arc"
249 && match &segment.arguments {
250 PathArguments::AngleBracketed(arguments) => {
251 arguments.args.len() == 1
252 && match &arguments.args[0] {
253 GenericArgument::Type(Type::Path(arg)) => {
254 arg.path.is_ident("Self")
255 }
256 _ => false,
257 }
258 }
259 _ => false,
260 }
261 } =>
262 {
263 &[InferredBound::Sync, InferredBound::Send]
264 }
265 _ => &[InferredBound::Send],
266 }
267 } else {
268 &[InferredBound::Send]
269 };
270
271 let bounds = bounds.iter().filter_map(|bound| {
272 let assume_bound = match context {
273 Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, bound),
274 Context::Impl { .. } => true,
275 };
276 if assume_bound || is_local {
277 None
278 } else {
279 Some(bound.spanned_path(default_span))
280 }
281 });
282
283 where_clause_or_default(&mut sig.generics.where_clause)
284 .predicates
285 .push(parse_quote_spanned! {default_span=>
286 Self: #(#bounds +)* 'async_trait
287 });
288 }
289
290 for (i, arg) in sig.inputs.iter_mut().enumerate() {
291 match arg {
292 FnArg::Receiver(receiver) => {
293 if receiver.reference.is_none() {
294 receiver.mutability = None;
295 }
296 }
297 FnArg::Typed(arg) => {
298 if match *arg.ty {
299 Type::Reference(_) => false,
300 _ => true,
301 } {
302 if let Pat::Ident(pat) = &mut *arg.pat {
303 pat.by_ref = None;
304 pat.mutability = None;
305 } else {
306 let positional = positional_arg(i, &arg.pat);
307 let m = mut_pat(&mut arg.pat);
308 arg.pat = parse_quote!(#m #positional);
309 }
310 }
311 AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
312 }
313 }
314 }
315
316 let bounds = if is_local {
317 quote_spanned!(default_span=> 'async_trait)
318 } else {
319 quote_spanned!(default_span=> ::core::marker::Send + 'async_trait)
320 };
321 sig.output = parse_quote_spanned! {default_span=>
322 #ret_arrow ::core::pin::Pin<Box<
323 dyn ::core::future::Future<Output = #ret> + #bounds
324 >>
325 };
326 }
327
328 // Input:
329 // async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
330 // self + x + a + b
331 // }
332 //
333 // Output:
334 // Box::pin(async move {
335 // let ___ret: Ret = {
336 // let __self = self;
337 // let x = x;
338 // let (a, b) = __arg1;
339 //
340 // __self + x + a + b
341 // };
342 //
343 // ___ret
344 // })
transform_block(context: Context, sig: &mut Signature, block: &mut Block)345 fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
346 let mut self_span = None;
347 let decls = sig
348 .inputs
349 .iter()
350 .enumerate()
351 .map(|(i, arg)| match arg {
352 FnArg::Receiver(Receiver {
353 self_token,
354 mutability,
355 ..
356 }) => {
357 let ident = Ident::new("__self", self_token.span);
358 self_span = Some(self_token.span);
359 quote!(let #mutability #ident = #self_token;)
360 }
361 FnArg::Typed(arg) => {
362 // If there is a #[cfg(...)] attribute that selectively enables
363 // the parameter, forward it to the variable.
364 //
365 // This is currently not applied to the `self` parameter.
366 let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
367
368 if let Type::Reference(_) = *arg.ty {
369 quote!()
370 } else if let Pat::Ident(PatIdent {
371 ident, mutability, ..
372 }) = &*arg.pat
373 {
374 quote! {
375 #(#attrs)*
376 let #mutability #ident = #ident;
377 }
378 } else {
379 let pat = &arg.pat;
380 let ident = positional_arg(i, pat);
381 if let Pat::Wild(_) = **pat {
382 quote! {
383 #(#attrs)*
384 let #ident = #ident;
385 }
386 } else {
387 quote! {
388 #(#attrs)*
389 let #pat = {
390 let #ident = #ident;
391 #ident
392 };
393 }
394 }
395 }
396 }
397 })
398 .collect::<Vec<_>>();
399
400 if let Some(span) = self_span {
401 let mut replace_self = ReplaceSelf(span);
402 replace_self.visit_block_mut(block);
403 }
404
405 let stmts = &block.stmts;
406 let let_ret = match &mut sig.output {
407 ReturnType::Default => quote_spanned! {block.brace_token.span=>
408 #(#decls)*
409 let () = { #(#stmts)* };
410 },
411 ReturnType::Type(_, ret) => {
412 if contains_associated_type_impl_trait(context, ret) {
413 if decls.is_empty() {
414 quote!(#(#stmts)*)
415 } else {
416 quote!(#(#decls)* { #(#stmts)* })
417 }
418 } else {
419 quote_spanned! {block.brace_token.span=>
420 if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
421 return __ret;
422 }
423 #(#decls)*
424 let __ret: #ret = { #(#stmts)* };
425 #[allow(unreachable_code)]
426 __ret
427 }
428 }
429 }
430 };
431 let box_pin = quote_spanned!(block.brace_token.span=>
432 Box::pin(async move { #let_ret })
433 );
434 block.stmts = parse_quote!(#box_pin);
435 }
436
positional_arg(i: usize, pat: &Pat) -> Ident437 fn positional_arg(i: usize, pat: &Pat) -> Ident {
438 let span: Span = syn::spanned::Spanned::span(pat);
439 #[cfg(not(no_span_mixed_site))]
440 let span = span.resolved_at(Span::mixed_site());
441 format_ident!("__arg{}", i, span = span)
442 }
443
contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool444 fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
445 struct AssociatedTypeImplTraits<'a> {
446 set: &'a Set<Ident>,
447 contains: bool,
448 }
449
450 impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
451 fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
452 if ty.qself.is_none()
453 && ty.path.segments.len() == 2
454 && ty.path.segments[0].ident == "Self"
455 && self.set.contains(&ty.path.segments[1].ident)
456 {
457 self.contains = true;
458 }
459 visit_mut::visit_type_path_mut(self, ty);
460 }
461 }
462
463 match context {
464 Context::Trait { .. } => false,
465 Context::Impl {
466 associated_type_impl_traits,
467 ..
468 } => {
469 let mut visit = AssociatedTypeImplTraits {
470 set: associated_type_impl_traits,
471 contains: false,
472 };
473 visit.visit_type_mut(ret);
474 visit.contains
475 }
476 }
477 }
478
where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause479 fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
480 clause.get_or_insert_with(|| WhereClause {
481 where_token: Default::default(),
482 predicates: Punctuated::new(),
483 })
484 }
485