• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::{iter::FromIterator, mem};
2 
3 use proc_macro2::{Group, Spacing, Span, TokenStream, TokenTree};
4 use quote::{quote, quote_spanned, ToTokens};
5 use syn::{
6     parse::{Parse, ParseBuffer, ParseStream},
7     parse_quote,
8     punctuated::Punctuated,
9     token,
10     visit_mut::{self, VisitMut},
11     Attribute, ExprPath, ExprStruct, Generics, Ident, Item, Lifetime, LifetimeDef, Macro, PatPath,
12     PatStruct, PatTupleStruct, Path, PathArguments, PredicateType, QSelf, Result, Token, Type,
13     TypeParamBound, TypePath, Variant, Visibility, WherePredicate,
14 };
15 
16 pub(crate) type Variants = Punctuated<Variant, Token![,]>;
17 
18 macro_rules! format_err {
19     ($span:expr, $msg:expr $(,)?) => {
20         syn::Error::new_spanned(&$span as &dyn quote::ToTokens, &$msg as &dyn std::fmt::Display)
21     };
22     ($span:expr, $($tt:tt)*) => {
23         format_err!($span, format!($($tt)*))
24     };
25 }
26 
27 macro_rules! bail {
28     ($($tt:tt)*) => {
29         return Err(format_err!($($tt)*))
30     };
31 }
32 
33 macro_rules! parse_quote_spanned {
34     ($span:expr => $($tt:tt)*) => {
35         syn::parse2(quote::quote_spanned!($span => $($tt)*)).unwrap_or_else(|e| panic!("{}", e))
36     };
37 }
38 
39 /// Determines the lifetime names. Ensure it doesn't overlap with any existing
40 /// lifetime names.
determine_lifetime_name(lifetime_name: &mut String, generics: &mut Generics)41 pub(crate) fn determine_lifetime_name(lifetime_name: &mut String, generics: &mut Generics) {
42     struct CollectLifetimes(Vec<String>);
43 
44     impl VisitMut for CollectLifetimes {
45         fn visit_lifetime_def_mut(&mut self, def: &mut LifetimeDef) {
46             self.0.push(def.lifetime.to_string());
47         }
48     }
49 
50     debug_assert!(lifetime_name.starts_with('\''));
51 
52     let mut lifetimes = CollectLifetimes(Vec::new());
53     lifetimes.visit_generics_mut(generics);
54 
55     while lifetimes.0.iter().any(|name| name.starts_with(&**lifetime_name)) {
56         lifetime_name.push('_');
57     }
58 }
59 
60 /// Like `insert_lifetime`, but also generates a bound of the form
61 /// `OriginalType<A, B>: 'lifetime`. Used when generating the definition
62 /// of a projection type
insert_lifetime_and_bound( generics: &mut Generics, lifetime: Lifetime, orig_generics: &Generics, orig_ident: &Ident, ) -> WherePredicate63 pub(crate) fn insert_lifetime_and_bound(
64     generics: &mut Generics,
65     lifetime: Lifetime,
66     orig_generics: &Generics,
67     orig_ident: &Ident,
68 ) -> WherePredicate {
69     insert_lifetime(generics, lifetime.clone());
70 
71     let orig_type: Type = parse_quote!(#orig_ident #orig_generics);
72     let mut punct = Punctuated::new();
73     punct.push(TypeParamBound::Lifetime(lifetime));
74 
75     WherePredicate::Type(PredicateType {
76         lifetimes: None,
77         bounded_ty: orig_type,
78         colon_token: <Token![:]>::default(),
79         bounds: punct,
80     })
81 }
82 
83 /// Inserts a `lifetime` at position `0` of `generics.params`.
insert_lifetime(generics: &mut Generics, lifetime: Lifetime)84 pub(crate) fn insert_lifetime(generics: &mut Generics, lifetime: Lifetime) {
85     generics.lt_token.get_or_insert_with(<Token![<]>::default);
86     generics.gt_token.get_or_insert_with(<Token![>]>::default);
87     generics.params.insert(0, LifetimeDef::new(lifetime).into());
88 }
89 
90 /// Determines the visibility of the projected types and projection methods.
91 ///
92 /// If given visibility is `pub`, returned visibility is `pub(crate)`.
93 /// Otherwise, returned visibility is the same as given visibility.
determine_visibility(vis: &Visibility) -> Visibility94 pub(crate) fn determine_visibility(vis: &Visibility) -> Visibility {
95     if let Visibility::Public(token) = vis {
96         parse_quote_spanned!(token.pub_token.span => pub(crate))
97     } else {
98         vis.clone()
99     }
100 }
101 
102 /// Checks if `tokens` is an empty `TokenStream`.
103 ///
104 /// This is almost equivalent to `syn::parse2::<Nothing>()`, but produces
105 /// a better error message and does not require ownership of `tokens`.
parse_as_empty(tokens: &TokenStream) -> Result<()>106 pub(crate) fn parse_as_empty(tokens: &TokenStream) -> Result<()> {
107     if tokens.is_empty() {
108         Ok(())
109     } else {
110         bail!(tokens, "unexpected token: `{}`", tokens)
111     }
112 }
113 
respan<T>(node: &T, span: Span) -> T where T: ToTokens + Parse,114 pub(crate) fn respan<T>(node: &T, span: Span) -> T
115 where
116     T: ToTokens + Parse,
117 {
118     let tokens = node.to_token_stream();
119     let respanned = respan_tokens(tokens, span);
120     syn::parse2(respanned).unwrap()
121 }
122 
respan_tokens(tokens: TokenStream, span: Span) -> TokenStream123 fn respan_tokens(tokens: TokenStream, span: Span) -> TokenStream {
124     tokens
125         .into_iter()
126         .map(|mut token| {
127             token.set_span(span);
128             token
129         })
130         .collect()
131 }
132 
133 // =================================================================================================
134 // extension traits
135 
136 pub(crate) trait SliceExt {
position_exact(&self, ident: &str) -> Result<Option<usize>>137     fn position_exact(&self, ident: &str) -> Result<Option<usize>>;
find(&self, ident: &str) -> Option<&Attribute>138     fn find(&self, ident: &str) -> Option<&Attribute>;
139 }
140 
141 impl SliceExt for [Attribute] {
142     /// # Errors
143     ///
144     /// - There are multiple specified attributes.
145     /// - The `Attribute::tokens` field of the specified attribute is not empty.
position_exact(&self, ident: &str) -> Result<Option<usize>>146     fn position_exact(&self, ident: &str) -> Result<Option<usize>> {
147         self.iter()
148             .try_fold((0, None), |(i, mut prev), attr| {
149                 if attr.path.is_ident(ident) {
150                     if prev.replace(i).is_some() {
151                         bail!(attr, "duplicate #[{}] attribute", ident);
152                     }
153                     parse_as_empty(&attr.tokens)?;
154                 }
155                 Ok((i + 1, prev))
156             })
157             .map(|(_, pos)| pos)
158     }
159 
find(&self, ident: &str) -> Option<&Attribute>160     fn find(&self, ident: &str) -> Option<&Attribute> {
161         self.iter().position(|attr| attr.path.is_ident(ident)).map(|i| &self[i])
162     }
163 }
164 
165 pub(crate) trait ParseBufferExt<'a> {
parenthesized(self) -> Result<ParseBuffer<'a>>166     fn parenthesized(self) -> Result<ParseBuffer<'a>>;
167 }
168 
169 impl<'a> ParseBufferExt<'a> for ParseStream<'a> {
parenthesized(self) -> Result<ParseBuffer<'a>>170     fn parenthesized(self) -> Result<ParseBuffer<'a>> {
171         let content;
172         let _: token::Paren = syn::parenthesized!(content in self);
173         Ok(content)
174     }
175 }
176 
177 impl<'a> ParseBufferExt<'a> for ParseBuffer<'a> {
parenthesized(self) -> Result<ParseBuffer<'a>>178     fn parenthesized(self) -> Result<ParseBuffer<'a>> {
179         let content;
180         let _: token::Paren = syn::parenthesized!(content in self);
181         Ok(content)
182     }
183 }
184 
185 // =================================================================================================
186 // visitors
187 
188 // Replace `self`/`Self` with `__self`/`self_ty`.
189 // Based on:
190 // - https://github.com/dtolnay/async-trait/blob/0.1.35/src/receiver.rs
191 // - https://github.com/dtolnay/async-trait/commit/6029cbf375c562ca98fa5748e9d950a8ff93b0e7
192 
193 pub(crate) struct ReplaceReceiver<'a>(pub(crate) &'a TypePath);
194 
195 impl ReplaceReceiver<'_> {
self_ty(&self, span: Span) -> TypePath196     fn self_ty(&self, span: Span) -> TypePath {
197         respan(self.0, span)
198     }
199 
self_to_qself(&self, qself: &mut Option<QSelf>, path: &mut Path)200     fn self_to_qself(&self, qself: &mut Option<QSelf>, path: &mut Path) {
201         if path.leading_colon.is_some() {
202             return;
203         }
204 
205         let first = &path.segments[0];
206         if first.ident != "Self" || !first.arguments.is_empty() {
207             return;
208         }
209 
210         if path.segments.len() == 1 {
211             self.self_to_expr_path(path);
212             return;
213         }
214 
215         let span = first.ident.span();
216         *qself = Some(QSelf {
217             lt_token: Token![<](span),
218             ty: Box::new(self.self_ty(span).into()),
219             position: 0,
220             as_token: None,
221             gt_token: Token![>](span),
222         });
223 
224         path.leading_colon = Some(**path.segments.pairs().next().unwrap().punct().unwrap());
225 
226         let segments = mem::replace(&mut path.segments, Punctuated::new());
227         path.segments = segments.into_pairs().skip(1).collect();
228     }
229 
self_to_expr_path(&self, path: &mut Path)230     fn self_to_expr_path(&self, path: &mut Path) {
231         if path.leading_colon.is_some() {
232             return;
233         }
234 
235         let first = &path.segments[0];
236         if first.ident != "Self" || !first.arguments.is_empty() {
237             return;
238         }
239 
240         let self_ty = self.self_ty(first.ident.span());
241         let variant = mem::replace(path, self_ty.path);
242         for segment in &mut path.segments {
243             if let PathArguments::AngleBracketed(bracketed) = &mut segment.arguments {
244                 if bracketed.colon2_token.is_none() && !bracketed.args.is_empty() {
245                     bracketed.colon2_token = Some(<Token![::]>::default());
246                 }
247             }
248         }
249         if variant.segments.len() > 1 {
250             path.segments.push_punct(<Token![::]>::default());
251             path.segments.extend(variant.segments.into_pairs().skip(1));
252         }
253     }
254 
visit_token_stream(&self, tokens: &mut TokenStream) -> bool255     fn visit_token_stream(&self, tokens: &mut TokenStream) -> bool {
256         let mut out = Vec::new();
257         let mut modified = false;
258         let mut iter = tokens.clone().into_iter().peekable();
259         while let Some(tt) = iter.next() {
260             match tt {
261                 TokenTree::Ident(mut ident) => {
262                     modified |= prepend_underscore_to_self(&mut ident);
263                     if ident == "Self" {
264                         modified = true;
265                         let self_ty = self.self_ty(ident.span());
266                         match iter.peek() {
267                             Some(TokenTree::Punct(p))
268                                 if p.as_char() == ':' && p.spacing() == Spacing::Joint =>
269                             {
270                                 let next = iter.next().unwrap();
271                                 match iter.peek() {
272                                     Some(TokenTree::Punct(p)) if p.as_char() == ':' => {
273                                         let span = ident.span();
274                                         out.extend(quote_spanned!(span=> <#self_ty>));
275                                     }
276                                     _ => out.extend(quote!(#self_ty)),
277                                 }
278                                 out.push(next);
279                             }
280                             _ => out.extend(quote!(#self_ty)),
281                         }
282                     } else {
283                         out.push(TokenTree::Ident(ident));
284                     }
285                 }
286                 TokenTree::Group(group) => {
287                     let mut content = group.stream();
288                     modified |= self.visit_token_stream(&mut content);
289                     let mut new = Group::new(group.delimiter(), content);
290                     new.set_span(group.span());
291                     out.push(TokenTree::Group(new));
292                 }
293                 other => out.push(other),
294             }
295         }
296         if modified {
297             *tokens = TokenStream::from_iter(out);
298         }
299         modified
300     }
301 }
302 
303 impl VisitMut for ReplaceReceiver<'_> {
304     // `Self` -> `Receiver`
visit_type_mut(&mut self, ty: &mut Type)305     fn visit_type_mut(&mut self, ty: &mut Type) {
306         if let Type::Path(node) = ty {
307             if node.qself.is_none() && node.path.is_ident("Self") {
308                 *ty = self.self_ty(node.path.segments[0].ident.span()).into();
309             } else {
310                 self.visit_type_path_mut(node);
311             }
312         } else {
313             visit_mut::visit_type_mut(self, ty);
314         }
315     }
316 
317     // `Self::Assoc` -> `<Receiver>::Assoc`
visit_type_path_mut(&mut self, ty: &mut TypePath)318     fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
319         if ty.qself.is_none() {
320             self.self_to_qself(&mut ty.qself, &mut ty.path);
321         }
322         visit_mut::visit_type_path_mut(self, ty);
323     }
324 
325     // `Self::method` -> `<Receiver>::method`
visit_expr_path_mut(&mut self, expr: &mut ExprPath)326     fn visit_expr_path_mut(&mut self, expr: &mut ExprPath) {
327         if expr.qself.is_none() {
328             self.self_to_qself(&mut expr.qself, &mut expr.path);
329         }
330         visit_mut::visit_expr_path_mut(self, expr);
331     }
332 
visit_expr_struct_mut(&mut self, expr: &mut ExprStruct)333     fn visit_expr_struct_mut(&mut self, expr: &mut ExprStruct) {
334         self.self_to_expr_path(&mut expr.path);
335         visit_mut::visit_expr_struct_mut(self, expr);
336     }
337 
visit_pat_path_mut(&mut self, pat: &mut PatPath)338     fn visit_pat_path_mut(&mut self, pat: &mut PatPath) {
339         if pat.qself.is_none() {
340             self.self_to_qself(&mut pat.qself, &mut pat.path);
341         }
342         visit_mut::visit_pat_path_mut(self, pat);
343     }
344 
visit_pat_struct_mut(&mut self, pat: &mut PatStruct)345     fn visit_pat_struct_mut(&mut self, pat: &mut PatStruct) {
346         self.self_to_expr_path(&mut pat.path);
347         visit_mut::visit_pat_struct_mut(self, pat);
348     }
349 
visit_pat_tuple_struct_mut(&mut self, pat: &mut PatTupleStruct)350     fn visit_pat_tuple_struct_mut(&mut self, pat: &mut PatTupleStruct) {
351         self.self_to_expr_path(&mut pat.path);
352         visit_mut::visit_pat_tuple_struct_mut(self, pat);
353     }
354 
visit_path_mut(&mut self, path: &mut Path)355     fn visit_path_mut(&mut self, path: &mut Path) {
356         if path.segments.len() == 1 {
357             // Replace `self`, but not `self::function`.
358             prepend_underscore_to_self(&mut path.segments[0].ident);
359         }
360         for segment in &mut path.segments {
361             self.visit_path_arguments_mut(&mut segment.arguments);
362         }
363     }
364 
visit_item_mut(&mut self, item: &mut Item)365     fn visit_item_mut(&mut self, item: &mut Item) {
366         match item {
367             // Visit `macro_rules!` because locally defined macros can refer to `self`.
368             Item::Macro(item) if item.mac.path.is_ident("macro_rules") => {
369                 self.visit_macro_mut(&mut item.mac);
370             }
371             // Otherwise, do not recurse into nested items.
372             _ => {}
373         }
374     }
375 
visit_macro_mut(&mut self, mac: &mut Macro)376     fn visit_macro_mut(&mut self, mac: &mut Macro) {
377         // We can't tell in general whether `self` inside a macro invocation
378         // refers to the self in the argument list or a different self
379         // introduced within the macro. Heuristic: if the macro input contains
380         // `fn`, then `self` is more likely to refer to something other than the
381         // outer function's self argument.
382         if !contains_fn(mac.tokens.clone()) {
383             self.visit_token_stream(&mut mac.tokens);
384         }
385     }
386 }
387 
contains_fn(tokens: TokenStream) -> bool388 fn contains_fn(tokens: TokenStream) -> bool {
389     tokens.into_iter().any(|tt| match tt {
390         TokenTree::Ident(ident) => ident == "fn",
391         TokenTree::Group(group) => contains_fn(group.stream()),
392         _ => false,
393     })
394 }
395 
prepend_underscore_to_self(ident: &mut Ident) -> bool396 pub(crate) fn prepend_underscore_to_self(ident: &mut Ident) -> bool {
397     let modified = ident == "self";
398     if modified {
399         *ident = Ident::new("__self", ident.span());
400     }
401     modified
402 }
403