1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use quote::quote;
6
7 /// A helper derive proc macro to flatten multiple subcommand enums into one
8 /// Note that it is unable to check for duplicate commands and they will be
9 /// tried in order of declaration
10 #[proc_macro_derive(FlattenSubcommand)]
flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream11 pub fn flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
12 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
13 let de = match ast.data {
14 syn::Data::Enum(v) => v,
15 _ => unreachable!(),
16 };
17 let name = &ast.ident;
18
19 // An enum variant like `<name>(<ty>)`
20 struct SubCommandVariant<'a> {
21 name: &'a syn::Ident,
22 ty: &'a syn::Type,
23 }
24
25 let variants: Vec<SubCommandVariant<'_>> = de
26 .variants
27 .iter()
28 .map(|variant| {
29 let name = &variant.ident;
30 let ty = match &variant.fields {
31 syn::Fields::Unnamed(field) => {
32 if field.unnamed.len() != 1 {
33 unreachable!()
34 }
35
36 &field.unnamed.first().unwrap().ty
37 }
38 _ => unreachable!(),
39 };
40 SubCommandVariant { name, ty }
41 })
42 .collect();
43
44 let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
45 let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
46
47 (quote! {
48 impl argh::FromArgs for #name {
49 fn from_args(command_name: &[&str], args: &[&str])
50 -> std::result::Result<Self, argh::EarlyExit>
51 {
52 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
53 *subcommand_name
54 } else {
55 return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
56 };
57
58 #(
59 if <#variant_ty as argh::SubCommands>::COMMANDS
60 .iter()
61 .find(|ci| ci.name.eq(subcommand_name))
62 .is_some()
63 {
64 return <#variant_ty as argh::FromArgs>::from_args(command_name, args)
65 .map(|v| Self::#variant_names(v));
66 }
67 )*
68
69 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
70 }
71
72 fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
73 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
74 *subcommand_name
75 } else {
76 return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
77 };
78
79 #(
80 if <#variant_ty as argh::SubCommands>::COMMANDS
81 .iter()
82 .find(|ci| ci.name.eq(subcommand_name))
83 .is_some()
84 {
85 return <#variant_ty as argh::FromArgs>::redact_arg_values(
86 command_name,
87 args,
88 );
89 }
90
91 )*
92
93 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
94 }
95 }
96
97 impl argh::SubCommands for #name {
98 const COMMANDS: &'static [&'static argh::CommandInfo] = {
99 const TOTAL_LEN: usize = #(<#variant_ty as argh::SubCommands>::COMMANDS.len())+*;
100 const COMMANDS: [&'static argh::CommandInfo; TOTAL_LEN] = {
101 let slices = &[#(<#variant_ty as argh::SubCommands>::COMMANDS,)*];
102 // Its not possible for slices[0][0] to be invalid
103 let mut output = [slices[0][0]; TOTAL_LEN];
104
105 let mut output_index = 0;
106 let mut which_slice = 0;
107 while which_slice < slices.len() {
108 let slice = &slices[which_slice];
109 let mut index_in_slice = 0;
110 while index_in_slice < slice.len() {
111 output[output_index] = slice[index_in_slice];
112 output_index += 1;
113 index_in_slice += 1;
114 }
115 which_slice += 1;
116 }
117 output
118 };
119 &COMMANDS
120 };
121 }
122 })
123 .into()
124 }
125
126 /// A helper proc macro to pad strings so that argh would break them at intended points
127 #[proc_macro_attribute]
pad_description_for_argh( _attr: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream128 pub fn pad_description_for_argh(
129 _attr: proc_macro::TokenStream,
130 item: proc_macro::TokenStream,
131 ) -> proc_macro::TokenStream {
132 let mut item = syn::parse_macro_input!(item as syn::Item);
133 if let syn::Item::Struct(s) = &mut item {
134 if let syn::Fields::Named(fields) = &mut s.fields {
135 for f in fields.named.iter_mut() {
136 for a in f.attrs.iter_mut() {
137 if a.path
138 .get_ident()
139 .map(|i| i.to_string())
140 .unwrap_or_default()
141 == *"doc"
142 {
143 if let Ok(syn::Meta::NameValue(nv)) = a.parse_meta() {
144 if let syn::Lit::Str(s) = nv.lit {
145 let doc = s
146 .value()
147 .lines()
148 .map(|s| format!("{: <61}", s))
149 .collect::<String>();
150 *a = syn::parse_quote! { #[doc= #doc] };
151 }
152 }
153 }
154 }
155 }
156 } else {
157 unreachable!()
158 }
159 } else {
160 unreachable!()
161 }
162 quote! {
163 #item
164 }
165 .into()
166 }
167