• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::bound::{has_bound, InferredBound, Supertraits};
2 use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3 use crate::parse::Item;
4 use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5 use crate::verbatim::VerbatimFn;
6 use proc_macro2::{Span, TokenStream};
7 use quote::{format_ident, quote, quote_spanned, ToTokens};
8 use std::collections::BTreeSet as Set;
9 use std::mem;
10 use syn::punctuated::Punctuated;
11 use syn::visit_mut::{self, VisitMut};
12 use syn::{
13     parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14     Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15     ReturnType, Signature, Token, TraitItem, Type, TypeInfer, TypePath, WhereClause,
16 };
17 
18 impl ToTokens for Item {
to_tokens(&self, tokens: &mut TokenStream)19     fn to_tokens(&self, tokens: &mut TokenStream) {
20         match self {
21             Item::Trait(item) => item.to_tokens(tokens),
22             Item::Impl(item) => item.to_tokens(tokens),
23         }
24     }
25 }
26 
27 #[derive(Clone, Copy)]
28 enum Context<'a> {
29     Trait {
30         generics: &'a Generics,
31         supertraits: &'a Supertraits,
32     },
33     Impl {
34         impl_generics: &'a Generics,
35         associated_type_impl_traits: &'a Set<Ident>,
36     },
37 }
38 
39 impl Context<'_> {
lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam>40     fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41         let generics = match self {
42             Context::Trait { generics, .. } => generics,
43             Context::Impl { impl_generics, .. } => impl_generics,
44         };
45         generics.params.iter().filter_map(move |param| {
46             if let GenericParam::Lifetime(param) = param {
47                 if used.contains(&param.lifetime) {
48                     return Some(param);
49                 }
50             }
51             None
52         })
53     }
54 }
55 
expand(input: &mut Item, is_local: bool)56 pub fn expand(input: &mut Item, is_local: bool) {
57     match input {
58         Item::Trait(input) => {
59             let context = Context::Trait {
60                 generics: &input.generics,
61                 supertraits: &input.supertraits,
62             };
63             for inner in &mut input.items {
64                 if let TraitItem::Fn(method) = inner {
65                     let sig = &mut method.sig;
66                     if sig.asyncness.is_some() {
67                         let block = &mut method.default;
68                         let mut has_self = has_self_in_sig(sig);
69                         method.attrs.push(parse_quote!(#[must_use]));
70                         if let Some(block) = block {
71                             has_self |= has_self_in_block(block);
72                             transform_block(context, sig, block);
73                             method.attrs.push(lint_suppress_with_body());
74                         } else {
75                             method.attrs.push(lint_suppress_without_body());
76                         }
77                         let has_default = method.default.is_some();
78                         transform_sig(context, sig, has_self, has_default, is_local);
79                     }
80                 }
81             }
82         }
83         Item::Impl(input) => {
84             let mut associated_type_impl_traits = Set::new();
85             for inner in &input.items {
86                 if let ImplItem::Type(assoc) = inner {
87                     if let Type::ImplTrait(_) = assoc.ty {
88                         associated_type_impl_traits.insert(assoc.ident.clone());
89                     }
90                 }
91             }
92 
93             let context = Context::Impl {
94                 impl_generics: &input.generics,
95                 associated_type_impl_traits: &associated_type_impl_traits,
96             };
97             for inner in &mut input.items {
98                 match inner {
99                     ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100                         let sig = &mut method.sig;
101                         let block = &mut method.block;
102                         let has_self = has_self_in_sig(sig);
103                         transform_block(context, sig, block);
104                         transform_sig(context, sig, has_self, false, is_local);
105                         method.attrs.push(lint_suppress_with_body());
106                     }
107                     ImplItem::Verbatim(tokens) => {
108                         let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109                             Ok(method) if method.sig.asyncness.is_some() => method,
110                             _ => continue,
111                         };
112                         let sig = &mut method.sig;
113                         let has_self = has_self_in_sig(sig);
114                         transform_sig(context, sig, has_self, false, is_local);
115                         method.attrs.push(lint_suppress_with_body());
116                         *tokens = quote!(#method);
117                     }
118                     _ => {}
119                 }
120             }
121         }
122     }
123 }
124 
lint_suppress_with_body() -> Attribute125 fn lint_suppress_with_body() -> Attribute {
126     parse_quote! {
127         #[allow(
128             elided_named_lifetimes,
129             clippy::async_yields_async,
130             clippy::diverging_sub_expression,
131             clippy::let_unit_value,
132             clippy::needless_arbitrary_self_type,
133             clippy::no_effect_underscore_binding,
134             clippy::shadow_same,
135             clippy::type_complexity,
136             clippy::type_repetition_in_bounds,
137             clippy::used_underscore_binding
138         )]
139     }
140 }
141 
lint_suppress_without_body() -> Attribute142 fn lint_suppress_without_body() -> Attribute {
143     parse_quote! {
144         #[allow(
145             elided_named_lifetimes,
146             clippy::type_complexity,
147             clippy::type_repetition_in_bounds
148         )]
149     }
150 }
151 
152 // Input:
153 //     async fn f<T>(&self, x: &T) -> Ret;
154 //
155 // Output:
156 //     fn f<'life0, 'life1, 'async_trait, T>(
157 //         &'life0 self,
158 //         x: &'life1 T,
159 //     ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
160 //     where
161 //         'life0: 'async_trait,
162 //         'life1: 'async_trait,
163 //         T: 'async_trait,
164 //         Self: Sync + 'async_trait;
transform_sig( context: Context, sig: &mut Signature, has_self: bool, has_default: bool, is_local: bool, )165 fn transform_sig(
166     context: Context,
167     sig: &mut Signature,
168     has_self: bool,
169     has_default: bool,
170     is_local: bool,
171 ) {
172     sig.fn_token.span = sig.asyncness.take().unwrap().span;
173 
174     let (ret_arrow, ret) = match &sig.output {
175         ReturnType::Default => (quote!(->), quote!(())),
176         ReturnType::Type(arrow, ret) => (quote!(#arrow), quote!(#ret)),
177     };
178 
179     let mut lifetimes = CollectLifetimes::new();
180     for arg in &mut sig.inputs {
181         match arg {
182             FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
183             FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
184         }
185     }
186 
187     for param in &mut sig.generics.params {
188         match param {
189             GenericParam::Type(param) => {
190                 let param_name = &param.ident;
191                 let span = match param.colon_token.take() {
192                     Some(colon_token) => colon_token.span,
193                     None => param_name.span(),
194                 };
195                 if param.attrs.is_empty() {
196                     let bounds = mem::take(&mut param.bounds);
197                     where_clause_or_default(&mut sig.generics.where_clause)
198                         .predicates
199                         .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
200                 } else {
201                     param.bounds.push(parse_quote!('async_trait));
202                 }
203             }
204             GenericParam::Lifetime(param) => {
205                 let param_name = &param.lifetime;
206                 let span = match param.colon_token.take() {
207                     Some(colon_token) => colon_token.span,
208                     None => param_name.span(),
209                 };
210                 if param.attrs.is_empty() {
211                     let bounds = mem::take(&mut param.bounds);
212                     where_clause_or_default(&mut sig.generics.where_clause)
213                         .predicates
214                         .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
215                 } else {
216                     param.bounds.push(parse_quote!('async_trait));
217                 }
218             }
219             GenericParam::Const(_) => {}
220         }
221     }
222 
223     for param in context.lifetimes(&lifetimes.explicit) {
224         let param = &param.lifetime;
225         let span = param.span();
226         where_clause_or_default(&mut sig.generics.where_clause)
227             .predicates
228             .push(parse_quote_spanned!(span=> #param: 'async_trait));
229     }
230 
231     if sig.generics.lt_token.is_none() {
232         sig.generics.lt_token = Some(Token![<](sig.ident.span()));
233     }
234     if sig.generics.gt_token.is_none() {
235         sig.generics.gt_token = Some(Token![>](sig.paren_token.span.join()));
236     }
237 
238     for elided in lifetimes.elided {
239         sig.generics.params.push(parse_quote!(#elided));
240         where_clause_or_default(&mut sig.generics.where_clause)
241             .predicates
242             .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
243     }
244 
245     sig.generics.params.push(parse_quote!('async_trait));
246 
247     if has_self {
248         let bounds: &[InferredBound] = if is_local {
249             &[]
250         } else if let Some(receiver) = sig.receiver() {
251             match receiver.ty.as_ref() {
252                 // self: &Self
253                 Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
254                 // self: Arc<Self>
255                 Type::Path(ty)
256                     if {
257                         let segment = ty.path.segments.last().unwrap();
258                         segment.ident == "Arc"
259                             && match &segment.arguments {
260                                 PathArguments::AngleBracketed(arguments) => {
261                                     arguments.args.len() == 1
262                                         && match &arguments.args[0] {
263                                             GenericArgument::Type(Type::Path(arg)) => {
264                                                 arg.path.is_ident("Self")
265                                             }
266                                             _ => false,
267                                         }
268                                 }
269                                 _ => false,
270                             }
271                     } =>
272                 {
273                     &[InferredBound::Sync, InferredBound::Send]
274                 }
275                 _ => &[InferredBound::Send],
276             }
277         } else {
278             &[InferredBound::Send]
279         };
280 
281         let bounds = bounds.iter().filter(|bound| match context {
282             Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound),
283             Context::Impl { .. } => false,
284         });
285 
286         where_clause_or_default(&mut sig.generics.where_clause)
287             .predicates
288             .push(parse_quote! {
289                 Self: #(#bounds +)* 'async_trait
290             });
291     }
292 
293     for (i, arg) in sig.inputs.iter_mut().enumerate() {
294         match arg {
295             FnArg::Receiver(receiver) => {
296                 if receiver.reference.is_none() {
297                     receiver.mutability = None;
298                 }
299             }
300             FnArg::Typed(arg) => {
301                 if match *arg.ty {
302                     Type::Reference(_) => false,
303                     _ => true,
304                 } {
305                     if let Pat::Ident(pat) = &mut *arg.pat {
306                         pat.by_ref = None;
307                         pat.mutability = None;
308                     } else {
309                         let positional = positional_arg(i, &arg.pat);
310                         let m = mut_pat(&mut arg.pat);
311                         arg.pat = parse_quote!(#m #positional);
312                     }
313                 }
314                 AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
315             }
316         }
317     }
318 
319     let bounds = if is_local {
320         quote!('async_trait)
321     } else {
322         quote!(::core::marker::Send + 'async_trait)
323     };
324     sig.output = parse_quote! {
325         #ret_arrow ::core::pin::Pin<Box<
326             dyn ::core::future::Future<Output = #ret> + #bounds
327         >>
328     };
329 }
330 
331 // Input:
332 //     async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
333 //         self + x + a + b
334 //     }
335 //
336 // Output:
337 //     Box::pin(async move {
338 //         let ___ret: Ret = {
339 //             let __self = self;
340 //             let x = x;
341 //             let (a, b) = __arg1;
342 //
343 //             __self + x + a + b
344 //         };
345 //
346 //         ___ret
347 //     })
transform_block(context: Context, sig: &mut Signature, block: &mut Block)348 fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
349     let mut replace_self = false;
350     let decls = sig
351         .inputs
352         .iter()
353         .enumerate()
354         .map(|(i, arg)| match arg {
355             FnArg::Receiver(Receiver {
356                 self_token,
357                 mutability,
358                 ..
359             }) => {
360                 replace_self = true;
361                 let ident = Ident::new("__self", self_token.span);
362                 quote!(let #mutability #ident = #self_token;)
363             }
364             FnArg::Typed(arg) => {
365                 // If there is a #[cfg(...)] attribute that selectively enables
366                 // the parameter, forward it to the variable.
367                 //
368                 // This is currently not applied to the `self` parameter.
369                 let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
370 
371                 if let Type::Reference(_) = *arg.ty {
372                     quote!()
373                 } else if let Pat::Ident(PatIdent {
374                     ident, mutability, ..
375                 }) = &*arg.pat
376                 {
377                     quote! {
378                         #(#attrs)*
379                         let #mutability #ident = #ident;
380                     }
381                 } else {
382                     let pat = &arg.pat;
383                     let ident = positional_arg(i, pat);
384                     if let Pat::Wild(_) = **pat {
385                         quote! {
386                             #(#attrs)*
387                             let #ident = #ident;
388                         }
389                     } else {
390                         quote! {
391                             #(#attrs)*
392                             let #pat = {
393                                 let #ident = #ident;
394                                 #ident
395                             };
396                         }
397                     }
398                 }
399             }
400         })
401         .collect::<Vec<_>>();
402 
403     if replace_self {
404         ReplaceSelf.visit_block_mut(block);
405     }
406 
407     let stmts = &block.stmts;
408     let let_ret = match &mut sig.output {
409         ReturnType::Default => quote_spanned! {block.brace_token.span=>
410             #(#decls)*
411             let () = { #(#stmts)* };
412         },
413         ReturnType::Type(_, ret) => {
414             if contains_associated_type_impl_trait(context, ret) {
415                 if decls.is_empty() {
416                     quote!(#(#stmts)*)
417                 } else {
418                     quote!(#(#decls)* { #(#stmts)* })
419                 }
420             } else {
421                 let mut ret = ret.clone();
422                 replace_impl_trait_with_infer(&mut ret);
423                 quote! {
424                     if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
425                         #[allow(unreachable_code)]
426                         return __ret;
427                     }
428                     #(#decls)*
429                     let __ret: #ret = { #(#stmts)* };
430                     #[allow(unreachable_code)]
431                     __ret
432                 }
433             }
434         }
435     };
436     let box_pin = quote_spanned!(block.brace_token.span=>
437         Box::pin(async move { #let_ret })
438     );
439     block.stmts = parse_quote!(#box_pin);
440 }
441 
positional_arg(i: usize, pat: &Pat) -> Ident442 fn positional_arg(i: usize, pat: &Pat) -> Ident {
443     let span = syn::spanned::Spanned::span(pat).resolved_at(Span::mixed_site());
444     format_ident!("__arg{}", i, span = span)
445 }
446 
contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool447 fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
448     struct AssociatedTypeImplTraits<'a> {
449         set: &'a Set<Ident>,
450         contains: bool,
451     }
452 
453     impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
454         fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
455             if ty.qself.is_none()
456                 && ty.path.segments.len() == 2
457                 && ty.path.segments[0].ident == "Self"
458                 && self.set.contains(&ty.path.segments[1].ident)
459             {
460                 self.contains = true;
461             }
462             visit_mut::visit_type_path_mut(self, ty);
463         }
464     }
465 
466     match context {
467         Context::Trait { .. } => false,
468         Context::Impl {
469             associated_type_impl_traits,
470             ..
471         } => {
472             let mut visit = AssociatedTypeImplTraits {
473                 set: associated_type_impl_traits,
474                 contains: false,
475             };
476             visit.visit_type_mut(ret);
477             visit.contains
478         }
479     }
480 }
481 
where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause482 fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
483     clause.get_or_insert_with(|| WhereClause {
484         where_token: Default::default(),
485         predicates: Punctuated::new(),
486     })
487 }
488 
replace_impl_trait_with_infer(ty: &mut Type)489 fn replace_impl_trait_with_infer(ty: &mut Type) {
490     struct ReplaceImplTraitWithInfer;
491 
492     impl VisitMut for ReplaceImplTraitWithInfer {
493         fn visit_type_mut(&mut self, ty: &mut Type) {
494             if let Type::ImplTrait(impl_trait) = ty {
495                 *ty = Type::Infer(TypeInfer {
496                     underscore_token: Token![_](impl_trait.impl_token.span),
497                 });
498             }
499             visit_mut::visit_type_mut(self, ty);
500         }
501     }
502 
503     ReplaceImplTraitWithInfer.visit_type_mut(ty);
504 }
505