• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::bound::{has_bound, InferredBound, Supertraits};
2 use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3 use crate::parse::Item;
4 use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5 use crate::verbatim::VerbatimFn;
6 use proc_macro2::{Span, TokenStream};
7 use quote::{format_ident, quote, quote_spanned, ToTokens};
8 use std::collections::BTreeSet as Set;
9 use std::mem;
10 use syn::punctuated::Punctuated;
11 use syn::visit_mut::{self, VisitMut};
12 use syn::{
13     parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14     Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15     ReturnType, Signature, Token, TraitItem, Type, TypePath, WhereClause,
16 };
17 
18 impl ToTokens for Item {
to_tokens(&self, tokens: &mut TokenStream)19     fn to_tokens(&self, tokens: &mut TokenStream) {
20         match self {
21             Item::Trait(item) => item.to_tokens(tokens),
22             Item::Impl(item) => item.to_tokens(tokens),
23         }
24     }
25 }
26 
27 #[derive(Clone, Copy)]
28 enum Context<'a> {
29     Trait {
30         generics: &'a Generics,
31         supertraits: &'a Supertraits,
32     },
33     Impl {
34         impl_generics: &'a Generics,
35         associated_type_impl_traits: &'a Set<Ident>,
36     },
37 }
38 
39 impl Context<'_> {
lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam>40     fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41         let generics = match self {
42             Context::Trait { generics, .. } => generics,
43             Context::Impl { impl_generics, .. } => impl_generics,
44         };
45         generics.params.iter().filter_map(move |param| {
46             if let GenericParam::Lifetime(param) = param {
47                 if used.contains(&param.lifetime) {
48                     return Some(param);
49                 }
50             }
51             None
52         })
53     }
54 }
55 
expand(input: &mut Item, is_local: bool)56 pub fn expand(input: &mut Item, is_local: bool) {
57     match input {
58         Item::Trait(input) => {
59             let context = Context::Trait {
60                 generics: &input.generics,
61                 supertraits: &input.supertraits,
62             };
63             for inner in &mut input.items {
64                 if let TraitItem::Fn(method) = inner {
65                     let sig = &mut method.sig;
66                     if sig.asyncness.is_some() {
67                         let block = &mut method.default;
68                         let mut has_self = has_self_in_sig(sig);
69                         method.attrs.push(parse_quote!(#[must_use]));
70                         if let Some(block) = block {
71                             has_self |= has_self_in_block(block);
72                             transform_block(context, sig, block);
73                             method.attrs.push(lint_suppress_with_body());
74                         } else {
75                             method.attrs.push(lint_suppress_without_body());
76                         }
77                         let has_default = method.default.is_some();
78                         transform_sig(context, sig, has_self, has_default, is_local);
79                     }
80                 }
81             }
82         }
83         Item::Impl(input) => {
84             let mut associated_type_impl_traits = Set::new();
85             for inner in &input.items {
86                 if let ImplItem::Type(assoc) = inner {
87                     if let Type::ImplTrait(_) = assoc.ty {
88                         associated_type_impl_traits.insert(assoc.ident.clone());
89                     }
90                 }
91             }
92 
93             let context = Context::Impl {
94                 impl_generics: &input.generics,
95                 associated_type_impl_traits: &associated_type_impl_traits,
96             };
97             for inner in &mut input.items {
98                 match inner {
99                     ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100                         let sig = &mut method.sig;
101                         let block = &mut method.block;
102                         let has_self = has_self_in_sig(sig) || has_self_in_block(block);
103                         transform_block(context, sig, block);
104                         transform_sig(context, sig, has_self, false, is_local);
105                         method.attrs.push(lint_suppress_with_body());
106                     }
107                     ImplItem::Verbatim(tokens) => {
108                         let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109                             Ok(method) if method.sig.asyncness.is_some() => method,
110                             _ => continue,
111                         };
112                         let sig = &mut method.sig;
113                         let has_self = has_self_in_sig(sig);
114                         transform_sig(context, sig, has_self, false, is_local);
115                         method.attrs.push(lint_suppress_with_body());
116                         *tokens = quote!(#method);
117                     }
118                     _ => {}
119                 }
120             }
121         }
122     }
123 }
124 
lint_suppress_with_body() -> Attribute125 fn lint_suppress_with_body() -> Attribute {
126     parse_quote! {
127         #[allow(
128             clippy::async_yields_async,
129             clippy::diverging_sub_expression,
130             clippy::let_unit_value,
131             clippy::no_effect_underscore_binding,
132             clippy::shadow_same,
133             clippy::type_complexity,
134             clippy::type_repetition_in_bounds,
135             clippy::used_underscore_binding
136         )]
137     }
138 }
139 
lint_suppress_without_body() -> Attribute140 fn lint_suppress_without_body() -> Attribute {
141     parse_quote! {
142         #[allow(
143             clippy::type_complexity,
144             clippy::type_repetition_in_bounds
145         )]
146     }
147 }
148 
149 // Input:
150 //     async fn f<T>(&self, x: &T) -> Ret;
151 //
152 // Output:
153 //     fn f<'life0, 'life1, 'async_trait, T>(
154 //         &'life0 self,
155 //         x: &'life1 T,
156 //     ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
157 //     where
158 //         'life0: 'async_trait,
159 //         'life1: 'async_trait,
160 //         T: 'async_trait,
161 //         Self: Sync + 'async_trait;
transform_sig( context: Context, sig: &mut Signature, has_self: bool, has_default: bool, is_local: bool, )162 fn transform_sig(
163     context: Context,
164     sig: &mut Signature,
165     has_self: bool,
166     has_default: bool,
167     is_local: bool,
168 ) {
169     let default_span = sig.asyncness.take().unwrap().span;
170     sig.fn_token.span = default_span;
171 
172     let (ret_arrow, ret) = match &sig.output {
173         ReturnType::Default => (Token![->](default_span), quote_spanned!(default_span=> ())),
174         ReturnType::Type(arrow, ret) => (*arrow, quote!(#ret)),
175     };
176 
177     let mut lifetimes = CollectLifetimes::new();
178     for arg in &mut sig.inputs {
179         match arg {
180             FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
181             FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
182         }
183     }
184 
185     for param in &mut sig.generics.params {
186         match param {
187             GenericParam::Type(param) => {
188                 let param_name = &param.ident;
189                 let span = match param.colon_token.take() {
190                     Some(colon_token) => colon_token.span,
191                     None => param_name.span(),
192                 };
193                 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
194                 where_clause_or_default(&mut sig.generics.where_clause)
195                     .predicates
196                     .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
197             }
198             GenericParam::Lifetime(param) => {
199                 let param_name = &param.lifetime;
200                 let span = match param.colon_token.take() {
201                     Some(colon_token) => colon_token.span,
202                     None => param_name.span(),
203                 };
204                 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
205                 where_clause_or_default(&mut sig.generics.where_clause)
206                     .predicates
207                     .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
208             }
209             GenericParam::Const(_) => {}
210         }
211     }
212 
213     for param in context.lifetimes(&lifetimes.explicit) {
214         let param = &param.lifetime;
215         let span = param.span();
216         where_clause_or_default(&mut sig.generics.where_clause)
217             .predicates
218             .push(parse_quote_spanned!(span=> #param: 'async_trait));
219     }
220 
221     if sig.generics.lt_token.is_none() {
222         sig.generics.lt_token = Some(Token![<](sig.ident.span()));
223     }
224     if sig.generics.gt_token.is_none() {
225         sig.generics.gt_token = Some(Token![>](sig.paren_token.span.join()));
226     }
227 
228     for elided in lifetimes.elided {
229         sig.generics.params.push(parse_quote!(#elided));
230         where_clause_or_default(&mut sig.generics.where_clause)
231             .predicates
232             .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
233     }
234 
235     sig.generics
236         .params
237         .push(parse_quote_spanned!(default_span=> 'async_trait));
238 
239     if has_self {
240         let bounds: &[InferredBound] = if let Some(receiver) = sig.receiver() {
241             match receiver.ty.as_ref() {
242                 // self: &Self
243                 Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
244                 // self: Arc<Self>
245                 Type::Path(ty)
246                     if {
247                         let segment = ty.path.segments.last().unwrap();
248                         segment.ident == "Arc"
249                             && match &segment.arguments {
250                                 PathArguments::AngleBracketed(arguments) => {
251                                     arguments.args.len() == 1
252                                         && match &arguments.args[0] {
253                                             GenericArgument::Type(Type::Path(arg)) => {
254                                                 arg.path.is_ident("Self")
255                                             }
256                                             _ => false,
257                                         }
258                                 }
259                                 _ => false,
260                             }
261                     } =>
262                 {
263                     &[InferredBound::Sync, InferredBound::Send]
264                 }
265                 _ => &[InferredBound::Send],
266             }
267         } else {
268             &[InferredBound::Send]
269         };
270 
271         let bounds = bounds.iter().filter_map(|bound| {
272             let assume_bound = match context {
273                 Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, bound),
274                 Context::Impl { .. } => true,
275             };
276             if assume_bound || is_local {
277                 None
278             } else {
279                 Some(bound.spanned_path(default_span))
280             }
281         });
282 
283         where_clause_or_default(&mut sig.generics.where_clause)
284             .predicates
285             .push(parse_quote_spanned! {default_span=>
286                 Self: #(#bounds +)* 'async_trait
287             });
288     }
289 
290     for (i, arg) in sig.inputs.iter_mut().enumerate() {
291         match arg {
292             FnArg::Receiver(receiver) => {
293                 if receiver.reference.is_none() {
294                     receiver.mutability = None;
295                 }
296             }
297             FnArg::Typed(arg) => {
298                 if match *arg.ty {
299                     Type::Reference(_) => false,
300                     _ => true,
301                 } {
302                     if let Pat::Ident(pat) = &mut *arg.pat {
303                         pat.by_ref = None;
304                         pat.mutability = None;
305                     } else {
306                         let positional = positional_arg(i, &arg.pat);
307                         let m = mut_pat(&mut arg.pat);
308                         arg.pat = parse_quote!(#m #positional);
309                     }
310                 }
311                 AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
312             }
313         }
314     }
315 
316     let bounds = if is_local {
317         quote_spanned!(default_span=> 'async_trait)
318     } else {
319         quote_spanned!(default_span=> ::core::marker::Send + 'async_trait)
320     };
321     sig.output = parse_quote_spanned! {default_span=>
322         #ret_arrow ::core::pin::Pin<Box<
323             dyn ::core::future::Future<Output = #ret> + #bounds
324         >>
325     };
326 }
327 
328 // Input:
329 //     async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
330 //         self + x + a + b
331 //     }
332 //
333 // Output:
334 //     Box::pin(async move {
335 //         let ___ret: Ret = {
336 //             let __self = self;
337 //             let x = x;
338 //             let (a, b) = __arg1;
339 //
340 //             __self + x + a + b
341 //         };
342 //
343 //         ___ret
344 //     })
transform_block(context: Context, sig: &mut Signature, block: &mut Block)345 fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
346     let mut self_span = None;
347     let decls = sig
348         .inputs
349         .iter()
350         .enumerate()
351         .map(|(i, arg)| match arg {
352             FnArg::Receiver(Receiver {
353                 self_token,
354                 mutability,
355                 ..
356             }) => {
357                 let ident = Ident::new("__self", self_token.span);
358                 self_span = Some(self_token.span);
359                 quote!(let #mutability #ident = #self_token;)
360             }
361             FnArg::Typed(arg) => {
362                 // If there is a #[cfg(...)] attribute that selectively enables
363                 // the parameter, forward it to the variable.
364                 //
365                 // This is currently not applied to the `self` parameter.
366                 let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
367 
368                 if let Type::Reference(_) = *arg.ty {
369                     quote!()
370                 } else if let Pat::Ident(PatIdent {
371                     ident, mutability, ..
372                 }) = &*arg.pat
373                 {
374                     quote! {
375                         #(#attrs)*
376                         let #mutability #ident = #ident;
377                     }
378                 } else {
379                     let pat = &arg.pat;
380                     let ident = positional_arg(i, pat);
381                     if let Pat::Wild(_) = **pat {
382                         quote! {
383                             #(#attrs)*
384                             let #ident = #ident;
385                         }
386                     } else {
387                         quote! {
388                             #(#attrs)*
389                             let #pat = {
390                                 let #ident = #ident;
391                                 #ident
392                             };
393                         }
394                     }
395                 }
396             }
397         })
398         .collect::<Vec<_>>();
399 
400     if let Some(span) = self_span {
401         let mut replace_self = ReplaceSelf(span);
402         replace_self.visit_block_mut(block);
403     }
404 
405     let stmts = &block.stmts;
406     let let_ret = match &mut sig.output {
407         ReturnType::Default => quote_spanned! {block.brace_token.span=>
408             #(#decls)*
409             let () = { #(#stmts)* };
410         },
411         ReturnType::Type(_, ret) => {
412             if contains_associated_type_impl_trait(context, ret) {
413                 if decls.is_empty() {
414                     quote!(#(#stmts)*)
415                 } else {
416                     quote!(#(#decls)* { #(#stmts)* })
417                 }
418             } else {
419                 quote_spanned! {block.brace_token.span=>
420                     if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
421                         return __ret;
422                     }
423                     #(#decls)*
424                     let __ret: #ret = { #(#stmts)* };
425                     #[allow(unreachable_code)]
426                     __ret
427                 }
428             }
429         }
430     };
431     let box_pin = quote_spanned!(block.brace_token.span=>
432         Box::pin(async move { #let_ret })
433     );
434     block.stmts = parse_quote!(#box_pin);
435 }
436 
positional_arg(i: usize, pat: &Pat) -> Ident437 fn positional_arg(i: usize, pat: &Pat) -> Ident {
438     let span: Span = syn::spanned::Spanned::span(pat);
439     #[cfg(not(no_span_mixed_site))]
440     let span = span.resolved_at(Span::mixed_site());
441     format_ident!("__arg{}", i, span = span)
442 }
443 
contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool444 fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
445     struct AssociatedTypeImplTraits<'a> {
446         set: &'a Set<Ident>,
447         contains: bool,
448     }
449 
450     impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
451         fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
452             if ty.qself.is_none()
453                 && ty.path.segments.len() == 2
454                 && ty.path.segments[0].ident == "Self"
455                 && self.set.contains(&ty.path.segments[1].ident)
456             {
457                 self.contains = true;
458             }
459             visit_mut::visit_type_path_mut(self, ty);
460         }
461     }
462 
463     match context {
464         Context::Trait { .. } => false,
465         Context::Impl {
466             associated_type_impl_traits,
467             ..
468         } => {
469             let mut visit = AssociatedTypeImplTraits {
470                 set: associated_type_impl_traits,
471                 contains: false,
472             };
473             visit.visit_type_mut(ret);
474             visit.contains
475         }
476     }
477 }
478 
where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause479 fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
480     clause.get_or_insert_with(|| WhereClause {
481         where_token: Default::default(),
482         predicates: Punctuated::new(),
483     })
484 }
485