• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![recursion_limit = "256"]
2 // Copyright (c) 2020 Google LLC All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 /// Implementation of the `FromArgs` and `argh(...)` derive attributes.
7 ///
8 /// For more thorough documentation, see the `argh` crate itself.
9 extern crate proc_macro;
10 
11 use {
12     crate::{
13         errors::Errors,
14         parse_attrs::{check_long_name, FieldAttrs, FieldKind, TypeAttrs},
15     },
16     proc_macro2::{Span, TokenStream},
17     quote::{quote, quote_spanned, ToTokens},
18     std::{collections::HashMap, str::FromStr},
19     syn::{spanned::Spanned, GenericArgument, LitStr, PathArguments, Type},
20 };
21 
22 mod errors;
23 mod help;
24 mod parse_attrs;
25 
26 /// Entrypoint for `#[derive(FromArgs)]`.
27 #[proc_macro_derive(FromArgs, attributes(argh))]
argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream28 pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29     let ast = syn::parse_macro_input!(input as syn::DeriveInput);
30     let gen = impl_from_args(&ast);
31     gen.into()
32 }
33 
34 /// Transform the input into a token stream containing any generated implementations,
35 /// as well as all errors that occurred.
impl_from_args(input: &syn::DeriveInput) -> TokenStream36 fn impl_from_args(input: &syn::DeriveInput) -> TokenStream {
37     let errors = &Errors::default();
38     let type_attrs = &TypeAttrs::parse(errors, input);
39     let mut output_tokens = match &input.data {
40         syn::Data::Struct(ds) => {
41             impl_from_args_struct(errors, &input.ident, type_attrs, &input.generics, ds)
42         }
43         syn::Data::Enum(de) => {
44             impl_from_args_enum(errors, &input.ident, type_attrs, &input.generics, de)
45         }
46         syn::Data::Union(_) => {
47             errors.err(input, "`#[derive(FromArgs)]` cannot be applied to unions");
48             TokenStream::new()
49         }
50     };
51     errors.to_tokens(&mut output_tokens);
52     output_tokens
53 }
54 
55 /// The kind of optionality a parameter has.
56 enum Optionality {
57     None,
58     Defaulted(TokenStream),
59     Optional,
60     Repeating,
61 }
62 
63 impl PartialEq<Optionality> for Optionality {
eq(&self, other: &Optionality) -> bool64     fn eq(&self, other: &Optionality) -> bool {
65         use Optionality::*;
66         // NB: (Defaulted, Defaulted) can't contain the same token streams
67         matches!((self, other), (Optional, Optional) | (Repeating, Repeating))
68     }
69 }
70 
71 impl Optionality {
72     /// Whether or not this is `Optionality::None`
is_required(&self) -> bool73     fn is_required(&self) -> bool {
74         matches!(self, Optionality::None)
75     }
76 }
77 
78 /// A field of a `#![derive(FromArgs)]` struct with attributes and some other
79 /// notable metadata appended.
80 struct StructField<'a> {
81     /// The original parsed field
82     field: &'a syn::Field,
83     /// The parsed attributes of the field
84     attrs: FieldAttrs,
85     /// The field name. This is contained optionally inside `field`,
86     /// but is duplicated non-optionally here to indicate that all field that
87     /// have reached this point must have a field name, and it no longer
88     /// needs to be unwrapped.
89     name: &'a syn::Ident,
90     /// Similar to `name` above, this is contained optionally inside `FieldAttrs`,
91     /// but here is fully present to indicate that we only have to consider fields
92     /// with a valid `kind` at this point.
93     kind: FieldKind,
94     // If `field.ty` is `Vec<T>` or `Option<T>`, this is `T`, otherwise it's `&field.ty`.
95     // This is used to enable consistent parsing code between optional and non-optional
96     // keyed and subcommand fields.
97     ty_without_wrapper: &'a syn::Type,
98     // Whether the field represents an optional value, such as an `Option` subcommand field
99     // or an `Option` or `Vec` keyed argument, or if it has a `default`.
100     optionality: Optionality,
101     // The `--`-prefixed name of the option, if one exists.
102     long_name: Option<String>,
103 }
104 
105 impl<'a> StructField<'a> {
106     /// Attempts to parse a field of a `#[derive(FromArgs)]` struct, pulling out the
107     /// fields required for code generation.
new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self>108     fn new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self> {
109         let name = field.ident.as_ref().expect("missing ident for named field");
110 
111         // Ensure that one "kind" is present (switch, option, subcommand, positional)
112         let kind = if let Some(field_type) = &attrs.field_type {
113             field_type.kind
114         } else {
115             errors.err(
116                 field,
117                 concat!(
118                     "Missing `argh` field kind attribute.\n",
119                     "Expected one of: `switch`, `option`, `remaining`, `subcommand`, `positional`",
120                 ),
121             );
122             return None;
123         };
124 
125         // Parse out whether a field is optional (`Option` or `Vec`).
126         let optionality;
127         let ty_without_wrapper;
128         match kind {
129             FieldKind::Switch => {
130                 if !ty_expect_switch(errors, &field.ty) {
131                     return None;
132                 }
133                 optionality = Optionality::Optional;
134                 ty_without_wrapper = &field.ty;
135             }
136             FieldKind::Option | FieldKind::Positional => {
137                 if let Some(default) = &attrs.default {
138                     let tokens = match TokenStream::from_str(&default.value()) {
139                         Ok(tokens) => tokens,
140                         Err(_) => {
141                             errors.err(&default, "Invalid tokens: unable to lex `default` value");
142                             return None;
143                         }
144                     };
145                     // Set the span of the generated tokens to the string literal
146                     let tokens: TokenStream = tokens
147                         .into_iter()
148                         .map(|mut tree| {
149                             tree.set_span(default.span());
150                             tree
151                         })
152                         .collect();
153                     optionality = Optionality::Defaulted(tokens);
154                     ty_without_wrapper = &field.ty;
155                 } else {
156                     let mut inner = None;
157                     optionality = if let Some(x) = ty_inner(&["Option"], &field.ty) {
158                         inner = Some(x);
159                         Optionality::Optional
160                     } else if let Some(x) = ty_inner(&["Vec"], &field.ty) {
161                         inner = Some(x);
162                         Optionality::Repeating
163                     } else {
164                         Optionality::None
165                     };
166                     ty_without_wrapper = inner.unwrap_or(&field.ty);
167                 }
168             }
169             FieldKind::SubCommand => {
170                 let inner = ty_inner(&["Option"], &field.ty);
171                 optionality =
172                     if inner.is_some() { Optionality::Optional } else { Optionality::None };
173                 ty_without_wrapper = inner.unwrap_or(&field.ty);
174             }
175         }
176 
177         // Determine the "long" name of options and switches.
178         // Defaults to the kebab-case'd field name if `#[argh(long = "...")]` is omitted.
179         let long_name = match kind {
180             FieldKind::Switch | FieldKind::Option => {
181                 let long_name = attrs.long.as_ref().map(syn::LitStr::value).unwrap_or_else(|| {
182                     let kebab_name = to_kebab_case(&name.to_string());
183                     check_long_name(errors, name, &kebab_name);
184                     kebab_name
185                 });
186                 if long_name == "help" {
187                     errors.err(field, "Custom `--help` flags are not supported.");
188                 }
189                 let long_name = format!("--{}", long_name);
190                 Some(long_name)
191             }
192             FieldKind::SubCommand | FieldKind::Positional => None,
193         };
194 
195         Some(StructField { field, attrs, kind, optionality, ty_without_wrapper, name, long_name })
196     }
197 
positional_arg_name(&self) -> String198     pub(crate) fn positional_arg_name(&self) -> String {
199         self.attrs
200             .arg_name
201             .as_ref()
202             .map(LitStr::value)
203             .unwrap_or_else(|| self.name.to_string().trim_matches('_').to_owned())
204     }
205 }
206 
to_kebab_case(s: &str) -> String207 fn to_kebab_case(s: &str) -> String {
208     let words = s.split('_').filter(|word| !word.is_empty());
209     let mut res = String::with_capacity(s.len());
210     for word in words {
211         if !res.is_empty() {
212             res.push('-')
213         }
214         res.push_str(word)
215     }
216     res
217 }
218 
219 #[test]
test_kebabs()220 fn test_kebabs() {
221     #[track_caller]
222     fn check(s: &str, want: &str) {
223         let got = to_kebab_case(s);
224         assert_eq!(got.as_str(), want)
225     }
226     check("", "");
227     check("_", "");
228     check("foo", "foo");
229     check("__foo_", "foo");
230     check("foo_bar", "foo-bar");
231     check("foo__Bar", "foo-Bar");
232     check("foo_bar__baz_", "foo-bar-baz");
233 }
234 
235 /// Implements `FromArgs` and `TopLevelCommand` or `SubCommand` for a `#[derive(FromArgs)]` struct.
impl_from_args_struct( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, ds: &syn::DataStruct, ) -> TokenStream236 fn impl_from_args_struct(
237     errors: &Errors,
238     name: &syn::Ident,
239     type_attrs: &TypeAttrs,
240     generic_args: &syn::Generics,
241     ds: &syn::DataStruct,
242 ) -> TokenStream {
243     let fields = match &ds.fields {
244         syn::Fields::Named(fields) => fields,
245         syn::Fields::Unnamed(_) => {
246             errors.err(
247                 &ds.struct_token,
248                 "`#![derive(FromArgs)]` is not currently supported on tuple structs",
249             );
250             return TokenStream::new();
251         }
252         syn::Fields::Unit => {
253             errors.err(&ds.struct_token, "#![derive(FromArgs)]` cannot be applied to unit structs");
254             return TokenStream::new();
255         }
256     };
257 
258     let fields: Vec<_> = fields
259         .named
260         .iter()
261         .filter_map(|field| {
262             let attrs = FieldAttrs::parse(errors, field);
263             StructField::new(errors, field, attrs)
264         })
265         .collect();
266 
267     ensure_unique_names(errors, &fields);
268     ensure_only_last_positional_is_optional(errors, &fields);
269 
270     let impl_span = Span::call_site();
271 
272     let from_args_method = impl_from_args_struct_from_args(errors, type_attrs, &fields);
273 
274     let redact_arg_values_method =
275         impl_from_args_struct_redact_arg_values(errors, type_attrs, &fields);
276 
277     let top_or_sub_cmd_impl = top_or_sub_cmd_impl(errors, name, type_attrs, generic_args);
278 
279     let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
280     let trait_impl = quote_spanned! { impl_span =>
281         #[automatically_derived]
282         impl #impl_generics argh::FromArgs for #name #ty_generics #where_clause {
283             #from_args_method
284 
285             #redact_arg_values_method
286         }
287 
288         #top_or_sub_cmd_impl
289     };
290 
291     trait_impl
292 }
293 
impl_from_args_struct_from_args<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream294 fn impl_from_args_struct_from_args<'a>(
295     errors: &Errors,
296     type_attrs: &TypeAttrs,
297     fields: &'a [StructField<'a>],
298 ) -> TokenStream {
299     let init_fields = declare_local_storage_for_from_args_fields(fields);
300     let unwrap_fields = unwrap_from_args_fields(fields);
301     let positional_fields: Vec<&StructField<'_>> =
302         fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
303     let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
304     let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
305     let last_positional_is_repeating = positional_fields
306         .last()
307         .map(|field| field.optionality == Optionality::Repeating)
308         .unwrap_or(false);
309     let last_positional_is_greedy = positional_fields
310         .last()
311         .map(|field| field.kind == FieldKind::Positional && field.attrs.greedy.is_some())
312         .unwrap_or(false);
313 
314     let flag_output_table = fields.iter().filter_map(|field| {
315         let field_name = &field.field.ident;
316         match field.kind {
317             FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
318             FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
319             FieldKind::SubCommand | FieldKind::Positional => None,
320         }
321     });
322 
323     let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(fields);
324 
325     let mut subcommands_iter =
326         fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
327 
328     let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
329     for dup_subcommand in subcommands_iter {
330         errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
331     }
332 
333     let impl_span = Span::call_site();
334 
335     let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
336 
337     let append_missing_requirements =
338         append_missing_requirements(&missing_requirements_ident, fields);
339 
340     let parse_subcommands = if let Some(subcommand) = subcommand {
341         let name = subcommand.name;
342         let ty = subcommand.ty_without_wrapper;
343         quote_spanned! { impl_span =>
344             Some(argh::ParseStructSubCommand {
345                 subcommands: <#ty as argh::SubCommands>::COMMANDS,
346                 dynamic_subcommands: &<#ty as argh::SubCommands>::dynamic_commands(),
347                 parse_func: &mut |__command, __remaining_args| {
348                     #name = Some(<#ty as argh::FromArgs>::from_args(__command, __remaining_args)?);
349                     Ok(())
350                 },
351             })
352         }
353     } else {
354         quote_spanned! { impl_span => None }
355     };
356 
357     // Identifier referring to a value containing the name of the current command as an `&[&str]`.
358     let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
359     let help = help::help(errors, cmd_name_str_array_ident, type_attrs, fields, subcommand);
360 
361     let method_impl = quote_spanned! { impl_span =>
362         fn from_args(__cmd_name: &[&str], __args: &[&str])
363             -> std::result::Result<Self, argh::EarlyExit>
364         {
365             #![allow(clippy::unwrap_in_result)]
366 
367             #( #init_fields )*
368 
369             argh::parse_struct_args(
370                 __cmd_name,
371                 __args,
372                 argh::ParseStructOptions {
373                     arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
374                     slots: &mut [ #( #flag_output_table, )* ],
375                 },
376                 argh::ParseStructPositionals {
377                     positionals: &mut [
378                         #(
379                             argh::ParseStructPositional {
380                                 name: #positional_field_names,
381                                 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
382                             },
383                         )*
384                     ],
385                     last_is_repeating: #last_positional_is_repeating,
386                     last_is_greedy: #last_positional_is_greedy,
387                 },
388                 #parse_subcommands,
389                 &|| #help,
390             )?;
391 
392             let mut #missing_requirements_ident = argh::MissingRequirements::default();
393             #(
394                 #append_missing_requirements
395             )*
396             #missing_requirements_ident.err_on_any()?;
397 
398             Ok(Self {
399                 #( #unwrap_fields, )*
400             })
401         }
402     };
403 
404     method_impl
405 }
406 
impl_from_args_struct_redact_arg_values<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream407 fn impl_from_args_struct_redact_arg_values<'a>(
408     errors: &Errors,
409     type_attrs: &TypeAttrs,
410     fields: &'a [StructField<'a>],
411 ) -> TokenStream {
412     let init_fields = declare_local_storage_for_redacted_fields(fields);
413     let unwrap_fields = unwrap_redacted_fields(fields);
414 
415     let positional_fields: Vec<&StructField<'_>> =
416         fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
417     let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
418     let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
419     let last_positional_is_repeating = positional_fields
420         .last()
421         .map(|field| field.optionality == Optionality::Repeating)
422         .unwrap_or(false);
423     let last_positional_is_greedy = positional_fields
424         .last()
425         .map(|field| field.kind == FieldKind::Positional && field.attrs.greedy.is_some())
426         .unwrap_or(false);
427 
428     let flag_output_table = fields.iter().filter_map(|field| {
429         let field_name = &field.field.ident;
430         match field.kind {
431             FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
432             FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
433             FieldKind::SubCommand | FieldKind::Positional => None,
434         }
435     });
436 
437     let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(fields);
438 
439     let mut subcommands_iter =
440         fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
441 
442     let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
443     for dup_subcommand in subcommands_iter {
444         errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
445     }
446 
447     let impl_span = Span::call_site();
448 
449     let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
450 
451     let append_missing_requirements =
452         append_missing_requirements(&missing_requirements_ident, fields);
453 
454     let redact_subcommands = if let Some(subcommand) = subcommand {
455         let name = subcommand.name;
456         let ty = subcommand.ty_without_wrapper;
457         quote_spanned! { impl_span =>
458             Some(argh::ParseStructSubCommand {
459                 subcommands: <#ty as argh::SubCommands>::COMMANDS,
460                 dynamic_subcommands: &<#ty as argh::SubCommands>::dynamic_commands(),
461                 parse_func: &mut |__command, __remaining_args| {
462                     #name = Some(<#ty as argh::FromArgs>::redact_arg_values(__command, __remaining_args)?);
463                     Ok(())
464                 },
465             })
466         }
467     } else {
468         quote_spanned! { impl_span => None }
469     };
470 
471     let unwrap_cmd_name_err_string = if type_attrs.is_subcommand.is_none() {
472         quote! { "no command name" }
473     } else {
474         quote! { "no subcommand name" }
475     };
476 
477     // Identifier referring to a value containing the name of the current command as an `&[&str]`.
478     let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
479     let help = help::help(errors, cmd_name_str_array_ident, type_attrs, fields, subcommand);
480 
481     let method_impl = quote_spanned! { impl_span =>
482         fn redact_arg_values(__cmd_name: &[&str], __args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
483             #( #init_fields )*
484 
485             argh::parse_struct_args(
486                 __cmd_name,
487                 __args,
488                 argh::ParseStructOptions {
489                     arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
490                     slots: &mut [ #( #flag_output_table, )* ],
491                 },
492                 argh::ParseStructPositionals {
493                     positionals: &mut [
494                         #(
495                             argh::ParseStructPositional {
496                                 name: #positional_field_names,
497                                 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
498                             },
499                         )*
500                     ],
501                     last_is_repeating: #last_positional_is_repeating,
502                     last_is_greedy: #last_positional_is_greedy,
503                 },
504                 #redact_subcommands,
505                 &|| #help,
506             )?;
507 
508             let mut #missing_requirements_ident = argh::MissingRequirements::default();
509             #(
510                 #append_missing_requirements
511             )*
512             #missing_requirements_ident.err_on_any()?;
513 
514             let mut __redacted = vec![
515                 if let Some(cmd_name) = __cmd_name.last() {
516                     (*cmd_name).to_owned()
517                 } else {
518                     return Err(argh::EarlyExit::from(#unwrap_cmd_name_err_string.to_owned()));
519                 }
520             ];
521 
522             #( #unwrap_fields )*
523 
524             Ok(__redacted)
525         }
526     };
527 
528     method_impl
529 }
530 
531 /// Ensures that only the last positional arg is non-required.
ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>])532 fn ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>]) {
533     let mut first_non_required_span = None;
534     for field in fields {
535         if field.kind == FieldKind::Positional {
536             if let Some(first) = first_non_required_span {
537                 errors.err_span(
538                     first,
539                     "Only the last positional argument may be `Option`, `Vec`, or defaulted.",
540                 );
541                 errors.err(&field.field, "Later positional argument declared here.");
542                 return;
543             }
544             if !field.optionality.is_required() {
545                 first_non_required_span = Some(field.field.span());
546             }
547         }
548     }
549 }
550 
551 /// Ensures that only one short or long name is used.
ensure_unique_names(errors: &Errors, fields: &[StructField<'_>])552 fn ensure_unique_names(errors: &Errors, fields: &[StructField<'_>]) {
553     let mut seen_short_names = HashMap::new();
554     let mut seen_long_names = HashMap::new();
555 
556     for field in fields {
557         if let Some(short_name) = &field.attrs.short {
558             let short_name = short_name.value();
559             if let Some(first_use_field) = seen_short_names.get(&short_name) {
560                 errors.err_span_tokens(
561                     first_use_field,
562                     &format!("The short name of \"-{}\" was already used here.", short_name),
563                 );
564                 errors.err_span_tokens(field.field, "Later usage here.");
565             }
566 
567             seen_short_names.insert(short_name, &field.field);
568         }
569 
570         if let Some(long_name) = &field.long_name {
571             if let Some(first_use_field) = seen_long_names.get(&long_name) {
572                 errors.err_span_tokens(
573                     *first_use_field,
574                     &format!("The long name of \"{}\" was already used here.", long_name),
575                 );
576                 errors.err_span_tokens(field.field, "Later usage here.");
577             }
578 
579             seen_long_names.insert(long_name, field.field);
580         }
581     }
582 }
583 
584 /// Implement `argh::TopLevelCommand` or `argh::SubCommand` as appropriate.
top_or_sub_cmd_impl( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, ) -> TokenStream585 fn top_or_sub_cmd_impl(
586     errors: &Errors,
587     name: &syn::Ident,
588     type_attrs: &TypeAttrs,
589     generic_args: &syn::Generics,
590 ) -> TokenStream {
591     let description =
592         help::require_description(errors, name.span(), &type_attrs.description, "type");
593     let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
594     if type_attrs.is_subcommand.is_none() {
595         // Not a subcommand
596         quote! {
597             #[automatically_derived]
598             impl #impl_generics argh::TopLevelCommand for #name #ty_generics #where_clause {}
599         }
600     } else {
601         let empty_str = syn::LitStr::new("", Span::call_site());
602         let subcommand_name = type_attrs.name.as_ref().unwrap_or_else(|| {
603             errors.err(name, "`#[argh(name = \"...\")]` attribute is required for subcommands");
604             &empty_str
605         });
606         quote! {
607             #[automatically_derived]
608             impl #impl_generics argh::SubCommand for #name #ty_generics #where_clause {
609                 const COMMAND: &'static argh::CommandInfo = &argh::CommandInfo {
610                     name: #subcommand_name,
611                     description: #description,
612                 };
613             }
614         }
615     }
616 }
617 
618 /// Declare a local slots to store each field in during parsing.
619 ///
620 /// Most fields are stored in `Option<FieldType>` locals.
621 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
622 /// function that knows how to decode the appropriate value.
declare_local_storage_for_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a623 fn declare_local_storage_for_from_args_fields<'a>(
624     fields: &'a [StructField<'a>],
625 ) -> impl Iterator<Item = TokenStream> + 'a {
626     fields.iter().map(|field| {
627         let field_name = &field.field.ident;
628         let field_type = &field.ty_without_wrapper;
629 
630         // Wrap field types in `Option` if they aren't already `Option` or `Vec`-wrapped.
631         let field_slot_type = match field.optionality {
632             Optionality::Optional | Optionality::Repeating => (&field.field.ty).into_token_stream(),
633             Optionality::None | Optionality::Defaulted(_) => {
634                 quote! { std::option::Option<#field_type> }
635             }
636         };
637 
638         match field.kind {
639             FieldKind::Option | FieldKind::Positional => {
640                 let from_str_fn = match &field.attrs.from_str_fn {
641                     Some(from_str_fn) => from_str_fn.into_token_stream(),
642                     None => {
643                         quote! {
644                             <#field_type as argh::FromArgValue>::from_arg_value
645                         }
646                     }
647                 };
648 
649                 quote! {
650                     let mut #field_name: argh::ParseValueSlotTy<#field_slot_type, #field_type>
651                         = argh::ParseValueSlotTy {
652                             slot: std::default::Default::default(),
653                             parse_func: |_, value| { #from_str_fn(value) },
654                         };
655                 }
656             }
657             FieldKind::SubCommand => {
658                 quote! { let mut #field_name: #field_slot_type = None; }
659             }
660             FieldKind::Switch => {
661                 quote! { let mut #field_name: #field_slot_type = argh::Flag::default(); }
662             }
663         }
664     })
665 }
666 
667 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a668 fn unwrap_from_args_fields<'a>(
669     fields: &'a [StructField<'a>],
670 ) -> impl Iterator<Item = TokenStream> + 'a {
671     fields.iter().map(|field| {
672         let field_name = field.name;
673         match field.kind {
674             FieldKind::Option | FieldKind::Positional => match &field.optionality {
675                 Optionality::None => quote! {
676                     #field_name: #field_name.slot.unwrap()
677                 },
678                 Optionality::Optional | Optionality::Repeating => {
679                     quote! { #field_name: #field_name.slot }
680                 }
681                 Optionality::Defaulted(tokens) => {
682                     quote! {
683                         #field_name: #field_name.slot.unwrap_or_else(|| #tokens)
684                     }
685                 }
686             },
687             FieldKind::Switch => field_name.into_token_stream(),
688             FieldKind::SubCommand => match field.optionality {
689                 Optionality::None => quote! { #field_name: #field_name.unwrap() },
690                 Optionality::Optional | Optionality::Repeating => field_name.into_token_stream(),
691                 Optionality::Defaulted(_) => unreachable!(),
692             },
693         }
694     })
695 }
696 
697 /// Declare a local slots to store each field in during parsing.
698 ///
699 /// Most fields are stored in `Option<FieldType>` locals.
700 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
701 /// function that knows how to decode the appropriate value.
declare_local_storage_for_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a702 fn declare_local_storage_for_redacted_fields<'a>(
703     fields: &'a [StructField<'a>],
704 ) -> impl Iterator<Item = TokenStream> + 'a {
705     fields.iter().map(|field| {
706         let field_name = &field.field.ident;
707 
708         match field.kind {
709             FieldKind::Switch => {
710                 quote! {
711                     let mut #field_name = argh::RedactFlag {
712                         slot: None,
713                     };
714                 }
715             }
716             FieldKind::Option => {
717                 let field_slot_type = match field.optionality {
718                     Optionality::Repeating => {
719                         quote! { std::vec::Vec<String> }
720                     }
721                     Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
722                         quote! { std::option::Option<String> }
723                     }
724                 };
725 
726                 quote! {
727                     let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
728                         argh::ParseValueSlotTy {
729                         slot: std::default::Default::default(),
730                         parse_func: |arg, _| { Ok(arg.to_owned()) },
731                     };
732                 }
733             }
734             FieldKind::Positional => {
735                 let field_slot_type = match field.optionality {
736                     Optionality::Repeating => {
737                         quote! { std::vec::Vec<String> }
738                     }
739                     Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
740                         quote! { std::option::Option<String> }
741                     }
742                 };
743 
744                 let arg_name = field.positional_arg_name();
745                 quote! {
746                     let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
747                         argh::ParseValueSlotTy {
748                         slot: std::default::Default::default(),
749                         parse_func: |_, _| { Ok(#arg_name.to_owned()) },
750                     };
751                 }
752             }
753             FieldKind::SubCommand => {
754                 quote! { let mut #field_name: std::option::Option<std::vec::Vec<String>> = None; }
755             }
756         }
757     })
758 }
759 
760 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a761 fn unwrap_redacted_fields<'a>(
762     fields: &'a [StructField<'a>],
763 ) -> impl Iterator<Item = TokenStream> + 'a {
764     fields.iter().map(|field| {
765         let field_name = field.name;
766 
767         match field.kind {
768             FieldKind::Switch => {
769                 quote! {
770                     if let Some(__field_name) = #field_name.slot {
771                         __redacted.push(__field_name);
772                     }
773                 }
774             }
775             FieldKind::Option => match field.optionality {
776                 Optionality::Repeating => {
777                     quote! {
778                         __redacted.extend(#field_name.slot.into_iter());
779                     }
780                 }
781                 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
782                     quote! {
783                         if let Some(__field_name) = #field_name.slot {
784                             __redacted.push(__field_name);
785                         }
786                     }
787                 }
788             },
789             FieldKind::Positional => {
790                 quote! {
791                     __redacted.extend(#field_name.slot.into_iter());
792                 }
793             }
794             FieldKind::SubCommand => {
795                 quote! {
796                     if let Some(__subcommand_args) = #field_name {
797                         __redacted.extend(__subcommand_args.into_iter());
798                     }
799                 }
800             }
801         }
802     })
803 }
804 
805 /// Entries of tokens like `("--some-flag-key", 5)` that map from a flag key string
806 /// to an index in the output table.
flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream>807 fn flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream> {
808     let mut flag_str_to_output_table_map = vec![];
809     for (i, (field, long_name)) in fields
810         .iter()
811         .filter_map(|field| field.long_name.as_ref().map(|long_name| (field, long_name)))
812         .enumerate()
813     {
814         if let Some(short) = &field.attrs.short {
815             let short = format!("-{}", short.value());
816             flag_str_to_output_table_map.push(quote! { (#short, #i) });
817         }
818 
819         flag_str_to_output_table_map.push(quote! { (#long_name, #i) });
820     }
821     flag_str_to_output_table_map
822 }
823 
824 /// For each non-optional field, add an entry to the `argh::MissingRequirements`.
append_missing_requirements<'a>( mri: &syn::Ident, fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a825 fn append_missing_requirements<'a>(
826     // missing_requirements_ident
827     mri: &syn::Ident,
828     fields: &'a [StructField<'a>],
829 ) -> impl Iterator<Item = TokenStream> + 'a {
830     let mri = mri.clone();
831     fields.iter().filter(|f| f.optionality.is_required()).map(move |field| {
832         let field_name = field.name;
833         match field.kind {
834             FieldKind::Switch => unreachable!("switches are always optional"),
835             FieldKind::Positional => {
836                 let name = field.positional_arg_name();
837                 quote! {
838                     if #field_name.slot.is_none() {
839                         #mri.missing_positional_arg(#name)
840                     }
841                 }
842             }
843             FieldKind::Option => {
844                 let name = field.long_name.as_ref().expect("options always have a long name");
845                 quote! {
846                     if #field_name.slot.is_none() {
847                         #mri.missing_option(#name)
848                     }
849                 }
850             }
851             FieldKind::SubCommand => {
852                 let ty = field.ty_without_wrapper;
853                 quote! {
854                     if #field_name.is_none() {
855                         #mri.missing_subcommands(
856                             <#ty as argh::SubCommands>::COMMANDS
857                                 .iter()
858                                 .cloned()
859                                 .chain(
860                                     <#ty as argh::SubCommands>::dynamic_commands()
861                                         .iter()
862                                         .copied()
863                                 ),
864                         )
865                     }
866                 }
867             }
868         }
869     })
870 }
871 
872 /// Require that a type can be a `switch`.
873 /// Throws an error for all types except booleans and integers
ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool874 fn ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool {
875     fn ty_can_be_switch(ty: &syn::Type) -> bool {
876         if let syn::Type::Path(path) = ty {
877             if path.qself.is_some() {
878                 return false;
879             }
880             if path.path.segments.len() != 1 {
881                 return false;
882             }
883             let ident = &path.path.segments[0].ident;
884             // `Option<bool>` can be used as a `switch`.
885             if ident == "Option" {
886                 if let PathArguments::AngleBracketed(args) = &path.path.segments[0].arguments {
887                     if let GenericArgument::Type(Type::Path(p)) = &args.args[0] {
888                         if p.path.segments[0].ident == "bool" {
889                             return true;
890                         }
891                     }
892                 }
893             }
894             ["bool", "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]
895                 .iter()
896                 .any(|path| ident == path)
897         } else {
898             false
899         }
900     }
901 
902     let res = ty_can_be_switch(ty);
903     if !res {
904         errors.err(ty, "switches must be of type `bool`, `Option<bool>`, or integer type");
905     }
906     res
907 }
908 
909 /// Returns `Some(T)` if a type is `wrapper_name<T>` for any `wrapper_name` in `wrapper_names`.
ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type>910 fn ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type> {
911     if let syn::Type::Path(path) = ty {
912         if path.qself.is_some() {
913             return None;
914         }
915         // Since we only check the last path segment, it isn't necessarily the case that
916         // we're referring to `std::vec::Vec` or `std::option::Option`, but there isn't
917         // a fool proof way to check these since name resolution happens after macro expansion,
918         // so this is likely "good enough" (so long as people don't have their own types called
919         // `Option` or `Vec` that take one generic parameter they're looking to parse).
920         let last_segment = path.path.segments.last()?;
921         if !wrapper_names.iter().any(|name| last_segment.ident == *name) {
922             return None;
923         }
924         if let syn::PathArguments::AngleBracketed(gen_args) = &last_segment.arguments {
925             let generic_arg = gen_args.args.first()?;
926             if let syn::GenericArgument::Type(ty) = &generic_arg {
927                 return Some(ty);
928             }
929         }
930     }
931     None
932 }
933 
934 /// Implements `FromArgs` and `SubCommands` for a `#![derive(FromArgs)]` enum.
impl_from_args_enum( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, de: &syn::DataEnum, ) -> TokenStream935 fn impl_from_args_enum(
936     errors: &Errors,
937     name: &syn::Ident,
938     type_attrs: &TypeAttrs,
939     generic_args: &syn::Generics,
940     de: &syn::DataEnum,
941 ) -> TokenStream {
942     parse_attrs::check_enum_type_attrs(errors, type_attrs, &de.enum_token.span);
943 
944     // An enum variant like `<name>(<ty>)`
945     struct SubCommandVariant<'a> {
946         name: &'a syn::Ident,
947         ty: &'a syn::Type,
948     }
949 
950     let mut dynamic_type_and_variant = None;
951 
952     let variants: Vec<SubCommandVariant<'_>> = de
953         .variants
954         .iter()
955         .filter_map(|variant| {
956             let name = &variant.ident;
957             let ty = enum_only_single_field_unnamed_variants(errors, &variant.fields)?;
958             if parse_attrs::VariantAttrs::parse(errors, variant).is_dynamic.is_some() {
959                 if dynamic_type_and_variant.is_some() {
960                     errors.err(variant, "Only one variant can have the `dynamic` attribute");
961                 }
962                 dynamic_type_and_variant = Some((ty, name));
963                 None
964             } else {
965                 Some(SubCommandVariant { name, ty })
966             }
967         })
968         .collect();
969 
970     let name_repeating = std::iter::repeat(name.clone());
971     let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
972     let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
973     let dynamic_from_args =
974         dynamic_type_and_variant.as_ref().map(|(dynamic_type, dynamic_variant)| {
975             quote! {
976                 if let Some(result) = <#dynamic_type as argh::DynamicSubCommand>::try_from_args(
977                     command_name, args) {
978                     return result.map(#name::#dynamic_variant);
979                 }
980             }
981         });
982     let dynamic_redact_arg_values = dynamic_type_and_variant.as_ref().map(|(dynamic_type, _)| {
983         quote! {
984             if let Some(result) = <#dynamic_type as argh::DynamicSubCommand>::try_redact_arg_values(
985                 command_name, args) {
986                 return result;
987             }
988         }
989     });
990     let dynamic_commands = dynamic_type_and_variant.as_ref().map(|(dynamic_type, _)| {
991         quote! {
992             fn dynamic_commands() -> &'static [&'static argh::CommandInfo] {
993                 <#dynamic_type as argh::DynamicSubCommand>::commands()
994             }
995         }
996     });
997 
998     let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
999     quote! {
1000         impl #impl_generics argh::FromArgs for #name #ty_generics #where_clause {
1001             fn from_args(command_name: &[&str], args: &[&str])
1002                 -> std::result::Result<Self, argh::EarlyExit>
1003             {
1004                 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
1005                     *subcommand_name
1006                 } else {
1007                     return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
1008                 };
1009 
1010                 #(
1011                     if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
1012                         return Ok(#name_repeating::#variant_names(
1013                             <#variant_ty as argh::FromArgs>::from_args(command_name, args)?
1014                         ));
1015                     }
1016                 )*
1017 
1018                 #dynamic_from_args
1019 
1020                 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
1021             }
1022 
1023             fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
1024                 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
1025                     *subcommand_name
1026                 } else {
1027                     return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
1028                 };
1029 
1030                 #(
1031                     if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
1032                         return <#variant_ty as argh::FromArgs>::redact_arg_values(command_name, args);
1033                     }
1034                 )*
1035 
1036                 #dynamic_redact_arg_values
1037 
1038                 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
1039             }
1040         }
1041 
1042         impl #impl_generics argh::SubCommands for #name #ty_generics #where_clause {
1043             const COMMANDS: &'static [&'static argh::CommandInfo] = &[#(
1044                 <#variant_ty as argh::SubCommand>::COMMAND,
1045             )*];
1046 
1047             #dynamic_commands
1048         }
1049     }
1050 }
1051 
1052 /// Returns `Some(Bar)` if the field is a single-field unnamed variant like `Foo(Bar)`.
1053 /// Otherwise, generates an error.
enum_only_single_field_unnamed_variants<'a>( errors: &Errors, variant_fields: &'a syn::Fields, ) -> Option<&'a syn::Type>1054 fn enum_only_single_field_unnamed_variants<'a>(
1055     errors: &Errors,
1056     variant_fields: &'a syn::Fields,
1057 ) -> Option<&'a syn::Type> {
1058     macro_rules! with_enum_suggestion {
1059         ($help_text:literal) => {
1060             concat!(
1061                 $help_text,
1062                 "\nInstead, use a variant with a single unnamed field for each subcommand:\n",
1063                 "    enum MyCommandEnum {\n",
1064                 "        SubCommandOne(SubCommandOne),\n",
1065                 "        SubCommandTwo(SubCommandTwo),\n",
1066                 "    }",
1067             )
1068         };
1069     }
1070 
1071     match variant_fields {
1072         syn::Fields::Named(fields) => {
1073             errors.err(
1074                 fields,
1075                 with_enum_suggestion!(
1076                     "`#![derive(FromArgs)]` `enum`s do not support variants with named fields."
1077                 ),
1078             );
1079             None
1080         }
1081         syn::Fields::Unit => {
1082             errors.err(
1083                 variant_fields,
1084                 with_enum_suggestion!(
1085                     "`#![derive(FromArgs)]` does not support `enum`s with no variants."
1086                 ),
1087             );
1088             None
1089         }
1090         syn::Fields::Unnamed(fields) => {
1091             if fields.unnamed.len() != 1 {
1092                 errors.err(
1093                     fields,
1094                     with_enum_suggestion!(
1095                         "`#![derive(FromArgs)]` `enum` variants must only contain one field."
1096                     ),
1097                 );
1098                 None
1099             } else {
1100                 // `unwrap` is okay because of the length check above.
1101                 let first_field = fields.unnamed.first().unwrap();
1102                 Some(&first_field.ty)
1103             }
1104         }
1105     }
1106 }
1107