1 #![recursion_limit = "256"]
2 // Copyright (c) 2020 Google LLC All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 /// Implementation of the `FromArgs` and `argh(...)` derive attributes.
7 ///
8 /// For more thorough documentation, see the `argh` crate itself.
9 extern crate proc_macro;
10
11 use {
12 crate::{
13 errors::Errors,
14 parse_attrs::{check_long_name, FieldAttrs, FieldKind, TypeAttrs},
15 },
16 proc_macro2::{Span, TokenStream},
17 quote::{quote, quote_spanned, ToTokens},
18 std::{collections::HashMap, str::FromStr},
19 syn::{spanned::Spanned, GenericArgument, LitStr, PathArguments, Type},
20 };
21
22 mod errors;
23 mod help;
24 mod parse_attrs;
25
26 /// Entrypoint for `#[derive(FromArgs)]`.
27 #[proc_macro_derive(FromArgs, attributes(argh))]
argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream28 pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
30 let gen = impl_from_args(&ast);
31 gen.into()
32 }
33
34 /// Transform the input into a token stream containing any generated implementations,
35 /// as well as all errors that occurred.
impl_from_args(input: &syn::DeriveInput) -> TokenStream36 fn impl_from_args(input: &syn::DeriveInput) -> TokenStream {
37 let errors = &Errors::default();
38 let type_attrs = &TypeAttrs::parse(errors, input);
39 let mut output_tokens = match &input.data {
40 syn::Data::Struct(ds) => {
41 impl_from_args_struct(errors, &input.ident, type_attrs, &input.generics, ds)
42 }
43 syn::Data::Enum(de) => {
44 impl_from_args_enum(errors, &input.ident, type_attrs, &input.generics, de)
45 }
46 syn::Data::Union(_) => {
47 errors.err(input, "`#[derive(FromArgs)]` cannot be applied to unions");
48 TokenStream::new()
49 }
50 };
51 errors.to_tokens(&mut output_tokens);
52 output_tokens
53 }
54
55 /// The kind of optionality a parameter has.
56 enum Optionality {
57 None,
58 Defaulted(TokenStream),
59 Optional,
60 Repeating,
61 }
62
63 impl PartialEq<Optionality> for Optionality {
eq(&self, other: &Optionality) -> bool64 fn eq(&self, other: &Optionality) -> bool {
65 use Optionality::*;
66 // NB: (Defaulted, Defaulted) can't contain the same token streams
67 matches!((self, other), (Optional, Optional) | (Repeating, Repeating))
68 }
69 }
70
71 impl Optionality {
72 /// Whether or not this is `Optionality::None`
is_required(&self) -> bool73 fn is_required(&self) -> bool {
74 matches!(self, Optionality::None)
75 }
76 }
77
78 /// A field of a `#![derive(FromArgs)]` struct with attributes and some other
79 /// notable metadata appended.
80 struct StructField<'a> {
81 /// The original parsed field
82 field: &'a syn::Field,
83 /// The parsed attributes of the field
84 attrs: FieldAttrs,
85 /// The field name. This is contained optionally inside `field`,
86 /// but is duplicated non-optionally here to indicate that all field that
87 /// have reached this point must have a field name, and it no longer
88 /// needs to be unwrapped.
89 name: &'a syn::Ident,
90 /// Similar to `name` above, this is contained optionally inside `FieldAttrs`,
91 /// but here is fully present to indicate that we only have to consider fields
92 /// with a valid `kind` at this point.
93 kind: FieldKind,
94 // If `field.ty` is `Vec<T>` or `Option<T>`, this is `T`, otherwise it's `&field.ty`.
95 // This is used to enable consistent parsing code between optional and non-optional
96 // keyed and subcommand fields.
97 ty_without_wrapper: &'a syn::Type,
98 // Whether the field represents an optional value, such as an `Option` subcommand field
99 // or an `Option` or `Vec` keyed argument, or if it has a `default`.
100 optionality: Optionality,
101 // The `--`-prefixed name of the option, if one exists.
102 long_name: Option<String>,
103 }
104
105 impl<'a> StructField<'a> {
106 /// Attempts to parse a field of a `#[derive(FromArgs)]` struct, pulling out the
107 /// fields required for code generation.
new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self>108 fn new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self> {
109 let name = field.ident.as_ref().expect("missing ident for named field");
110
111 // Ensure that one "kind" is present (switch, option, subcommand, positional)
112 let kind = if let Some(field_type) = &attrs.field_type {
113 field_type.kind
114 } else {
115 errors.err(
116 field,
117 concat!(
118 "Missing `argh` field kind attribute.\n",
119 "Expected one of: `switch`, `option`, `remaining`, `subcommand`, `positional`",
120 ),
121 );
122 return None;
123 };
124
125 // Parse out whether a field is optional (`Option` or `Vec`).
126 let optionality;
127 let ty_without_wrapper;
128 match kind {
129 FieldKind::Switch => {
130 if !ty_expect_switch(errors, &field.ty) {
131 return None;
132 }
133 optionality = Optionality::Optional;
134 ty_without_wrapper = &field.ty;
135 }
136 FieldKind::Option | FieldKind::Positional => {
137 if let Some(default) = &attrs.default {
138 let tokens = match TokenStream::from_str(&default.value()) {
139 Ok(tokens) => tokens,
140 Err(_) => {
141 errors.err(&default, "Invalid tokens: unable to lex `default` value");
142 return None;
143 }
144 };
145 // Set the span of the generated tokens to the string literal
146 let tokens: TokenStream = tokens
147 .into_iter()
148 .map(|mut tree| {
149 tree.set_span(default.span());
150 tree
151 })
152 .collect();
153 optionality = Optionality::Defaulted(tokens);
154 ty_without_wrapper = &field.ty;
155 } else {
156 let mut inner = None;
157 optionality = if let Some(x) = ty_inner(&["Option"], &field.ty) {
158 inner = Some(x);
159 Optionality::Optional
160 } else if let Some(x) = ty_inner(&["Vec"], &field.ty) {
161 inner = Some(x);
162 Optionality::Repeating
163 } else {
164 Optionality::None
165 };
166 ty_without_wrapper = inner.unwrap_or(&field.ty);
167 }
168 }
169 FieldKind::SubCommand => {
170 let inner = ty_inner(&["Option"], &field.ty);
171 optionality =
172 if inner.is_some() { Optionality::Optional } else { Optionality::None };
173 ty_without_wrapper = inner.unwrap_or(&field.ty);
174 }
175 }
176
177 // Determine the "long" name of options and switches.
178 // Defaults to the kebab-case'd field name if `#[argh(long = "...")]` is omitted.
179 let long_name = match kind {
180 FieldKind::Switch | FieldKind::Option => {
181 let long_name = attrs.long.as_ref().map(syn::LitStr::value).unwrap_or_else(|| {
182 let kebab_name = to_kebab_case(&name.to_string());
183 check_long_name(errors, name, &kebab_name);
184 kebab_name
185 });
186 if long_name == "help" {
187 errors.err(field, "Custom `--help` flags are not supported.");
188 }
189 let long_name = format!("--{}", long_name);
190 Some(long_name)
191 }
192 FieldKind::SubCommand | FieldKind::Positional => None,
193 };
194
195 Some(StructField { field, attrs, kind, optionality, ty_without_wrapper, name, long_name })
196 }
197
positional_arg_name(&self) -> String198 pub(crate) fn positional_arg_name(&self) -> String {
199 self.attrs
200 .arg_name
201 .as_ref()
202 .map(LitStr::value)
203 .unwrap_or_else(|| self.name.to_string().trim_matches('_').to_owned())
204 }
205 }
206
to_kebab_case(s: &str) -> String207 fn to_kebab_case(s: &str) -> String {
208 let words = s.split('_').filter(|word| !word.is_empty());
209 let mut res = String::with_capacity(s.len());
210 for word in words {
211 if !res.is_empty() {
212 res.push('-')
213 }
214 res.push_str(word)
215 }
216 res
217 }
218
219 #[test]
test_kebabs()220 fn test_kebabs() {
221 #[track_caller]
222 fn check(s: &str, want: &str) {
223 let got = to_kebab_case(s);
224 assert_eq!(got.as_str(), want)
225 }
226 check("", "");
227 check("_", "");
228 check("foo", "foo");
229 check("__foo_", "foo");
230 check("foo_bar", "foo-bar");
231 check("foo__Bar", "foo-Bar");
232 check("foo_bar__baz_", "foo-bar-baz");
233 }
234
235 /// Implements `FromArgs` and `TopLevelCommand` or `SubCommand` for a `#[derive(FromArgs)]` struct.
impl_from_args_struct( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, ds: &syn::DataStruct, ) -> TokenStream236 fn impl_from_args_struct(
237 errors: &Errors,
238 name: &syn::Ident,
239 type_attrs: &TypeAttrs,
240 generic_args: &syn::Generics,
241 ds: &syn::DataStruct,
242 ) -> TokenStream {
243 let fields = match &ds.fields {
244 syn::Fields::Named(fields) => fields,
245 syn::Fields::Unnamed(_) => {
246 errors.err(
247 &ds.struct_token,
248 "`#![derive(FromArgs)]` is not currently supported on tuple structs",
249 );
250 return TokenStream::new();
251 }
252 syn::Fields::Unit => {
253 errors.err(&ds.struct_token, "#![derive(FromArgs)]` cannot be applied to unit structs");
254 return TokenStream::new();
255 }
256 };
257
258 let fields: Vec<_> = fields
259 .named
260 .iter()
261 .filter_map(|field| {
262 let attrs = FieldAttrs::parse(errors, field);
263 StructField::new(errors, field, attrs)
264 })
265 .collect();
266
267 ensure_unique_names(errors, &fields);
268 ensure_only_last_positional_is_optional(errors, &fields);
269
270 let impl_span = Span::call_site();
271
272 let from_args_method = impl_from_args_struct_from_args(errors, type_attrs, &fields);
273
274 let redact_arg_values_method =
275 impl_from_args_struct_redact_arg_values(errors, type_attrs, &fields);
276
277 let top_or_sub_cmd_impl = top_or_sub_cmd_impl(errors, name, type_attrs, generic_args);
278
279 let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
280 let trait_impl = quote_spanned! { impl_span =>
281 #[automatically_derived]
282 impl #impl_generics argh::FromArgs for #name #ty_generics #where_clause {
283 #from_args_method
284
285 #redact_arg_values_method
286 }
287
288 #top_or_sub_cmd_impl
289 };
290
291 trait_impl
292 }
293
impl_from_args_struct_from_args<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream294 fn impl_from_args_struct_from_args<'a>(
295 errors: &Errors,
296 type_attrs: &TypeAttrs,
297 fields: &'a [StructField<'a>],
298 ) -> TokenStream {
299 let init_fields = declare_local_storage_for_from_args_fields(fields);
300 let unwrap_fields = unwrap_from_args_fields(fields);
301 let positional_fields: Vec<&StructField<'_>> =
302 fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
303 let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
304 let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
305 let last_positional_is_repeating = positional_fields
306 .last()
307 .map(|field| field.optionality == Optionality::Repeating)
308 .unwrap_or(false);
309 let last_positional_is_greedy = positional_fields
310 .last()
311 .map(|field| field.kind == FieldKind::Positional && field.attrs.greedy.is_some())
312 .unwrap_or(false);
313
314 let flag_output_table = fields.iter().filter_map(|field| {
315 let field_name = &field.field.ident;
316 match field.kind {
317 FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
318 FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
319 FieldKind::SubCommand | FieldKind::Positional => None,
320 }
321 });
322
323 let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(fields);
324
325 let mut subcommands_iter =
326 fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
327
328 let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
329 for dup_subcommand in subcommands_iter {
330 errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
331 }
332
333 let impl_span = Span::call_site();
334
335 let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
336
337 let append_missing_requirements =
338 append_missing_requirements(&missing_requirements_ident, fields);
339
340 let parse_subcommands = if let Some(subcommand) = subcommand {
341 let name = subcommand.name;
342 let ty = subcommand.ty_without_wrapper;
343 quote_spanned! { impl_span =>
344 Some(argh::ParseStructSubCommand {
345 subcommands: <#ty as argh::SubCommands>::COMMANDS,
346 dynamic_subcommands: &<#ty as argh::SubCommands>::dynamic_commands(),
347 parse_func: &mut |__command, __remaining_args| {
348 #name = Some(<#ty as argh::FromArgs>::from_args(__command, __remaining_args)?);
349 Ok(())
350 },
351 })
352 }
353 } else {
354 quote_spanned! { impl_span => None }
355 };
356
357 // Identifier referring to a value containing the name of the current command as an `&[&str]`.
358 let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
359 let help = help::help(errors, cmd_name_str_array_ident, type_attrs, fields, subcommand);
360
361 let method_impl = quote_spanned! { impl_span =>
362 fn from_args(__cmd_name: &[&str], __args: &[&str])
363 -> std::result::Result<Self, argh::EarlyExit>
364 {
365 #![allow(clippy::unwrap_in_result)]
366
367 #( #init_fields )*
368
369 argh::parse_struct_args(
370 __cmd_name,
371 __args,
372 argh::ParseStructOptions {
373 arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
374 slots: &mut [ #( #flag_output_table, )* ],
375 },
376 argh::ParseStructPositionals {
377 positionals: &mut [
378 #(
379 argh::ParseStructPositional {
380 name: #positional_field_names,
381 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
382 },
383 )*
384 ],
385 last_is_repeating: #last_positional_is_repeating,
386 last_is_greedy: #last_positional_is_greedy,
387 },
388 #parse_subcommands,
389 &|| #help,
390 )?;
391
392 let mut #missing_requirements_ident = argh::MissingRequirements::default();
393 #(
394 #append_missing_requirements
395 )*
396 #missing_requirements_ident.err_on_any()?;
397
398 Ok(Self {
399 #( #unwrap_fields, )*
400 })
401 }
402 };
403
404 method_impl
405 }
406
impl_from_args_struct_redact_arg_values<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream407 fn impl_from_args_struct_redact_arg_values<'a>(
408 errors: &Errors,
409 type_attrs: &TypeAttrs,
410 fields: &'a [StructField<'a>],
411 ) -> TokenStream {
412 let init_fields = declare_local_storage_for_redacted_fields(fields);
413 let unwrap_fields = unwrap_redacted_fields(fields);
414
415 let positional_fields: Vec<&StructField<'_>> =
416 fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
417 let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
418 let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
419 let last_positional_is_repeating = positional_fields
420 .last()
421 .map(|field| field.optionality == Optionality::Repeating)
422 .unwrap_or(false);
423 let last_positional_is_greedy = positional_fields
424 .last()
425 .map(|field| field.kind == FieldKind::Positional && field.attrs.greedy.is_some())
426 .unwrap_or(false);
427
428 let flag_output_table = fields.iter().filter_map(|field| {
429 let field_name = &field.field.ident;
430 match field.kind {
431 FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
432 FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
433 FieldKind::SubCommand | FieldKind::Positional => None,
434 }
435 });
436
437 let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(fields);
438
439 let mut subcommands_iter =
440 fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
441
442 let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
443 for dup_subcommand in subcommands_iter {
444 errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
445 }
446
447 let impl_span = Span::call_site();
448
449 let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
450
451 let append_missing_requirements =
452 append_missing_requirements(&missing_requirements_ident, fields);
453
454 let redact_subcommands = if let Some(subcommand) = subcommand {
455 let name = subcommand.name;
456 let ty = subcommand.ty_without_wrapper;
457 quote_spanned! { impl_span =>
458 Some(argh::ParseStructSubCommand {
459 subcommands: <#ty as argh::SubCommands>::COMMANDS,
460 dynamic_subcommands: &<#ty as argh::SubCommands>::dynamic_commands(),
461 parse_func: &mut |__command, __remaining_args| {
462 #name = Some(<#ty as argh::FromArgs>::redact_arg_values(__command, __remaining_args)?);
463 Ok(())
464 },
465 })
466 }
467 } else {
468 quote_spanned! { impl_span => None }
469 };
470
471 let unwrap_cmd_name_err_string = if type_attrs.is_subcommand.is_none() {
472 quote! { "no command name" }
473 } else {
474 quote! { "no subcommand name" }
475 };
476
477 // Identifier referring to a value containing the name of the current command as an `&[&str]`.
478 let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
479 let help = help::help(errors, cmd_name_str_array_ident, type_attrs, fields, subcommand);
480
481 let method_impl = quote_spanned! { impl_span =>
482 fn redact_arg_values(__cmd_name: &[&str], __args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
483 #( #init_fields )*
484
485 argh::parse_struct_args(
486 __cmd_name,
487 __args,
488 argh::ParseStructOptions {
489 arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
490 slots: &mut [ #( #flag_output_table, )* ],
491 },
492 argh::ParseStructPositionals {
493 positionals: &mut [
494 #(
495 argh::ParseStructPositional {
496 name: #positional_field_names,
497 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
498 },
499 )*
500 ],
501 last_is_repeating: #last_positional_is_repeating,
502 last_is_greedy: #last_positional_is_greedy,
503 },
504 #redact_subcommands,
505 &|| #help,
506 )?;
507
508 let mut #missing_requirements_ident = argh::MissingRequirements::default();
509 #(
510 #append_missing_requirements
511 )*
512 #missing_requirements_ident.err_on_any()?;
513
514 let mut __redacted = vec![
515 if let Some(cmd_name) = __cmd_name.last() {
516 (*cmd_name).to_owned()
517 } else {
518 return Err(argh::EarlyExit::from(#unwrap_cmd_name_err_string.to_owned()));
519 }
520 ];
521
522 #( #unwrap_fields )*
523
524 Ok(__redacted)
525 }
526 };
527
528 method_impl
529 }
530
531 /// Ensures that only the last positional arg is non-required.
ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>])532 fn ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>]) {
533 let mut first_non_required_span = None;
534 for field in fields {
535 if field.kind == FieldKind::Positional {
536 if let Some(first) = first_non_required_span {
537 errors.err_span(
538 first,
539 "Only the last positional argument may be `Option`, `Vec`, or defaulted.",
540 );
541 errors.err(&field.field, "Later positional argument declared here.");
542 return;
543 }
544 if !field.optionality.is_required() {
545 first_non_required_span = Some(field.field.span());
546 }
547 }
548 }
549 }
550
551 /// Ensures that only one short or long name is used.
ensure_unique_names(errors: &Errors, fields: &[StructField<'_>])552 fn ensure_unique_names(errors: &Errors, fields: &[StructField<'_>]) {
553 let mut seen_short_names = HashMap::new();
554 let mut seen_long_names = HashMap::new();
555
556 for field in fields {
557 if let Some(short_name) = &field.attrs.short {
558 let short_name = short_name.value();
559 if let Some(first_use_field) = seen_short_names.get(&short_name) {
560 errors.err_span_tokens(
561 first_use_field,
562 &format!("The short name of \"-{}\" was already used here.", short_name),
563 );
564 errors.err_span_tokens(field.field, "Later usage here.");
565 }
566
567 seen_short_names.insert(short_name, &field.field);
568 }
569
570 if let Some(long_name) = &field.long_name {
571 if let Some(first_use_field) = seen_long_names.get(&long_name) {
572 errors.err_span_tokens(
573 *first_use_field,
574 &format!("The long name of \"{}\" was already used here.", long_name),
575 );
576 errors.err_span_tokens(field.field, "Later usage here.");
577 }
578
579 seen_long_names.insert(long_name, field.field);
580 }
581 }
582 }
583
584 /// Implement `argh::TopLevelCommand` or `argh::SubCommand` as appropriate.
top_or_sub_cmd_impl( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, ) -> TokenStream585 fn top_or_sub_cmd_impl(
586 errors: &Errors,
587 name: &syn::Ident,
588 type_attrs: &TypeAttrs,
589 generic_args: &syn::Generics,
590 ) -> TokenStream {
591 let description =
592 help::require_description(errors, name.span(), &type_attrs.description, "type");
593 let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
594 if type_attrs.is_subcommand.is_none() {
595 // Not a subcommand
596 quote! {
597 #[automatically_derived]
598 impl #impl_generics argh::TopLevelCommand for #name #ty_generics #where_clause {}
599 }
600 } else {
601 let empty_str = syn::LitStr::new("", Span::call_site());
602 let subcommand_name = type_attrs.name.as_ref().unwrap_or_else(|| {
603 errors.err(name, "`#[argh(name = \"...\")]` attribute is required for subcommands");
604 &empty_str
605 });
606 quote! {
607 #[automatically_derived]
608 impl #impl_generics argh::SubCommand for #name #ty_generics #where_clause {
609 const COMMAND: &'static argh::CommandInfo = &argh::CommandInfo {
610 name: #subcommand_name,
611 description: #description,
612 };
613 }
614 }
615 }
616 }
617
618 /// Declare a local slots to store each field in during parsing.
619 ///
620 /// Most fields are stored in `Option<FieldType>` locals.
621 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
622 /// function that knows how to decode the appropriate value.
declare_local_storage_for_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a623 fn declare_local_storage_for_from_args_fields<'a>(
624 fields: &'a [StructField<'a>],
625 ) -> impl Iterator<Item = TokenStream> + 'a {
626 fields.iter().map(|field| {
627 let field_name = &field.field.ident;
628 let field_type = &field.ty_without_wrapper;
629
630 // Wrap field types in `Option` if they aren't already `Option` or `Vec`-wrapped.
631 let field_slot_type = match field.optionality {
632 Optionality::Optional | Optionality::Repeating => (&field.field.ty).into_token_stream(),
633 Optionality::None | Optionality::Defaulted(_) => {
634 quote! { std::option::Option<#field_type> }
635 }
636 };
637
638 match field.kind {
639 FieldKind::Option | FieldKind::Positional => {
640 let from_str_fn = match &field.attrs.from_str_fn {
641 Some(from_str_fn) => from_str_fn.into_token_stream(),
642 None => {
643 quote! {
644 <#field_type as argh::FromArgValue>::from_arg_value
645 }
646 }
647 };
648
649 quote! {
650 let mut #field_name: argh::ParseValueSlotTy<#field_slot_type, #field_type>
651 = argh::ParseValueSlotTy {
652 slot: std::default::Default::default(),
653 parse_func: |_, value| { #from_str_fn(value) },
654 };
655 }
656 }
657 FieldKind::SubCommand => {
658 quote! { let mut #field_name: #field_slot_type = None; }
659 }
660 FieldKind::Switch => {
661 quote! { let mut #field_name: #field_slot_type = argh::Flag::default(); }
662 }
663 }
664 })
665 }
666
667 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a668 fn unwrap_from_args_fields<'a>(
669 fields: &'a [StructField<'a>],
670 ) -> impl Iterator<Item = TokenStream> + 'a {
671 fields.iter().map(|field| {
672 let field_name = field.name;
673 match field.kind {
674 FieldKind::Option | FieldKind::Positional => match &field.optionality {
675 Optionality::None => quote! {
676 #field_name: #field_name.slot.unwrap()
677 },
678 Optionality::Optional | Optionality::Repeating => {
679 quote! { #field_name: #field_name.slot }
680 }
681 Optionality::Defaulted(tokens) => {
682 quote! {
683 #field_name: #field_name.slot.unwrap_or_else(|| #tokens)
684 }
685 }
686 },
687 FieldKind::Switch => field_name.into_token_stream(),
688 FieldKind::SubCommand => match field.optionality {
689 Optionality::None => quote! { #field_name: #field_name.unwrap() },
690 Optionality::Optional | Optionality::Repeating => field_name.into_token_stream(),
691 Optionality::Defaulted(_) => unreachable!(),
692 },
693 }
694 })
695 }
696
697 /// Declare a local slots to store each field in during parsing.
698 ///
699 /// Most fields are stored in `Option<FieldType>` locals.
700 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
701 /// function that knows how to decode the appropriate value.
declare_local_storage_for_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a702 fn declare_local_storage_for_redacted_fields<'a>(
703 fields: &'a [StructField<'a>],
704 ) -> impl Iterator<Item = TokenStream> + 'a {
705 fields.iter().map(|field| {
706 let field_name = &field.field.ident;
707
708 match field.kind {
709 FieldKind::Switch => {
710 quote! {
711 let mut #field_name = argh::RedactFlag {
712 slot: None,
713 };
714 }
715 }
716 FieldKind::Option => {
717 let field_slot_type = match field.optionality {
718 Optionality::Repeating => {
719 quote! { std::vec::Vec<String> }
720 }
721 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
722 quote! { std::option::Option<String> }
723 }
724 };
725
726 quote! {
727 let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
728 argh::ParseValueSlotTy {
729 slot: std::default::Default::default(),
730 parse_func: |arg, _| { Ok(arg.to_owned()) },
731 };
732 }
733 }
734 FieldKind::Positional => {
735 let field_slot_type = match field.optionality {
736 Optionality::Repeating => {
737 quote! { std::vec::Vec<String> }
738 }
739 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
740 quote! { std::option::Option<String> }
741 }
742 };
743
744 let arg_name = field.positional_arg_name();
745 quote! {
746 let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
747 argh::ParseValueSlotTy {
748 slot: std::default::Default::default(),
749 parse_func: |_, _| { Ok(#arg_name.to_owned()) },
750 };
751 }
752 }
753 FieldKind::SubCommand => {
754 quote! { let mut #field_name: std::option::Option<std::vec::Vec<String>> = None; }
755 }
756 }
757 })
758 }
759
760 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a761 fn unwrap_redacted_fields<'a>(
762 fields: &'a [StructField<'a>],
763 ) -> impl Iterator<Item = TokenStream> + 'a {
764 fields.iter().map(|field| {
765 let field_name = field.name;
766
767 match field.kind {
768 FieldKind::Switch => {
769 quote! {
770 if let Some(__field_name) = #field_name.slot {
771 __redacted.push(__field_name);
772 }
773 }
774 }
775 FieldKind::Option => match field.optionality {
776 Optionality::Repeating => {
777 quote! {
778 __redacted.extend(#field_name.slot.into_iter());
779 }
780 }
781 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
782 quote! {
783 if let Some(__field_name) = #field_name.slot {
784 __redacted.push(__field_name);
785 }
786 }
787 }
788 },
789 FieldKind::Positional => {
790 quote! {
791 __redacted.extend(#field_name.slot.into_iter());
792 }
793 }
794 FieldKind::SubCommand => {
795 quote! {
796 if let Some(__subcommand_args) = #field_name {
797 __redacted.extend(__subcommand_args.into_iter());
798 }
799 }
800 }
801 }
802 })
803 }
804
805 /// Entries of tokens like `("--some-flag-key", 5)` that map from a flag key string
806 /// to an index in the output table.
flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream>807 fn flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream> {
808 let mut flag_str_to_output_table_map = vec![];
809 for (i, (field, long_name)) in fields
810 .iter()
811 .filter_map(|field| field.long_name.as_ref().map(|long_name| (field, long_name)))
812 .enumerate()
813 {
814 if let Some(short) = &field.attrs.short {
815 let short = format!("-{}", short.value());
816 flag_str_to_output_table_map.push(quote! { (#short, #i) });
817 }
818
819 flag_str_to_output_table_map.push(quote! { (#long_name, #i) });
820 }
821 flag_str_to_output_table_map
822 }
823
824 /// For each non-optional field, add an entry to the `argh::MissingRequirements`.
append_missing_requirements<'a>( mri: &syn::Ident, fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a825 fn append_missing_requirements<'a>(
826 // missing_requirements_ident
827 mri: &syn::Ident,
828 fields: &'a [StructField<'a>],
829 ) -> impl Iterator<Item = TokenStream> + 'a {
830 let mri = mri.clone();
831 fields.iter().filter(|f| f.optionality.is_required()).map(move |field| {
832 let field_name = field.name;
833 match field.kind {
834 FieldKind::Switch => unreachable!("switches are always optional"),
835 FieldKind::Positional => {
836 let name = field.positional_arg_name();
837 quote! {
838 if #field_name.slot.is_none() {
839 #mri.missing_positional_arg(#name)
840 }
841 }
842 }
843 FieldKind::Option => {
844 let name = field.long_name.as_ref().expect("options always have a long name");
845 quote! {
846 if #field_name.slot.is_none() {
847 #mri.missing_option(#name)
848 }
849 }
850 }
851 FieldKind::SubCommand => {
852 let ty = field.ty_without_wrapper;
853 quote! {
854 if #field_name.is_none() {
855 #mri.missing_subcommands(
856 <#ty as argh::SubCommands>::COMMANDS
857 .iter()
858 .cloned()
859 .chain(
860 <#ty as argh::SubCommands>::dynamic_commands()
861 .iter()
862 .copied()
863 ),
864 )
865 }
866 }
867 }
868 }
869 })
870 }
871
872 /// Require that a type can be a `switch`.
873 /// Throws an error for all types except booleans and integers
ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool874 fn ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool {
875 fn ty_can_be_switch(ty: &syn::Type) -> bool {
876 if let syn::Type::Path(path) = ty {
877 if path.qself.is_some() {
878 return false;
879 }
880 if path.path.segments.len() != 1 {
881 return false;
882 }
883 let ident = &path.path.segments[0].ident;
884 // `Option<bool>` can be used as a `switch`.
885 if ident == "Option" {
886 if let PathArguments::AngleBracketed(args) = &path.path.segments[0].arguments {
887 if let GenericArgument::Type(Type::Path(p)) = &args.args[0] {
888 if p.path.segments[0].ident == "bool" {
889 return true;
890 }
891 }
892 }
893 }
894 ["bool", "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]
895 .iter()
896 .any(|path| ident == path)
897 } else {
898 false
899 }
900 }
901
902 let res = ty_can_be_switch(ty);
903 if !res {
904 errors.err(ty, "switches must be of type `bool`, `Option<bool>`, or integer type");
905 }
906 res
907 }
908
909 /// Returns `Some(T)` if a type is `wrapper_name<T>` for any `wrapper_name` in `wrapper_names`.
ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type>910 fn ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type> {
911 if let syn::Type::Path(path) = ty {
912 if path.qself.is_some() {
913 return None;
914 }
915 // Since we only check the last path segment, it isn't necessarily the case that
916 // we're referring to `std::vec::Vec` or `std::option::Option`, but there isn't
917 // a fool proof way to check these since name resolution happens after macro expansion,
918 // so this is likely "good enough" (so long as people don't have their own types called
919 // `Option` or `Vec` that take one generic parameter they're looking to parse).
920 let last_segment = path.path.segments.last()?;
921 if !wrapper_names.iter().any(|name| last_segment.ident == *name) {
922 return None;
923 }
924 if let syn::PathArguments::AngleBracketed(gen_args) = &last_segment.arguments {
925 let generic_arg = gen_args.args.first()?;
926 if let syn::GenericArgument::Type(ty) = &generic_arg {
927 return Some(ty);
928 }
929 }
930 }
931 None
932 }
933
934 /// Implements `FromArgs` and `SubCommands` for a `#![derive(FromArgs)]` enum.
impl_from_args_enum( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, de: &syn::DataEnum, ) -> TokenStream935 fn impl_from_args_enum(
936 errors: &Errors,
937 name: &syn::Ident,
938 type_attrs: &TypeAttrs,
939 generic_args: &syn::Generics,
940 de: &syn::DataEnum,
941 ) -> TokenStream {
942 parse_attrs::check_enum_type_attrs(errors, type_attrs, &de.enum_token.span);
943
944 // An enum variant like `<name>(<ty>)`
945 struct SubCommandVariant<'a> {
946 name: &'a syn::Ident,
947 ty: &'a syn::Type,
948 }
949
950 let mut dynamic_type_and_variant = None;
951
952 let variants: Vec<SubCommandVariant<'_>> = de
953 .variants
954 .iter()
955 .filter_map(|variant| {
956 let name = &variant.ident;
957 let ty = enum_only_single_field_unnamed_variants(errors, &variant.fields)?;
958 if parse_attrs::VariantAttrs::parse(errors, variant).is_dynamic.is_some() {
959 if dynamic_type_and_variant.is_some() {
960 errors.err(variant, "Only one variant can have the `dynamic` attribute");
961 }
962 dynamic_type_and_variant = Some((ty, name));
963 None
964 } else {
965 Some(SubCommandVariant { name, ty })
966 }
967 })
968 .collect();
969
970 let name_repeating = std::iter::repeat(name.clone());
971 let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
972 let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
973 let dynamic_from_args =
974 dynamic_type_and_variant.as_ref().map(|(dynamic_type, dynamic_variant)| {
975 quote! {
976 if let Some(result) = <#dynamic_type as argh::DynamicSubCommand>::try_from_args(
977 command_name, args) {
978 return result.map(#name::#dynamic_variant);
979 }
980 }
981 });
982 let dynamic_redact_arg_values = dynamic_type_and_variant.as_ref().map(|(dynamic_type, _)| {
983 quote! {
984 if let Some(result) = <#dynamic_type as argh::DynamicSubCommand>::try_redact_arg_values(
985 command_name, args) {
986 return result;
987 }
988 }
989 });
990 let dynamic_commands = dynamic_type_and_variant.as_ref().map(|(dynamic_type, _)| {
991 quote! {
992 fn dynamic_commands() -> &'static [&'static argh::CommandInfo] {
993 <#dynamic_type as argh::DynamicSubCommand>::commands()
994 }
995 }
996 });
997
998 let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
999 quote! {
1000 impl #impl_generics argh::FromArgs for #name #ty_generics #where_clause {
1001 fn from_args(command_name: &[&str], args: &[&str])
1002 -> std::result::Result<Self, argh::EarlyExit>
1003 {
1004 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
1005 *subcommand_name
1006 } else {
1007 return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
1008 };
1009
1010 #(
1011 if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
1012 return Ok(#name_repeating::#variant_names(
1013 <#variant_ty as argh::FromArgs>::from_args(command_name, args)?
1014 ));
1015 }
1016 )*
1017
1018 #dynamic_from_args
1019
1020 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
1021 }
1022
1023 fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
1024 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
1025 *subcommand_name
1026 } else {
1027 return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
1028 };
1029
1030 #(
1031 if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
1032 return <#variant_ty as argh::FromArgs>::redact_arg_values(command_name, args);
1033 }
1034 )*
1035
1036 #dynamic_redact_arg_values
1037
1038 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
1039 }
1040 }
1041
1042 impl #impl_generics argh::SubCommands for #name #ty_generics #where_clause {
1043 const COMMANDS: &'static [&'static argh::CommandInfo] = &[#(
1044 <#variant_ty as argh::SubCommand>::COMMAND,
1045 )*];
1046
1047 #dynamic_commands
1048 }
1049 }
1050 }
1051
1052 /// Returns `Some(Bar)` if the field is a single-field unnamed variant like `Foo(Bar)`.
1053 /// Otherwise, generates an error.
enum_only_single_field_unnamed_variants<'a>( errors: &Errors, variant_fields: &'a syn::Fields, ) -> Option<&'a syn::Type>1054 fn enum_only_single_field_unnamed_variants<'a>(
1055 errors: &Errors,
1056 variant_fields: &'a syn::Fields,
1057 ) -> Option<&'a syn::Type> {
1058 macro_rules! with_enum_suggestion {
1059 ($help_text:literal) => {
1060 concat!(
1061 $help_text,
1062 "\nInstead, use a variant with a single unnamed field for each subcommand:\n",
1063 " enum MyCommandEnum {\n",
1064 " SubCommandOne(SubCommandOne),\n",
1065 " SubCommandTwo(SubCommandTwo),\n",
1066 " }",
1067 )
1068 };
1069 }
1070
1071 match variant_fields {
1072 syn::Fields::Named(fields) => {
1073 errors.err(
1074 fields,
1075 with_enum_suggestion!(
1076 "`#![derive(FromArgs)]` `enum`s do not support variants with named fields."
1077 ),
1078 );
1079 None
1080 }
1081 syn::Fields::Unit => {
1082 errors.err(
1083 variant_fields,
1084 with_enum_suggestion!(
1085 "`#![derive(FromArgs)]` does not support `enum`s with no variants."
1086 ),
1087 );
1088 None
1089 }
1090 syn::Fields::Unnamed(fields) => {
1091 if fields.unnamed.len() != 1 {
1092 errors.err(
1093 fields,
1094 with_enum_suggestion!(
1095 "`#![derive(FromArgs)]` `enum` variants must only contain one field."
1096 ),
1097 );
1098 None
1099 } else {
1100 // `unwrap` is okay because of the length check above.
1101 let first_field = fields.unnamed.first().unwrap();
1102 Some(&first_field.ty)
1103 }
1104 }
1105 }
1106 }
1107