• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::ast::{ContainerKind, Field};
2 use crate::attr::{Display, Trait};
3 use crate::scan_expr::scan_expr;
4 use crate::unraw::{IdentUnraw, MemberUnraw};
5 use proc_macro2::{Delimiter, TokenStream, TokenTree};
6 use quote::{format_ident, quote, quote_spanned, ToTokens as _};
7 use std::collections::{BTreeSet, HashMap};
8 use std::iter;
9 use syn::ext::IdentExt;
10 use syn::parse::discouraged::Speculative;
11 use syn::parse::{Error, ParseStream, Parser, Result};
12 use syn::{Expr, Ident, Index, LitStr, Token};
13 
14 impl Display<'_> {
expand_shorthand(&mut self, fields: &[Field], container: ContainerKind) -> Result<()>15     pub fn expand_shorthand(&mut self, fields: &[Field], container: ContainerKind) -> Result<()> {
16         let raw_args = self.args.clone();
17         let FmtArguments {
18             named: user_named_args,
19             first_unnamed,
20         } = explicit_named_args.parse2(raw_args).unwrap();
21 
22         let mut member_index = HashMap::new();
23         let mut extra_positional_arguments_allowed = true;
24         for (i, field) in fields.iter().enumerate() {
25             member_index.insert(&field.member, i);
26             extra_positional_arguments_allowed &= matches!(&field.member, MemberUnraw::Named(_));
27         }
28 
29         let span = self.fmt.span();
30         let fmt = self.fmt.value();
31         let mut read = fmt.as_str();
32         let mut out = String::new();
33         let mut has_bonus_display = false;
34         let mut infinite_recursive = false;
35         let mut implied_bounds = BTreeSet::new();
36         let mut bindings = Vec::new();
37         let mut macro_named_args = BTreeSet::new();
38 
39         self.requires_fmt_machinery = self.requires_fmt_machinery || fmt.contains('}');
40 
41         while let Some(brace) = read.find('{') {
42             self.requires_fmt_machinery = true;
43             out += &read[..brace + 1];
44             read = &read[brace + 1..];
45             if read.starts_with('{') {
46                 out.push('{');
47                 read = &read[1..];
48                 continue;
49             }
50             let next = match read.chars().next() {
51                 Some(next) => next,
52                 None => return Ok(()),
53             };
54             let member = match next {
55                 '0'..='9' => {
56                     let int = take_int(&mut read);
57                     if !extra_positional_arguments_allowed {
58                         if let Some(first_unnamed) = &first_unnamed {
59                             let msg = format!("ambiguous reference to positional arguments by number in a {container}; change this to a named argument");
60                             return Err(Error::new_spanned(first_unnamed, msg));
61                         }
62                     }
63                     match int.parse::<u32>() {
64                         Ok(index) => MemberUnraw::Unnamed(Index { index, span }),
65                         Err(_) => return Ok(()),
66                     }
67                 }
68                 'a'..='z' | 'A'..='Z' | '_' => {
69                     if read.starts_with("r#") {
70                         continue;
71                     }
72                     let repr = take_ident(&mut read);
73                     if repr == "_" {
74                         // Invalid. Let rustc produce the diagnostic.
75                         out += repr;
76                         continue;
77                     }
78                     let ident = IdentUnraw::new(Ident::new(repr, span));
79                     if user_named_args.contains(&ident) {
80                         // Refers to a named argument written by the user, not to field.
81                         out += repr;
82                         continue;
83                     }
84                     MemberUnraw::Named(ident)
85                 }
86                 _ => continue,
87             };
88             let end_spec = match read.find('}') {
89                 Some(end_spec) => end_spec,
90                 None => return Ok(()),
91             };
92             let mut bonus_display = false;
93             let bound = match read[..end_spec].chars().next_back() {
94                 Some('?') => Trait::Debug,
95                 Some('o') => Trait::Octal,
96                 Some('x') => Trait::LowerHex,
97                 Some('X') => Trait::UpperHex,
98                 Some('p') => Trait::Pointer,
99                 Some('b') => Trait::Binary,
100                 Some('e') => Trait::LowerExp,
101                 Some('E') => Trait::UpperExp,
102                 Some(_) => Trait::Display,
103                 None => {
104                     bonus_display = true;
105                     has_bonus_display = true;
106                     Trait::Display
107                 }
108             };
109             infinite_recursive |= member == *"self" && bound == Trait::Display;
110             let field = match member_index.get(&member) {
111                 Some(&field) => field,
112                 None => {
113                     out += &member.to_string();
114                     continue;
115                 }
116             };
117             implied_bounds.insert((field, bound));
118             let formatvar_prefix = if bonus_display {
119                 "__display"
120             } else if bound == Trait::Pointer {
121                 "__pointer"
122             } else {
123                 "__field"
124             };
125             let mut formatvar = IdentUnraw::new(match &member {
126                 MemberUnraw::Unnamed(index) => format_ident!("{}{}", formatvar_prefix, index),
127                 MemberUnraw::Named(ident) => {
128                     format_ident!("{}_{}", formatvar_prefix, ident.to_string())
129                 }
130             });
131             while user_named_args.contains(&formatvar) {
132                 formatvar = IdentUnraw::new(format_ident!("_{}", formatvar.to_string()));
133             }
134             formatvar.set_span(span);
135             out += &formatvar.to_string();
136             if !macro_named_args.insert(formatvar.clone()) {
137                 // Already added to bindings by a previous use.
138                 continue;
139             }
140             let mut binding_value = match &member {
141                 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
142                 MemberUnraw::Named(ident) => ident.to_local(),
143             };
144             binding_value.set_span(span.resolved_at(fields[field].member.span()));
145             let wrapped_binding_value = if bonus_display {
146                 quote_spanned!(span=> #binding_value.as_display())
147             } else if bound == Trait::Pointer {
148                 quote!(::thiserror::__private::Var(#binding_value))
149             } else {
150                 binding_value.into_token_stream()
151             };
152             bindings.push((formatvar.to_local(), wrapped_binding_value));
153         }
154 
155         out += read;
156         self.fmt = LitStr::new(&out, self.fmt.span());
157         self.has_bonus_display = has_bonus_display;
158         self.infinite_recursive = infinite_recursive;
159         self.implied_bounds = implied_bounds;
160         self.bindings = bindings;
161         Ok(())
162     }
163 }
164 
165 struct FmtArguments {
166     named: BTreeSet<IdentUnraw>,
167     first_unnamed: Option<TokenStream>,
168 }
169 
170 #[allow(clippy::unnecessary_wraps)]
explicit_named_args(input: ParseStream) -> Result<FmtArguments>171 fn explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
172     let ahead = input.fork();
173     if let Ok(set) = try_explicit_named_args(&ahead) {
174         input.advance_to(&ahead);
175         return Ok(set);
176     }
177 
178     let ahead = input.fork();
179     if let Ok(set) = fallback_explicit_named_args(&ahead) {
180         input.advance_to(&ahead);
181         return Ok(set);
182     }
183 
184     input.parse::<TokenStream>().unwrap();
185     Ok(FmtArguments {
186         named: BTreeSet::new(),
187         first_unnamed: None,
188     })
189 }
190 
try_explicit_named_args(input: ParseStream) -> Result<FmtArguments>191 fn try_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
192     let mut syn_full = None;
193     let mut args = FmtArguments {
194         named: BTreeSet::new(),
195         first_unnamed: None,
196     };
197 
198     while !input.is_empty() {
199         input.parse::<Token![,]>()?;
200         if input.is_empty() {
201             break;
202         }
203 
204         let mut begin_unnamed = None;
205         if input.peek(Ident::peek_any) && input.peek2(Token![=]) && !input.peek2(Token![==]) {
206             let ident: IdentUnraw = input.parse()?;
207             input.parse::<Token![=]>()?;
208             args.named.insert(ident);
209         } else {
210             begin_unnamed = Some(input.fork());
211         }
212 
213         let ahead = input.fork();
214         if *syn_full.get_or_insert_with(is_syn_full) && ahead.parse::<Expr>().is_ok() {
215             input.advance_to(&ahead);
216         } else {
217             scan_expr(input)?;
218         }
219 
220         if let Some(begin_unnamed) = begin_unnamed {
221             if args.first_unnamed.is_none() {
222                 args.first_unnamed = Some(between(&begin_unnamed, input));
223             }
224         }
225     }
226 
227     Ok(args)
228 }
229 
fallback_explicit_named_args(input: ParseStream) -> Result<FmtArguments>230 fn fallback_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
231     let mut args = FmtArguments {
232         named: BTreeSet::new(),
233         first_unnamed: None,
234     };
235 
236     while !input.is_empty() {
237         if input.peek(Token![,])
238             && input.peek2(Ident::peek_any)
239             && input.peek3(Token![=])
240             && !input.peek3(Token![==])
241         {
242             input.parse::<Token![,]>()?;
243             let ident: IdentUnraw = input.parse()?;
244             input.parse::<Token![=]>()?;
245             args.named.insert(ident);
246         } else {
247             input.parse::<TokenTree>()?;
248         }
249     }
250 
251     Ok(args)
252 }
253 
is_syn_full() -> bool254 fn is_syn_full() -> bool {
255     // Expr::Block contains syn::Block which contains Vec<syn::Stmt>. In the
256     // current version of Syn, syn::Stmt is exhaustive and could only plausibly
257     // represent `trait Trait {}` in Stmt::Item which contains syn::Item. Most
258     // of the point of syn's non-"full" mode is to avoid compiling Item and the
259     // entire expansive syntax tree it comprises. So the following expression
260     // being parsed to Expr::Block is a reliable indication that "full" is
261     // enabled.
262     let test = quote!({
263         trait Trait {}
264     });
265     match syn::parse2(test) {
266         Ok(Expr::Verbatim(_)) | Err(_) => false,
267         Ok(Expr::Block(_)) => true,
268         Ok(_) => unreachable!(),
269     }
270 }
271 
take_int<'a>(read: &mut &'a str) -> &'a str272 fn take_int<'a>(read: &mut &'a str) -> &'a str {
273     let mut int_len = 0;
274     for ch in read.chars() {
275         match ch {
276             '0'..='9' => int_len += 1,
277             _ => break,
278         }
279     }
280     let (int, rest) = read.split_at(int_len);
281     *read = rest;
282     int
283 }
284 
take_ident<'a>(read: &mut &'a str) -> &'a str285 fn take_ident<'a>(read: &mut &'a str) -> &'a str {
286     let mut ident_len = 0;
287     for ch in read.chars() {
288         match ch {
289             'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => ident_len += 1,
290             _ => break,
291         }
292     }
293     let (ident, rest) = read.split_at(ident_len);
294     *read = rest;
295     ident
296 }
297 
between<'a>(begin: ParseStream<'a>, end: ParseStream<'a>) -> TokenStream298 fn between<'a>(begin: ParseStream<'a>, end: ParseStream<'a>) -> TokenStream {
299     let end = end.cursor();
300     let mut cursor = begin.cursor();
301     let mut tokens = TokenStream::new();
302 
303     while cursor < end {
304         let (tt, next) = cursor.token_tree().unwrap();
305 
306         if end < next {
307             if let Some((inside, _span, _after)) = cursor.group(Delimiter::None) {
308                 cursor = inside;
309                 continue;
310             }
311             if tokens.is_empty() {
312                 tokens.extend(iter::once(tt));
313             }
314             break;
315         }
316 
317         tokens.extend(iter::once(tt));
318         cursor = next;
319     }
320 
321     tokens
322 }
323