• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 extern crate proc_macro;
16 
17 mod config;
18 mod discriminant;
19 mod repr;
20 
21 use config::Config;
22 
23 use discriminant::Discriminant;
24 use proc_macro2::{Span, TokenStream};
25 use quote::{format_ident, quote, ToTokens};
26 use repr::Repr;
27 use std::collections::HashSet;
28 use syn::Attribute;
29 use syn::{
30     parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Visibility,
31 };
32 
33 /// Sets the span for every token tree in the token stream
set_token_stream_span(tokens: TokenStream, span: Span) -> TokenStream34 fn set_token_stream_span(tokens: TokenStream, span: Span) -> TokenStream {
35     tokens
36         .into_iter()
37         .map(|mut tt| {
38             tt.set_span(span);
39             tt
40         })
41         .collect()
42 }
43 
44 /// Checks that there are no duplicate discriminant values. If all variants are literals, return an `Err` so we can have
45 /// more clear error messages. Otherwise, emit a static check that ensures no duplicates.
check_no_alias<'a>( enum_: &ItemEnum, variants: impl Iterator<Item = (&'a Ident, &'a Discriminant, Span)> + Clone, ) -> syn::Result<TokenStream>46 fn check_no_alias<'a>(
47     enum_: &ItemEnum,
48     variants: impl Iterator<Item = (&'a Ident, &'a Discriminant, Span)> + Clone,
49 ) -> syn::Result<TokenStream> {
50     // If they're all literals, we can give better error messages by checking at proc macro time.
51     let mut values: HashSet<i128> = HashSet::new();
52     for (_, variant, span) in variants {
53         if let &Discriminant::Literal(value) = variant {
54             if !values.insert(value) {
55                 return Err(Error::new(
56                     span,
57                     format!("discriminant value `{value}` assigned more than once"),
58                 ));
59             }
60         } else {
61             let mut checking_enum = syn::ItemEnum {
62                 ident: format_ident!("_Check{}", enum_.ident),
63                 vis: Visibility::Inherited,
64                 ..enum_.clone()
65             };
66             checking_enum.attrs.retain(|attr| {
67                 matches!(
68                     attr.path().to_token_stream().to_string().as_str(),
69                     "repr" | "allow" | "warn" | "deny" | "forbid"
70                 )
71             });
72             return Ok(quote!(
73                 #[allow(dead_code)]
74                 #checking_enum
75             ));
76         }
77     }
78     Ok(TokenStream::default())
79 }
80 
emit_debug_impl<'a>( ident: &Ident, variants: impl Iterator<Item = &'a Ident> + Clone, attrs: impl Iterator<Item = &'a Vec<Attribute>> + Clone, ) -> TokenStream81 fn emit_debug_impl<'a>(
82     ident: &Ident,
83     variants: impl Iterator<Item = &'a Ident> + Clone,
84     attrs: impl Iterator<Item = &'a Vec<Attribute>> + Clone,
85 ) -> TokenStream {
86     let attrs = attrs.map(|attrs| {
87         // Only allow "#[cfg(...)]" attributes
88         let iter = attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
89         quote!(#(#iter)*)
90     });
91     quote!(impl ::core::fmt::Debug for #ident {
92         fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
93             #![allow(unreachable_patterns)]
94             let s = match *self {
95                 #( #attrs Self::#variants => stringify!(#variants), )*
96                 _ => {
97                     return fmt.debug_tuple(stringify!(#ident)).field(&self.0).finish();
98                 }
99             };
100             fmt.pad(s)
101         }
102     })
103 }
104 
path_matches_prelude_derive( got_path: &syn::Path, expected_path_after_std: &[&'static str], ) -> bool105 fn path_matches_prelude_derive(
106     got_path: &syn::Path,
107     expected_path_after_std: &[&'static str],
108 ) -> bool {
109     let &[a, b] = expected_path_after_std else {
110         unimplemented!("checking against stdlib paths with != 2 parts");
111     };
112     let segments: Vec<&syn::PathSegment> = got_path.segments.iter().collect();
113     if segments
114         .iter()
115         .any(|segment| !matches!(segment.arguments, syn::PathArguments::None))
116     {
117         return false;
118     }
119     match &segments[..] {
120         // `core::fmt::Debug` or `some_crate::module::Name`
121         [maybe_core_or_std, maybe_a, maybe_b] => {
122             (maybe_core_or_std.ident == "core" || maybe_core_or_std.ident == "std")
123                 && maybe_a.ident == a
124                 && maybe_b.ident == b
125         }
126         // `fmt::Debug` or `module::Name`
127         [maybe_a, maybe_b] => {
128             maybe_a.ident == a && maybe_b.ident == b && got_path.leading_colon.is_none()
129         }
130         // `Debug` or `Name``
131         [maybe_b] => maybe_b.ident == b && got_path.leading_colon.is_none(),
132         _ => false,
133     }
134 }
135 
136 fn open_enum_impl(
137     enum_: ItemEnum,
138     Config {
139         allow_alias,
140         repr_visibility,
141     }: Config,
142 ) -> Result<TokenStream, Error> {
143     // Does the enum define a `#[repr()]`?
144     let mut struct_attrs: Vec<TokenStream> = Vec::with_capacity(enum_.attrs.len() + 5);
145     struct_attrs.push(quote!(#[allow(clippy::exhaustive_structs)]));
146 
147     if !enum_.generics.params.is_empty() {
148         return Err(Error::new(enum_.generics.span(), "enum cannot be generic"));
149     }
150     let mut variants = Vec::with_capacity(enum_.variants.len());
151     let mut last_field = Discriminant::Literal(-1);
152     for variant in &enum_.variants {
153         if !matches!(variant.fields, syn::Fields::Unit) {
154             return Err(Error::new(variant.span(), "enum cannot contain fields"));
155         }
156 
157         let (value, value_span) = if let Some((_, discriminant)) = &variant.discriminant {
158             let span = discriminant.span();
159             (Discriminant::new(discriminant.clone())?, span)
160         } else {
161             last_field = last_field
162                 .next_value()
163                 .ok_or_else(|| Error::new(variant.span(), "enum discriminant overflowed"))?;
164             (last_field.clone(), variant.ident.span())
165         };
166         last_field = value.clone();
167         variants.push((&variant.ident, value, value_span, &variant.attrs))
168     }
169 
170     let mut impl_attrs: Vec<TokenStream> = vec![quote!(#[allow(non_upper_case_globals)])];
171     let mut explicit_repr: Option<Repr> = None;
172 
173     // To make `match` seamless, derive(PartialEq, Eq) if they aren't already.
174     let mut extra_derives = vec![quote!(::core::cmp::PartialEq), quote!(::core::cmp::Eq)];
175 
176     let mut make_custom_debug_impl = false;
177     for attr in &enum_.attrs {
178         let mut include_in_struct = true;
179         // Turns out `is_ident` does a `to_string` every time
180         match attr.path().to_token_stream().to_string().as_str() {
181             "derive" => {
182                 if let Ok(derive_paths) =
183                     attr.parse_args_with(Punctuated::<syn::Path, syn::Token![,]>::parse_terminated)
184                 {
185                     for derive in &derive_paths {
186                         // These derives are treated specially
187                         const PARTIAL_EQ_PATH: &[&str] = &["cmp", "PartialEq"];
188                         const EQ_PATH: &[&str] = &["cmp", "Eq"];
189                         const DEBUG_PATH: &[&str] = &["fmt", "Debug"];
190 
191                         if path_matches_prelude_derive(derive, PARTIAL_EQ_PATH)
192                             || path_matches_prelude_derive(derive, EQ_PATH)
193                         {
194                             // This derive is always included, exclude it.
195                             continue;
196                         }
197                         if path_matches_prelude_derive(derive, DEBUG_PATH) && !allow_alias {
198                             make_custom_debug_impl = true;
199                             // Don't include this derive since we're generating a special one.
200                             continue;
201                         }
202                         extra_derives.push(derive.to_token_stream());
203                     }
204                     include_in_struct = false;
205                 }
206             }
207             // Copy linting attribute to the impl.
208             "allow" | "warn" | "deny" | "forbid" => impl_attrs.push(attr.to_token_stream()),
209             "repr" => {
210                 assert!(explicit_repr.is_none(), "duplicate explicit repr");
211                 explicit_repr = Some(attr.parse_args()?);
212                 include_in_struct = false;
213             }
214             "non_exhaustive" => {
215                 // technically it's exhaustive if the enum covers the full integer range
216                 return Err(Error::new(attr.path().span(), "`non_exhaustive` cannot be applied to an open enum; it is already non-exhaustive"));
217             }
218             _ => {}
219         }
220         if include_in_struct {
221             struct_attrs.push(attr.to_token_stream());
222         }
223     }
224 
225     // The proper repr to type-check against
226     let typecheck_repr: Repr = explicit_repr.unwrap_or(Repr::Isize);
227 
228     // The actual representation of the value.
229     let inner_repr = match explicit_repr {
230         Some(explicit_repr) => {
231             // If there is an explicit repr, emit #[repr(transparent)].
232             struct_attrs.push(quote!(#[repr(transparent)]));
233             explicit_repr
234         }
235         None => {
236             // If there isn't an explicit repr, determine an appropriate sized integer that will fit.
237             // Interpret all discriminant expressions as isize.
238             repr::autodetect_inner_repr(variants.iter().map(|v| &v.1))
239         }
240     };
241 
242     if !extra_derives.is_empty() {
243         struct_attrs.push(quote!(#[derive(#(#extra_derives),*)]));
244     }
245 
246     let alias_check = if allow_alias {
247         TokenStream::default()
248     } else {
249         check_no_alias(&enum_, variants.iter().map(|(i, v, s, _)| (*i, v, *s)))?
250     };
251 
252     let syn::ItemEnum { ident, vis, .. } = enum_;
253 
254     let debug_impl = if make_custom_debug_impl {
255         emit_debug_impl(
256             &ident,
257             variants.iter().map(|(i, _, _, _)| *i),
258             variants.iter().map(|(_, _, _, a)| *a),
259         )
260     } else {
261         TokenStream::default()
262     };
263 
264     let fields = variants
265         .into_iter()
266         .map(|(name, value, value_span, attrs)| {
267             let mut value = value.into_token_stream();
268             value = set_token_stream_span(value, value_span);
269             let inner = if typecheck_repr == inner_repr {
270                 value
271             } else {
272                 quote!(::core::convert::identity::<#typecheck_repr>(#value) as #inner_repr)
273             };
274             quote!(
275                 #(#attrs)*
276                 pub const #name: #ident = #ident(#inner);
277             )
278         });
279 
280     Ok(quote! {
281         #(#struct_attrs)*
282         #vis struct #ident(#repr_visibility #inner_repr);
283 
284         #(#impl_attrs)*
285         impl #ident {
286             #(
287                 #fields
288             )*
289         }
290         #debug_impl
291         #alias_check
292     })
293 }
294 
295 #[proc_macro_attribute]
open_enum( attrs: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream296 pub fn open_enum(
297     attrs: proc_macro::TokenStream,
298     input: proc_macro::TokenStream,
299 ) -> proc_macro::TokenStream {
300     let enum_ = parse_macro_input!(input as syn::ItemEnum);
301     let config = parse_macro_input!(attrs as Config);
302     open_enum_impl(enum_, config)
303         .unwrap_or_else(Error::into_compile_error)
304         .into()
305 }
306 
307 #[cfg(test)]
308 mod tests {
309     use super::*;
310 
311     #[test]
test_path_matches_stdlib_derive()312     fn test_path_matches_stdlib_derive() {
313         const DEBUG_PATH: &[&str] = &["fmt", "Debug"];
314 
315         for success_case in [
316             "::core::fmt::Debug",
317             "::std::fmt::Debug",
318             "core::fmt::Debug",
319             "std::fmt::Debug",
320             "fmt::Debug",
321             "Debug",
322         ] {
323             assert!(
324                 path_matches_prelude_derive(&syn::parse_str(success_case).unwrap(), DEBUG_PATH),
325                 "{success_case}"
326             );
327         }
328 
329         for fail_case in [
330             "::fmt::Debug",
331             "::Debug",
332             "zerocopy::AsBytes",
333             "::zerocopy::AsBytes",
334             "PartialEq",
335             "core::cmp::Eq",
336         ] {
337             assert!(
338                 !path_matches_prelude_derive(&syn::parse_str(fail_case).unwrap(), DEBUG_PATH),
339                 "{fail_case}"
340             );
341         }
342     }
343 }
344