• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use quote::quote;
6 
7 /// A helper derive proc macro to flatten multiple subcommand enums into one
8 /// Note that it is unable to check for duplicate commands and they will be
9 /// tried in order of declaration
10 #[proc_macro_derive(FlattenSubcommand)]
flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream11 pub fn flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
12     let ast = syn::parse_macro_input!(input as syn::DeriveInput);
13     let de = match ast.data {
14         syn::Data::Enum(v) => v,
15         _ => unreachable!(),
16     };
17     let name = &ast.ident;
18 
19     // An enum variant like `<name>(<ty>)`
20     struct SubCommandVariant<'a> {
21         name: &'a syn::Ident,
22         ty: &'a syn::Type,
23     }
24 
25     let variants: Vec<SubCommandVariant<'_>> = de
26         .variants
27         .iter()
28         .map(|variant| {
29             let name = &variant.ident;
30             let ty = match &variant.fields {
31                 syn::Fields::Unnamed(field) => {
32                     if field.unnamed.len() != 1 {
33                         unreachable!()
34                     }
35 
36                     &field.unnamed.first().unwrap().ty
37                 }
38                 _ => unreachable!(),
39             };
40             SubCommandVariant { name, ty }
41         })
42         .collect();
43 
44     let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
45     let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
46 
47     (quote! {
48         impl argh::FromArgs for #name {
49             fn from_args(command_name: &[&str], args: &[&str])
50                 -> std::result::Result<Self, argh::EarlyExit>
51             {
52                 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
53                     *subcommand_name
54                 } else {
55                     return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
56                 };
57 
58                 #(
59                     if <#variant_ty as argh::SubCommands>::COMMANDS
60                     .iter()
61                     .find(|ci| ci.name.eq(subcommand_name))
62                     .is_some()
63                     {
64                         return <#variant_ty as argh::FromArgs>::from_args(command_name, args)
65                             .map(|v| Self::#variant_names(v));
66                     }
67                 )*
68 
69                 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
70             }
71 
72             fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
73                 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
74                     *subcommand_name
75                 } else {
76                     return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
77                 };
78 
79                 #(
80                     if <#variant_ty as argh::SubCommands>::COMMANDS
81                     .iter()
82                     .find(|ci| ci.name.eq(subcommand_name))
83                     .is_some()
84                     {
85                         return <#variant_ty as argh::FromArgs>::redact_arg_values(
86                             command_name,
87                             args,
88                         );
89                     }
90 
91                 )*
92 
93                 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
94             }
95         }
96 
97         impl argh::SubCommands for #name {
98             const COMMANDS: &'static [&'static argh::CommandInfo] = {
99                 const TOTAL_LEN: usize = #(<#variant_ty as argh::SubCommands>::COMMANDS.len())+*;
100                 const COMMANDS: [&'static argh::CommandInfo; TOTAL_LEN] = {
101                     let slices = &[#(<#variant_ty as argh::SubCommands>::COMMANDS,)*];
102                     // Its not possible for slices[0][0] to be invalid
103                     let mut output = [slices[0][0]; TOTAL_LEN];
104 
105                     let mut output_index = 0;
106                     let mut which_slice = 0;
107                     while which_slice < slices.len() {
108                         let slice = &slices[which_slice];
109                         let mut index_in_slice = 0;
110                         while index_in_slice < slice.len() {
111                             output[output_index] = slice[index_in_slice];
112                             output_index += 1;
113                             index_in_slice += 1;
114                         }
115                         which_slice += 1;
116                     }
117                     output
118                 };
119                 &COMMANDS
120             };
121         }
122     })
123     .into()
124 }
125 
126 /// A helper proc macro to pad strings so that argh would break them at intended points
127 #[proc_macro_attribute]
pad_description_for_argh( _attr: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream128 pub fn pad_description_for_argh(
129     _attr: proc_macro::TokenStream,
130     item: proc_macro::TokenStream,
131 ) -> proc_macro::TokenStream {
132     let mut item = syn::parse_macro_input!(item as syn::Item);
133     if let syn::Item::Struct(s) = &mut item {
134         if let syn::Fields::Named(fields) = &mut s.fields {
135             for f in fields.named.iter_mut() {
136                 for a in f.attrs.iter_mut() {
137                     if a.path
138                         .get_ident()
139                         .map(|i| i.to_string())
140                         .unwrap_or_default()
141                         == *"doc"
142                     {
143                         if let Ok(syn::Meta::NameValue(nv)) = a.parse_meta() {
144                             if let syn::Lit::Str(s) = nv.lit {
145                                 let doc = s
146                                     .value()
147                                     .lines()
148                                     .map(|s| format!("{: <61}", s))
149                                     .collect::<String>();
150                                 *a = syn::parse_quote! { #[doc= #doc] };
151                             }
152                         }
153                     }
154                 }
155             }
156         } else {
157             unreachable!()
158         }
159     } else {
160         unreachable!()
161     }
162     quote! {
163         #item
164     }
165     .into()
166 }
167