• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![recursion_limit = "256"]
2 // Copyright (c) 2020 Google LLC All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 /// Implementation of the `FromArgs` and `argh(...)` derive attributes.
7 ///
8 /// For more thorough documentation, see the `argh` crate itself.
9 extern crate proc_macro;
10 
11 use {
12     crate::{
13         errors::Errors,
14         parse_attrs::{FieldAttrs, FieldKind, TypeAttrs},
15     },
16     proc_macro2::{Span, TokenStream},
17     quote::{quote, quote_spanned, ToTokens},
18     std::str::FromStr,
19     syn::{spanned::Spanned, LitStr},
20 };
21 
22 mod errors;
23 mod help;
24 mod parse_attrs;
25 
26 /// Entrypoint for `#[derive(FromArgs)]`.
27 #[proc_macro_derive(FromArgs, attributes(argh))]
argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream28 pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29     let ast = syn::parse_macro_input!(input as syn::DeriveInput);
30     let gen = impl_from_args(&ast);
31     gen.into()
32 }
33 
34 /// Transform the input into a token stream containing any generated implementations,
35 /// as well as all errors that occurred.
impl_from_args(input: &syn::DeriveInput) -> TokenStream36 fn impl_from_args(input: &syn::DeriveInput) -> TokenStream {
37     let errors = &Errors::default();
38     if input.generics.params.len() != 0 {
39         errors.err(
40             &input.generics,
41             "`#![derive(FromArgs)]` cannot be applied to types with generic parameters",
42         );
43     }
44     let type_attrs = &TypeAttrs::parse(errors, input);
45     let mut output_tokens = match &input.data {
46         syn::Data::Struct(ds) => impl_from_args_struct(errors, &input.ident, type_attrs, ds),
47         syn::Data::Enum(de) => impl_from_args_enum(errors, &input.ident, type_attrs, de),
48         syn::Data::Union(_) => {
49             errors.err(input, "`#[derive(FromArgs)]` cannot be applied to unions");
50             TokenStream::new()
51         }
52     };
53     errors.to_tokens(&mut output_tokens);
54     output_tokens
55 }
56 
57 /// The kind of optionality a parameter has.
58 enum Optionality {
59     None,
60     Defaulted(TokenStream),
61     Optional,
62     Repeating,
63 }
64 
65 impl PartialEq<Optionality> for Optionality {
eq(&self, other: &Optionality) -> bool66     fn eq(&self, other: &Optionality) -> bool {
67         use Optionality::*;
68         match (self, other) {
69             (None, None) | (Optional, Optional) | (Repeating, Repeating) => true,
70             // NB: (Defaulted, Defaulted) can't contain the same token streams
71             _ => false,
72         }
73     }
74 }
75 
76 impl Optionality {
77     /// Whether or not this is `Optionality::None`
is_required(&self) -> bool78     fn is_required(&self) -> bool {
79         if let Optionality::None = self {
80             true
81         } else {
82             false
83         }
84     }
85 }
86 
87 /// A field of a `#![derive(FromArgs)]` struct with attributes and some other
88 /// notable metadata appended.
89 struct StructField<'a> {
90     /// The original parsed field
91     field: &'a syn::Field,
92     /// The parsed attributes of the field
93     attrs: FieldAttrs,
94     /// The field name. This is contained optionally inside `field`,
95     /// but is duplicated non-optionally here to indicate that all field that
96     /// have reached this point must have a field name, and it no longer
97     /// needs to be unwrapped.
98     name: &'a syn::Ident,
99     /// Similar to `name` above, this is contained optionally inside `FieldAttrs`,
100     /// but here is fully present to indicate that we only have to consider fields
101     /// with a valid `kind` at this point.
102     kind: FieldKind,
103     // If `field.ty` is `Vec<T>` or `Option<T>`, this is `T`, otherwise it's `&field.ty`.
104     // This is used to enable consistent parsing code between optional and non-optional
105     // keyed and subcommand fields.
106     ty_without_wrapper: &'a syn::Type,
107     // Whether the field represents an optional value, such as an `Option` subcommand field
108     // or an `Option` or `Vec` keyed argument, or if it has a `default`.
109     optionality: Optionality,
110     // The `--`-prefixed name of the option, if one exists.
111     long_name: Option<String>,
112 }
113 
114 impl<'a> StructField<'a> {
115     /// Attempts to parse a field of a `#[derive(FromArgs)]` struct, pulling out the
116     /// fields required for code generation.
new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self>117     fn new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self> {
118         let name = field.ident.as_ref().expect("missing ident for named field");
119 
120         // Ensure that one "kind" is present (switch, option, subcommand, positional)
121         let kind = if let Some(field_type) = &attrs.field_type {
122             field_type.kind
123         } else {
124             errors.err(
125                 field,
126                 concat!(
127                     "Missing `argh` field kind attribute.\n",
128                     "Expected one of: `switch`, `option`, `subcommand`, `positional`",
129                 ),
130             );
131             return None;
132         };
133 
134         // Parse out whether a field is optional (`Option` or `Vec`).
135         let optionality;
136         let ty_without_wrapper;
137         match kind {
138             FieldKind::Switch => {
139                 if !ty_expect_switch(errors, &field.ty) {
140                     return None;
141                 }
142                 optionality = Optionality::Optional;
143                 ty_without_wrapper = &field.ty;
144             }
145             FieldKind::Option | FieldKind::Positional => {
146                 if let Some(default) = &attrs.default {
147                     let tokens = match TokenStream::from_str(&default.value()) {
148                         Ok(tokens) => tokens,
149                         Err(_) => {
150                             errors.err(&default, "Invalid tokens: unable to lex `default` value");
151                             return None;
152                         }
153                     };
154                     // Set the span of the generated tokens to the string literal
155                     let tokens: TokenStream = tokens
156                         .into_iter()
157                         .map(|mut tree| {
158                             tree.set_span(default.span());
159                             tree
160                         })
161                         .collect();
162                     optionality = Optionality::Defaulted(tokens);
163                     ty_without_wrapper = &field.ty;
164                 } else {
165                     let mut inner = None;
166                     optionality = if let Some(x) = ty_inner(&["Option"], &field.ty) {
167                         inner = Some(x);
168                         Optionality::Optional
169                     } else if let Some(x) = ty_inner(&["Vec"], &field.ty) {
170                         inner = Some(x);
171                         Optionality::Repeating
172                     } else {
173                         Optionality::None
174                     };
175                     ty_without_wrapper = inner.unwrap_or(&field.ty);
176                 }
177             }
178             FieldKind::SubCommand => {
179                 let inner = ty_inner(&["Option"], &field.ty);
180                 optionality =
181                     if inner.is_some() { Optionality::Optional } else { Optionality::None };
182                 ty_without_wrapper = inner.unwrap_or(&field.ty);
183             }
184         }
185 
186         // Determine the "long" name of options and switches.
187         // Defaults to the kebab-case'd field name if `#[argh(long = "...")]` is omitted.
188         let long_name = match kind {
189             FieldKind::Switch | FieldKind::Option => {
190                 let long_name = attrs
191                     .long
192                     .as_ref()
193                     .map(syn::LitStr::value)
194                     .unwrap_or_else(|| heck::KebabCase::to_kebab_case(&*name.to_string()));
195                 if long_name == "help" {
196                     errors.err(field, "Custom `--help` flags are not supported.");
197                 }
198                 let long_name = format!("--{}", long_name);
199                 Some(long_name)
200             }
201             FieldKind::SubCommand | FieldKind::Positional => None,
202         };
203 
204         Some(StructField { field, attrs, kind, optionality, ty_without_wrapper, name, long_name })
205     }
206 
arg_name(&self) -> String207     pub(crate) fn arg_name(&self) -> String {
208         self.attrs.arg_name.as_ref().map(LitStr::value).unwrap_or_else(|| self.name.to_string())
209     }
210 }
211 
212 /// Implements `FromArgs` and `TopLevelCommand` or `SubCommand` for a `#[derive(FromArgs)]` struct.
impl_from_args_struct( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, ds: &syn::DataStruct, ) -> TokenStream213 fn impl_from_args_struct(
214     errors: &Errors,
215     name: &syn::Ident,
216     type_attrs: &TypeAttrs,
217     ds: &syn::DataStruct,
218 ) -> TokenStream {
219     let fields = match &ds.fields {
220         syn::Fields::Named(fields) => fields,
221         syn::Fields::Unnamed(_) => {
222             errors.err(
223                 &ds.struct_token,
224                 "`#![derive(FromArgs)]` is not currently supported on tuple structs",
225             );
226             return TokenStream::new();
227         }
228         syn::Fields::Unit => {
229             errors.err(&ds.struct_token, "#![derive(FromArgs)]` cannot be applied to unit structs");
230             return TokenStream::new();
231         }
232     };
233 
234     let fields: Vec<_> = fields
235         .named
236         .iter()
237         .filter_map(|field| {
238             let attrs = FieldAttrs::parse(errors, field);
239             StructField::new(errors, field, attrs)
240         })
241         .collect();
242 
243     ensure_only_last_positional_is_optional(errors, &fields);
244 
245     let impl_span = Span::call_site();
246 
247     let from_args_method = impl_from_args_struct_from_args(errors, type_attrs, &fields);
248 
249     let redact_arg_values_method =
250         impl_from_args_struct_redact_arg_values(errors, type_attrs, &fields);
251 
252     let top_or_sub_cmd_impl = top_or_sub_cmd_impl(errors, name, type_attrs);
253 
254     let trait_impl = quote_spanned! { impl_span =>
255         impl argh::FromArgs for #name {
256             #from_args_method
257 
258             #redact_arg_values_method
259         }
260 
261         #top_or_sub_cmd_impl
262     };
263 
264     trait_impl
265 }
266 
impl_from_args_struct_from_args<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream267 fn impl_from_args_struct_from_args<'a>(
268     errors: &Errors,
269     type_attrs: &TypeAttrs,
270     fields: &'a [StructField<'a>],
271 ) -> TokenStream {
272     let init_fields = declare_local_storage_for_from_args_fields(&fields);
273     let unwrap_fields = unwrap_from_args_fields(&fields);
274     let positional_fields: Vec<&StructField<'_>> =
275         fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
276     let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
277     let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
278     let last_positional_is_repeating = positional_fields
279         .last()
280         .map(|field| field.optionality == Optionality::Repeating)
281         .unwrap_or(false);
282 
283     let flag_output_table = fields.iter().filter_map(|field| {
284         let field_name = &field.field.ident;
285         match field.kind {
286             FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
287             FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
288             FieldKind::SubCommand | FieldKind::Positional => None,
289         }
290     });
291 
292     let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(&fields);
293 
294     let mut subcommands_iter =
295         fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
296 
297     let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
298     while let Some(dup_subcommand) = subcommands_iter.next() {
299         errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
300     }
301 
302     let impl_span = Span::call_site();
303 
304     let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
305 
306     let append_missing_requirements =
307         append_missing_requirements(&missing_requirements_ident, &fields);
308 
309     let parse_subcommands = if let Some(subcommand) = subcommand {
310         let name = subcommand.name;
311         let ty = subcommand.ty_without_wrapper;
312         quote_spanned! { impl_span =>
313             Some(argh::ParseStructSubCommand {
314                 subcommands: <#ty as argh::SubCommands>::COMMANDS,
315                 parse_func: &mut |__command, __remaining_args| {
316                     #name = Some(<#ty as argh::FromArgs>::from_args(__command, __remaining_args)?);
317                     Ok(())
318                 },
319             })
320         }
321     } else {
322         quote_spanned! { impl_span => None }
323     };
324 
325     // Identifier referring to a value containing the name of the current command as an `&[&str]`.
326     let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
327     let help = help::help(errors, cmd_name_str_array_ident, type_attrs, &fields, subcommand);
328 
329     let method_impl = quote_spanned! { impl_span =>
330         fn from_args(__cmd_name: &[&str], __args: &[&str])
331             -> std::result::Result<Self, argh::EarlyExit>
332         {
333             #( #init_fields )*
334 
335             argh::parse_struct_args(
336                 __cmd_name,
337                 __args,
338                 argh::ParseStructOptions {
339                     arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
340                     slots: &mut [ #( #flag_output_table, )* ],
341                 },
342                 argh::ParseStructPositionals {
343                     positionals: &mut [
344                         #(
345                             argh::ParseStructPositional {
346                                 name: #positional_field_names,
347                                 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
348                             },
349                         )*
350                     ],
351                     last_is_repeating: #last_positional_is_repeating,
352                 },
353                 #parse_subcommands,
354                 &|| #help,
355             )?;
356 
357             let mut #missing_requirements_ident = argh::MissingRequirements::default();
358             #(
359                 #append_missing_requirements
360             )*
361             #missing_requirements_ident.err_on_any()?;
362 
363             Ok(Self {
364                 #( #unwrap_fields, )*
365             })
366         }
367     };
368 
369     method_impl
370 }
371 
impl_from_args_struct_redact_arg_values<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream372 fn impl_from_args_struct_redact_arg_values<'a>(
373     errors: &Errors,
374     type_attrs: &TypeAttrs,
375     fields: &'a [StructField<'a>],
376 ) -> TokenStream {
377     let init_fields = declare_local_storage_for_redacted_fields(&fields);
378     let unwrap_fields = unwrap_redacted_fields(&fields);
379 
380     let positional_fields: Vec<&StructField<'_>> =
381         fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
382     let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
383     let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
384     let last_positional_is_repeating = positional_fields
385         .last()
386         .map(|field| field.optionality == Optionality::Repeating)
387         .unwrap_or(false);
388 
389     let flag_output_table = fields.iter().filter_map(|field| {
390         let field_name = &field.field.ident;
391         match field.kind {
392             FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
393             FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
394             FieldKind::SubCommand | FieldKind::Positional => None,
395         }
396     });
397 
398     let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(&fields);
399 
400     let mut subcommands_iter =
401         fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
402 
403     let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
404     while let Some(dup_subcommand) = subcommands_iter.next() {
405         errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
406     }
407 
408     let impl_span = Span::call_site();
409 
410     let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
411 
412     let append_missing_requirements =
413         append_missing_requirements(&missing_requirements_ident, &fields);
414 
415     let redact_subcommands = if let Some(subcommand) = subcommand {
416         let name = subcommand.name;
417         let ty = subcommand.ty_without_wrapper;
418         quote_spanned! { impl_span =>
419             Some(argh::ParseStructSubCommand {
420                 subcommands: <#ty as argh::SubCommands>::COMMANDS,
421                 parse_func: &mut |__command, __remaining_args| {
422                     #name = Some(<#ty as argh::FromArgs>::redact_arg_values(__command, __remaining_args)?);
423                     Ok(())
424                 },
425             })
426         }
427     } else {
428         quote_spanned! { impl_span => None }
429     };
430 
431     let cmd_name = if type_attrs.is_subcommand.is_none() {
432         quote! { __cmd_name.last().expect("no command name").to_string() }
433     } else {
434         quote! { __cmd_name.last().expect("no subcommand name").to_string() }
435     };
436 
437     // Identifier referring to a value containing the name of the current command as an `&[&str]`.
438     let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
439     let help = help::help(errors, cmd_name_str_array_ident, type_attrs, &fields, subcommand);
440 
441     let method_impl = quote_spanned! { impl_span =>
442         fn redact_arg_values(__cmd_name: &[&str], __args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
443             #( #init_fields )*
444 
445             argh::parse_struct_args(
446                 __cmd_name,
447                 __args,
448                 argh::ParseStructOptions {
449                     arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
450                     slots: &mut [ #( #flag_output_table, )* ],
451                 },
452                 argh::ParseStructPositionals {
453                     positionals: &mut [
454                         #(
455                             argh::ParseStructPositional {
456                                 name: #positional_field_names,
457                                 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
458                             },
459                         )*
460                     ],
461                     last_is_repeating: #last_positional_is_repeating,
462                 },
463                 #redact_subcommands,
464                 &|| #help,
465             )?;
466 
467             let mut #missing_requirements_ident = argh::MissingRequirements::default();
468             #(
469                 #append_missing_requirements
470             )*
471             #missing_requirements_ident.err_on_any()?;
472 
473             let mut __redacted = vec![
474                 #cmd_name,
475             ];
476 
477             #( #unwrap_fields )*
478 
479             Ok(__redacted)
480         }
481     };
482 
483     method_impl
484 }
485 
486 /// Ensures that only the last positional arg is non-required.
ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>])487 fn ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>]) {
488     let mut first_non_required_span = None;
489     for field in fields {
490         if field.kind == FieldKind::Positional {
491             if let Some(first) = first_non_required_span {
492                 errors.err_span(
493                     first,
494                     "Only the last positional argument may be `Option`, `Vec`, or defaulted.",
495                 );
496                 errors.err(&field.field, "Later positional argument declared here.");
497                 return;
498             }
499             if !field.optionality.is_required() {
500                 first_non_required_span = Some(field.field.span());
501             }
502         }
503     }
504 }
505 
506 /// Implement `argh::TopLevelCommand` or `argh::SubCommand` as appropriate.
top_or_sub_cmd_impl(errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs) -> TokenStream507 fn top_or_sub_cmd_impl(errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs) -> TokenStream {
508     let description =
509         help::require_description(errors, name.span(), &type_attrs.description, "type");
510     if type_attrs.is_subcommand.is_none() {
511         // Not a subcommand
512         quote! {
513             impl argh::TopLevelCommand for #name {}
514         }
515     } else {
516         let empty_str = syn::LitStr::new("", Span::call_site());
517         let subcommand_name = type_attrs.name.as_ref().unwrap_or_else(|| {
518             errors.err(name, "`#[argh(name = \"...\")]` attribute is required for subcommands");
519             &empty_str
520         });
521         quote! {
522             impl argh::SubCommand for #name {
523                 const COMMAND: &'static argh::CommandInfo = &argh::CommandInfo {
524                     name: #subcommand_name,
525                     description: #description,
526                 };
527             }
528         }
529     }
530 }
531 
532 /// Declare a local slots to store each field in during parsing.
533 ///
534 /// Most fields are stored in `Option<FieldType>` locals.
535 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
536 /// function that knows how to decode the appropriate value.
declare_local_storage_for_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a537 fn declare_local_storage_for_from_args_fields<'a>(
538     fields: &'a [StructField<'a>],
539 ) -> impl Iterator<Item = TokenStream> + 'a {
540     fields.iter().map(|field| {
541         let field_name = &field.field.ident;
542         let field_type = &field.ty_without_wrapper;
543 
544         // Wrap field types in `Option` if they aren't already `Option` or `Vec`-wrapped.
545         let field_slot_type = match field.optionality {
546             Optionality::Optional | Optionality::Repeating => (&field.field.ty).into_token_stream(),
547             Optionality::None | Optionality::Defaulted(_) => {
548                 quote! { std::option::Option<#field_type> }
549             }
550         };
551 
552         match field.kind {
553             FieldKind::Option | FieldKind::Positional => {
554                 let from_str_fn = match &field.attrs.from_str_fn {
555                     Some(from_str_fn) => from_str_fn.into_token_stream(),
556                     None => {
557                         quote! {
558                             <#field_type as argh::FromArgValue>::from_arg_value
559                         }
560                     }
561                 };
562 
563                 quote! {
564                     let mut #field_name: argh::ParseValueSlotTy<#field_slot_type, #field_type>
565                         = argh::ParseValueSlotTy {
566                             slot: std::default::Default::default(),
567                             parse_func: |_, value| { #from_str_fn(value) },
568                         };
569                 }
570             }
571             FieldKind::SubCommand => {
572                 quote! { let mut #field_name: #field_slot_type = None; }
573             }
574             FieldKind::Switch => {
575                 quote! { let mut #field_name: #field_slot_type = argh::Flag::default(); }
576             }
577         }
578     })
579 }
580 
581 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a582 fn unwrap_from_args_fields<'a>(
583     fields: &'a [StructField<'a>],
584 ) -> impl Iterator<Item = TokenStream> + 'a {
585     fields.iter().map(|field| {
586         let field_name = field.name;
587         match field.kind {
588             FieldKind::Option | FieldKind::Positional => match &field.optionality {
589                 Optionality::None => quote! { #field_name: #field_name.slot.unwrap() },
590                 Optionality::Optional | Optionality::Repeating => {
591                     quote! { #field_name: #field_name.slot }
592                 }
593                 Optionality::Defaulted(tokens) => {
594                     quote! {
595                         #field_name: #field_name.slot.unwrap_or_else(|| #tokens)
596                     }
597                 }
598             },
599             FieldKind::Switch => field_name.into_token_stream(),
600             FieldKind::SubCommand => match field.optionality {
601                 Optionality::None => quote! { #field_name: #field_name.unwrap() },
602                 Optionality::Optional | Optionality::Repeating => field_name.into_token_stream(),
603                 Optionality::Defaulted(_) => unreachable!(),
604             },
605         }
606     })
607 }
608 
609 /// Declare a local slots to store each field in during parsing.
610 ///
611 /// Most fields are stored in `Option<FieldType>` locals.
612 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
613 /// function that knows how to decode the appropriate value.
declare_local_storage_for_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a614 fn declare_local_storage_for_redacted_fields<'a>(
615     fields: &'a [StructField<'a>],
616 ) -> impl Iterator<Item = TokenStream> + 'a {
617     fields.iter().map(|field| {
618         let field_name = &field.field.ident;
619 
620         match field.kind {
621             FieldKind::Switch => {
622                 quote! {
623                     let mut #field_name = argh::RedactFlag {
624                         slot: None,
625                     };
626                 }
627             }
628             FieldKind::Option => {
629                 let field_slot_type = match field.optionality {
630                     Optionality::Repeating => {
631                         quote! { std::vec::Vec<String> }
632                     }
633                     Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
634                         quote! { std::option::Option<String> }
635                     }
636                 };
637 
638                 quote! {
639                     let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
640                         argh::ParseValueSlotTy {
641                         slot: std::default::Default::default(),
642                         parse_func: |arg, _| { Ok(arg.to_string()) },
643                     };
644                 }
645             }
646             FieldKind::Positional => {
647                 let field_slot_type = match field.optionality {
648                     Optionality::Repeating => {
649                         quote! { std::vec::Vec<String> }
650                     }
651                     Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
652                         quote! { std::option::Option<String> }
653                     }
654                 };
655 
656                 let arg_name = field.arg_name();
657                 quote! {
658                     let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
659                         argh::ParseValueSlotTy {
660                         slot: std::default::Default::default(),
661                         parse_func: |_, _| { Ok(#arg_name.to_string()) },
662                     };
663                 }
664             }
665             FieldKind::SubCommand => {
666                 quote! { let mut #field_name: std::option::Option<std::vec::Vec<String>> = None; }
667             }
668         }
669     })
670 }
671 
672 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a673 fn unwrap_redacted_fields<'a>(
674     fields: &'a [StructField<'a>],
675 ) -> impl Iterator<Item = TokenStream> + 'a {
676     fields.iter().map(|field| {
677         let field_name = field.name;
678 
679         match field.kind {
680             FieldKind::Switch => {
681                 quote! {
682                     if let Some(__field_name) = #field_name.slot {
683                         __redacted.push(__field_name);
684                     }
685                 }
686             }
687             FieldKind::Option => match field.optionality {
688                 Optionality::Repeating => {
689                     quote! {
690                         __redacted.extend(#field_name.slot.into_iter());
691                     }
692                 }
693                 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
694                     quote! {
695                         if let Some(__field_name) = #field_name.slot {
696                             __redacted.push(__field_name);
697                         }
698                     }
699                 }
700             },
701             FieldKind::Positional => {
702                 quote! {
703                     __redacted.extend(#field_name.slot.into_iter());
704                 }
705             }
706             FieldKind::SubCommand => {
707                 quote! {
708                     if let Some(__subcommand_args) = #field_name {
709                         __redacted.extend(__subcommand_args.into_iter());
710                     }
711                 }
712             }
713         }
714     })
715 }
716 
717 /// Entries of tokens like `("--some-flag-key", 5)` that map from a flag key string
718 /// to an index in the output table.
flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream>719 fn flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream> {
720     let mut flag_str_to_output_table_map = vec![];
721     for (i, (field, long_name)) in fields
722         .iter()
723         .filter_map(|field| field.long_name.as_ref().map(|long_name| (field, long_name)))
724         .enumerate()
725     {
726         if let Some(short) = &field.attrs.short {
727             let short = format!("-{}", short.value());
728             flag_str_to_output_table_map.push(quote! { (#short, #i) });
729         }
730 
731         flag_str_to_output_table_map.push(quote! { (#long_name, #i) });
732     }
733     flag_str_to_output_table_map
734 }
735 
736 /// For each non-optional field, add an entry to the `argh::MissingRequirements`.
append_missing_requirements<'a>( mri: &syn::Ident, fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a737 fn append_missing_requirements<'a>(
738     // missing_requirements_ident
739     mri: &syn::Ident,
740     fields: &'a [StructField<'a>],
741 ) -> impl Iterator<Item = TokenStream> + 'a {
742     let mri = mri.clone();
743     fields.iter().filter(|f| f.optionality.is_required()).map(move |field| {
744         let field_name = field.name;
745         match field.kind {
746             FieldKind::Switch => unreachable!("switches are always optional"),
747             FieldKind::Positional => {
748                 let name = field.arg_name();
749                 quote! {
750                     if #field_name.slot.is_none() {
751                         #mri.missing_positional_arg(#name)
752                     }
753                 }
754             }
755             FieldKind::Option => {
756                 let name = field.long_name.as_ref().expect("options always have a long name");
757                 quote! {
758                     if #field_name.slot.is_none() {
759                         #mri.missing_option(#name)
760                     }
761                 }
762             }
763             FieldKind::SubCommand => {
764                 let ty = field.ty_without_wrapper;
765                 quote! {
766                     if #field_name.is_none() {
767                         #mri.missing_subcommands(
768                             <#ty as argh::SubCommands>::COMMANDS,
769                         )
770                     }
771                 }
772             }
773         }
774     })
775 }
776 
777 /// Require that a type can be a `switch`.
778 /// Throws an error for all types except booleans and integers
ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool779 fn ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool {
780     fn ty_can_be_switch(ty: &syn::Type) -> bool {
781         if let syn::Type::Path(path) = ty {
782             if path.qself.is_some() {
783                 return false;
784             }
785             if path.path.segments.len() != 1 {
786                 return false;
787             }
788             let ident = &path.path.segments[0].ident;
789             ["bool", "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]
790                 .iter()
791                 .any(|path| ident == path)
792         } else {
793             false
794         }
795     }
796 
797     let res = ty_can_be_switch(ty);
798     if !res {
799         errors.err(ty, "switches must be of type `bool` or integer type");
800     }
801     res
802 }
803 
804 /// Returns `Some(T)` if a type is `wrapper_name<T>` for any `wrapper_name` in `wrapper_names`.
ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type>805 fn ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type> {
806     if let syn::Type::Path(path) = ty {
807         if path.qself.is_some() {
808             return None;
809         }
810         // Since we only check the last path segment, it isn't necessarily the case that
811         // we're referring to `std::vec::Vec` or `std::option::Option`, but there isn't
812         // a fool proof way to check these since name resolution happens after macro expansion,
813         // so this is likely "good enough" (so long as people don't have their own types called
814         // `Option` or `Vec` that take one generic parameter they're looking to parse).
815         let last_segment = path.path.segments.last()?;
816         if !wrapper_names.iter().any(|name| last_segment.ident == *name) {
817             return None;
818         }
819         if let syn::PathArguments::AngleBracketed(gen_args) = &last_segment.arguments {
820             let generic_arg = gen_args.args.first()?;
821             if let syn::GenericArgument::Type(ty) = &generic_arg {
822                 return Some(ty);
823             }
824         }
825     }
826     None
827 }
828 
829 /// Implements `FromArgs` and `SubCommands` for a `#![derive(FromArgs)]` enum.
impl_from_args_enum( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, de: &syn::DataEnum, ) -> TokenStream830 fn impl_from_args_enum(
831     errors: &Errors,
832     name: &syn::Ident,
833     type_attrs: &TypeAttrs,
834     de: &syn::DataEnum,
835 ) -> TokenStream {
836     parse_attrs::check_enum_type_attrs(errors, type_attrs, &de.enum_token.span);
837 
838     // An enum variant like `<name>(<ty>)`
839     struct SubCommandVariant<'a> {
840         name: &'a syn::Ident,
841         ty: &'a syn::Type,
842     }
843 
844     let variants: Vec<SubCommandVariant<'_>> = de
845         .variants
846         .iter()
847         .filter_map(|variant| {
848             parse_attrs::check_enum_variant_attrs(errors, variant);
849             let name = &variant.ident;
850             let ty = enum_only_single_field_unnamed_variants(errors, &variant.fields)?;
851             Some(SubCommandVariant { name, ty })
852         })
853         .collect();
854 
855     let name_repeating = std::iter::repeat(name.clone());
856     let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
857     let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
858 
859     quote! {
860         impl argh::FromArgs for #name {
861             fn from_args(command_name: &[&str], args: &[&str])
862                 -> std::result::Result<Self, argh::EarlyExit>
863             {
864                 let subcommand_name = *command_name.last().expect("no subcommand name");
865                 #(
866                     if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
867                         return Ok(#name_repeating::#variant_names(
868                             <#variant_ty as argh::FromArgs>::from_args(command_name, args)?
869                         ));
870                     }
871                 )*
872                 unreachable!("no subcommand matched")
873             }
874 
875             fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
876                 let subcommand_name = *command_name.last().expect("no subcommand name");
877                 #(
878                     if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
879                         return <#variant_ty as argh::FromArgs>::redact_arg_values(command_name, args);
880                     }
881                 )*
882                 unreachable!("no subcommand matched")
883             }
884         }
885 
886         impl argh::SubCommands for #name {
887             const COMMANDS: &'static [&'static argh::CommandInfo] = &[#(
888                 <#variant_ty as argh::SubCommand>::COMMAND,
889             )*];
890         }
891     }
892 }
893 
894 /// Returns `Some(Bar)` if the field is a single-field unnamed variant like `Foo(Bar)`.
895 /// Otherwise, generates an error.
enum_only_single_field_unnamed_variants<'a>( errors: &Errors, variant_fields: &'a syn::Fields, ) -> Option<&'a syn::Type>896 fn enum_only_single_field_unnamed_variants<'a>(
897     errors: &Errors,
898     variant_fields: &'a syn::Fields,
899 ) -> Option<&'a syn::Type> {
900     macro_rules! with_enum_suggestion {
901         ($help_text:literal) => {
902             concat!(
903                 $help_text,
904                 "\nInstead, use a variant with a single unnamed field for each subcommand:\n",
905                 "    enum MyCommandEnum {\n",
906                 "        SubCommandOne(SubCommandOne),\n",
907                 "        SubCommandTwo(SubCommandTwo),\n",
908                 "    }",
909             )
910         };
911     }
912 
913     match variant_fields {
914         syn::Fields::Named(fields) => {
915             errors.err(
916                 fields,
917                 with_enum_suggestion!(
918                     "`#![derive(FromArgs)]` `enum`s do not support variants with named fields."
919                 ),
920             );
921             None
922         }
923         syn::Fields::Unit => {
924             errors.err(
925                 variant_fields,
926                 with_enum_suggestion!(
927                     "`#![derive(FromArgs)]` does not support `enum`s with no variants."
928                 ),
929             );
930             None
931         }
932         syn::Fields::Unnamed(fields) => {
933             if fields.unnamed.len() != 1 {
934                 errors.err(
935                     fields,
936                     with_enum_suggestion!(
937                         "`#![derive(FromArgs)]` `enum` variants must only contain one field."
938                     ),
939                 );
940                 None
941             } else {
942                 // `unwrap` is okay because of the length check above.
943                 let first_field = fields.unnamed.first().unwrap();
944                 Some(&first_field.ty)
945             }
946         }
947     }
948 }
949