1 // Copyright 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 extern crate proc_macro;
16
17 mod config;
18 mod discriminant;
19 mod repr;
20
21 use config::Config;
22
23 use discriminant::Discriminant;
24 use proc_macro2::{Span, TokenStream};
25 use quote::{format_ident, quote, ToTokens};
26 use repr::Repr;
27 use std::collections::HashSet;
28 use syn::Attribute;
29 use syn::{
30 parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Visibility,
31 };
32
33 /// Sets the span for every token tree in the token stream
set_token_stream_span(tokens: TokenStream, span: Span) -> TokenStream34 fn set_token_stream_span(tokens: TokenStream, span: Span) -> TokenStream {
35 tokens
36 .into_iter()
37 .map(|mut tt| {
38 tt.set_span(span);
39 tt
40 })
41 .collect()
42 }
43
44 /// Checks that there are no duplicate discriminant values. If all variants are literals, return an `Err` so we can have
45 /// more clear error messages. Otherwise, emit a static check that ensures no duplicates.
check_no_alias<'a>( enum_: &ItemEnum, variants: impl Iterator<Item = (&'a Ident, &'a Discriminant, Span)> + Clone, ) -> syn::Result<TokenStream>46 fn check_no_alias<'a>(
47 enum_: &ItemEnum,
48 variants: impl Iterator<Item = (&'a Ident, &'a Discriminant, Span)> + Clone,
49 ) -> syn::Result<TokenStream> {
50 // If they're all literals, we can give better error messages by checking at proc macro time.
51 let mut values: HashSet<i128> = HashSet::new();
52 for (_, variant, span) in variants {
53 if let &Discriminant::Literal(value) = variant {
54 if !values.insert(value) {
55 return Err(Error::new(
56 span,
57 format!("discriminant value `{value}` assigned more than once"),
58 ));
59 }
60 } else {
61 let mut checking_enum = syn::ItemEnum {
62 ident: format_ident!("_Check{}", enum_.ident),
63 vis: Visibility::Inherited,
64 ..enum_.clone()
65 };
66 checking_enum.attrs.retain(|attr| {
67 matches!(
68 attr.path().to_token_stream().to_string().as_str(),
69 "repr" | "allow" | "warn" | "deny" | "forbid"
70 )
71 });
72 return Ok(quote!(
73 #[allow(dead_code)]
74 #checking_enum
75 ));
76 }
77 }
78 Ok(TokenStream::default())
79 }
80
emit_debug_impl<'a>( ident: &Ident, variants: impl Iterator<Item = &'a Ident> + Clone, attrs: impl Iterator<Item = &'a Vec<Attribute>> + Clone, ) -> TokenStream81 fn emit_debug_impl<'a>(
82 ident: &Ident,
83 variants: impl Iterator<Item = &'a Ident> + Clone,
84 attrs: impl Iterator<Item = &'a Vec<Attribute>> + Clone,
85 ) -> TokenStream {
86 let attrs = attrs.map(|attrs| {
87 // Only allow "#[cfg(...)]" attributes
88 let iter = attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
89 quote!(#(#iter)*)
90 });
91 quote!(impl ::core::fmt::Debug for #ident {
92 fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
93 #![allow(unreachable_patterns)]
94 let s = match *self {
95 #( #attrs Self::#variants => stringify!(#variants), )*
96 _ => {
97 return fmt.debug_tuple(stringify!(#ident)).field(&self.0).finish();
98 }
99 };
100 fmt.pad(s)
101 }
102 })
103 }
104
path_matches_prelude_derive( got_path: &syn::Path, expected_path_after_std: &[&'static str], ) -> bool105 fn path_matches_prelude_derive(
106 got_path: &syn::Path,
107 expected_path_after_std: &[&'static str],
108 ) -> bool {
109 let &[a, b] = expected_path_after_std else {
110 unimplemented!("checking against stdlib paths with != 2 parts");
111 };
112 let segments: Vec<&syn::PathSegment> = got_path.segments.iter().collect();
113 if segments
114 .iter()
115 .any(|segment| !matches!(segment.arguments, syn::PathArguments::None))
116 {
117 return false;
118 }
119 match &segments[..] {
120 // `core::fmt::Debug` or `some_crate::module::Name`
121 [maybe_core_or_std, maybe_a, maybe_b] => {
122 (maybe_core_or_std.ident == "core" || maybe_core_or_std.ident == "std")
123 && maybe_a.ident == a
124 && maybe_b.ident == b
125 }
126 // `fmt::Debug` or `module::Name`
127 [maybe_a, maybe_b] => {
128 maybe_a.ident == a && maybe_b.ident == b && got_path.leading_colon.is_none()
129 }
130 // `Debug` or `Name``
131 [maybe_b] => maybe_b.ident == b && got_path.leading_colon.is_none(),
132 _ => false,
133 }
134 }
135
136 fn open_enum_impl(
137 enum_: ItemEnum,
138 Config {
139 allow_alias,
140 repr_visibility,
141 }: Config,
142 ) -> Result<TokenStream, Error> {
143 // Does the enum define a `#[repr()]`?
144 let mut struct_attrs: Vec<TokenStream> = Vec::with_capacity(enum_.attrs.len() + 5);
145 struct_attrs.push(quote!(#[allow(clippy::exhaustive_structs)]));
146
147 if !enum_.generics.params.is_empty() {
148 return Err(Error::new(enum_.generics.span(), "enum cannot be generic"));
149 }
150 let mut variants = Vec::with_capacity(enum_.variants.len());
151 let mut last_field = Discriminant::Literal(-1);
152 for variant in &enum_.variants {
153 if !matches!(variant.fields, syn::Fields::Unit) {
154 return Err(Error::new(variant.span(), "enum cannot contain fields"));
155 }
156
157 let (value, value_span) = if let Some((_, discriminant)) = &variant.discriminant {
158 let span = discriminant.span();
159 (Discriminant::new(discriminant.clone())?, span)
160 } else {
161 last_field = last_field
162 .next_value()
163 .ok_or_else(|| Error::new(variant.span(), "enum discriminant overflowed"))?;
164 (last_field.clone(), variant.ident.span())
165 };
166 last_field = value.clone();
167 variants.push((&variant.ident, value, value_span, &variant.attrs))
168 }
169
170 let mut impl_attrs: Vec<TokenStream> = vec![quote!(#[allow(non_upper_case_globals)])];
171 let mut explicit_repr: Option<Repr> = None;
172
173 // To make `match` seamless, derive(PartialEq, Eq) if they aren't already.
174 let mut extra_derives = vec![quote!(::core::cmp::PartialEq), quote!(::core::cmp::Eq)];
175
176 let mut make_custom_debug_impl = false;
177 for attr in &enum_.attrs {
178 let mut include_in_struct = true;
179 // Turns out `is_ident` does a `to_string` every time
180 match attr.path().to_token_stream().to_string().as_str() {
181 "derive" => {
182 if let Ok(derive_paths) =
183 attr.parse_args_with(Punctuated::<syn::Path, syn::Token![,]>::parse_terminated)
184 {
185 for derive in &derive_paths {
186 // These derives are treated specially
187 const PARTIAL_EQ_PATH: &[&str] = &["cmp", "PartialEq"];
188 const EQ_PATH: &[&str] = &["cmp", "Eq"];
189 const DEBUG_PATH: &[&str] = &["fmt", "Debug"];
190
191 if path_matches_prelude_derive(derive, PARTIAL_EQ_PATH)
192 || path_matches_prelude_derive(derive, EQ_PATH)
193 {
194 // This derive is always included, exclude it.
195 continue;
196 }
197 if path_matches_prelude_derive(derive, DEBUG_PATH) && !allow_alias {
198 make_custom_debug_impl = true;
199 // Don't include this derive since we're generating a special one.
200 continue;
201 }
202 extra_derives.push(derive.to_token_stream());
203 }
204 include_in_struct = false;
205 }
206 }
207 // Copy linting attribute to the impl.
208 "allow" | "warn" | "deny" | "forbid" => impl_attrs.push(attr.to_token_stream()),
209 "repr" => {
210 assert!(explicit_repr.is_none(), "duplicate explicit repr");
211 explicit_repr = Some(attr.parse_args()?);
212 include_in_struct = false;
213 }
214 "non_exhaustive" => {
215 // technically it's exhaustive if the enum covers the full integer range
216 return Err(Error::new(attr.path().span(), "`non_exhaustive` cannot be applied to an open enum; it is already non-exhaustive"));
217 }
218 _ => {}
219 }
220 if include_in_struct {
221 struct_attrs.push(attr.to_token_stream());
222 }
223 }
224
225 // The proper repr to type-check against
226 let typecheck_repr: Repr = explicit_repr.unwrap_or(Repr::Isize);
227
228 // The actual representation of the value.
229 let inner_repr = match explicit_repr {
230 Some(explicit_repr) => {
231 // If there is an explicit repr, emit #[repr(transparent)].
232 struct_attrs.push(quote!(#[repr(transparent)]));
233 explicit_repr
234 }
235 None => {
236 // If there isn't an explicit repr, determine an appropriate sized integer that will fit.
237 // Interpret all discriminant expressions as isize.
238 repr::autodetect_inner_repr(variants.iter().map(|v| &v.1))
239 }
240 };
241
242 if !extra_derives.is_empty() {
243 struct_attrs.push(quote!(#[derive(#(#extra_derives),*)]));
244 }
245
246 let alias_check = if allow_alias {
247 TokenStream::default()
248 } else {
249 check_no_alias(&enum_, variants.iter().map(|(i, v, s, _)| (*i, v, *s)))?
250 };
251
252 let syn::ItemEnum { ident, vis, .. } = enum_;
253
254 let debug_impl = if make_custom_debug_impl {
255 emit_debug_impl(
256 &ident,
257 variants.iter().map(|(i, _, _, _)| *i),
258 variants.iter().map(|(_, _, _, a)| *a),
259 )
260 } else {
261 TokenStream::default()
262 };
263
264 let fields = variants
265 .into_iter()
266 .map(|(name, value, value_span, attrs)| {
267 let mut value = value.into_token_stream();
268 value = set_token_stream_span(value, value_span);
269 let inner = if typecheck_repr == inner_repr {
270 value
271 } else {
272 quote!(::core::convert::identity::<#typecheck_repr>(#value) as #inner_repr)
273 };
274 quote!(
275 #(#attrs)*
276 pub const #name: #ident = #ident(#inner);
277 )
278 });
279
280 Ok(quote! {
281 #(#struct_attrs)*
282 #vis struct #ident(#repr_visibility #inner_repr);
283
284 #(#impl_attrs)*
285 impl #ident {
286 #(
287 #fields
288 )*
289 }
290 #debug_impl
291 #alias_check
292 })
293 }
294
295 #[proc_macro_attribute]
open_enum( attrs: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream296 pub fn open_enum(
297 attrs: proc_macro::TokenStream,
298 input: proc_macro::TokenStream,
299 ) -> proc_macro::TokenStream {
300 let enum_ = parse_macro_input!(input as syn::ItemEnum);
301 let config = parse_macro_input!(attrs as Config);
302 open_enum_impl(enum_, config)
303 .unwrap_or_else(Error::into_compile_error)
304 .into()
305 }
306
307 #[cfg(test)]
308 mod tests {
309 use super::*;
310
311 #[test]
test_path_matches_stdlib_derive()312 fn test_path_matches_stdlib_derive() {
313 const DEBUG_PATH: &[&str] = &["fmt", "Debug"];
314
315 for success_case in [
316 "::core::fmt::Debug",
317 "::std::fmt::Debug",
318 "core::fmt::Debug",
319 "std::fmt::Debug",
320 "fmt::Debug",
321 "Debug",
322 ] {
323 assert!(
324 path_matches_prelude_derive(&syn::parse_str(success_case).unwrap(), DEBUG_PATH),
325 "{success_case}"
326 );
327 }
328
329 for fail_case in [
330 "::fmt::Debug",
331 "::Debug",
332 "zerocopy::AsBytes",
333 "::zerocopy::AsBytes",
334 "PartialEq",
335 "core::cmp::Eq",
336 ] {
337 assert!(
338 !path_matches_prelude_derive(&syn::parse_str(fail_case).unwrap(), DEBUG_PATH),
339 "{fail_case}"
340 );
341 }
342 }
343 }
344