• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![allow(unused_imports)]
2 use std::{cmp, convert::TryFrom};
3 
4 use proc_macro2::{Ident, Span, TokenStream, TokenTree};
5 use quote::{quote, ToTokens};
6 use syn::{
7   parse::{Parse, ParseStream, Parser},
8   punctuated::Punctuated,
9   spanned::Spanned,
10   Result, *,
11 };
12 
13 macro_rules! bail {
14   ($msg:expr $(,)?) => {
15     return Err(Error::new(Span::call_site(), &$msg[..]))
16   };
17 
18   ( $msg:expr => $span_to_blame:expr $(,)? ) => {
19     return Err(Error::new_spanned(&$span_to_blame, $msg))
20   };
21 }
22 
23 pub trait Derivable {
ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>24   fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>;
implies_trait(_crate_name: &TokenStream) -> Option<TokenStream>25   fn implies_trait(_crate_name: &TokenStream) -> Option<TokenStream> {
26     None
27   }
asserts( _input: &DeriveInput, _crate_name: &TokenStream, ) -> Result<TokenStream>28   fn asserts(
29     _input: &DeriveInput, _crate_name: &TokenStream,
30   ) -> Result<TokenStream> {
31     Ok(quote!())
32   }
check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()>33   fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
34     Ok(())
35   }
trait_impl( _input: &DeriveInput, _crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)>36   fn trait_impl(
37     _input: &DeriveInput, _crate_name: &TokenStream,
38   ) -> Result<(TokenStream, TokenStream)> {
39     Ok((quote!(), quote!()))
40   }
requires_where_clause() -> bool41   fn requires_where_clause() -> bool {
42     true
43   }
explicit_bounds_attribute_name() -> Option<&'static str>44   fn explicit_bounds_attribute_name() -> Option<&'static str> {
45     None
46   }
47 
48   /// If this trait has a custom meaning for "perfect derive", this function
49   /// should be overridden to return `Some`.
50   ///
51   /// The default is "the fields of a struct; unions and enums not supported".
perfect_derive_fields(_input: &DeriveInput) -> Option<Fields>52   fn perfect_derive_fields(_input: &DeriveInput) -> Option<Fields> {
53     None
54   }
55 }
56 
57 pub struct Pod;
58 
59 impl Derivable for Pod {
ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>60   fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
61     Ok(syn::parse_quote!(#crate_name::Pod))
62   }
63 
asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result<TokenStream>64   fn asserts(
65     input: &DeriveInput, crate_name: &TokenStream,
66   ) -> Result<TokenStream> {
67     let repr = get_repr(&input.attrs)?;
68 
69     let completly_packed =
70       repr.packed == Some(1) || repr.repr == Repr::Transparent;
71 
72     if !completly_packed && !input.generics.params.is_empty() {
73       bail!("\
74         Pod requires cannot be derived for non-packed types containing \
75         generic parameters because the padding requirements can't be verified \
76         for generic non-packed structs\
77       " => input.generics.params.first().unwrap());
78     }
79 
80     match &input.data {
81       Data::Struct(_) => {
82         let assert_no_padding = if !completly_packed {
83           Some(generate_assert_no_padding(input)?)
84         } else {
85           None
86         };
87         let assert_fields_are_pod = generate_fields_are_trait(
88           input,
89           None,
90           Self::ident(input, crate_name)?,
91         )?;
92 
93         Ok(quote!(
94           #assert_no_padding
95           #assert_fields_are_pod
96         ))
97       }
98       Data::Enum(_) => bail!("Deriving Pod is not supported for enums"),
99       Data::Union(_) => bail!("Deriving Pod is not supported for unions"),
100     }
101   }
102 
check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()>103   fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
104     let repr = get_repr(attributes)?;
105     match repr.repr {
106       Repr::C => Ok(()),
107       Repr::Transparent => Ok(()),
108       _ => {
109         bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
110       }
111     }
112   }
113 }
114 
115 pub struct AnyBitPattern;
116 
117 impl Derivable for AnyBitPattern {
ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>118   fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
119     Ok(syn::parse_quote!(#crate_name::AnyBitPattern))
120   }
121 
implies_trait(crate_name: &TokenStream) -> Option<TokenStream>122   fn implies_trait(crate_name: &TokenStream) -> Option<TokenStream> {
123     Some(quote!(#crate_name::Zeroable))
124   }
125 
asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result<TokenStream>126   fn asserts(
127     input: &DeriveInput, crate_name: &TokenStream,
128   ) -> Result<TokenStream> {
129     match &input.data {
130       Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
131       Data::Struct(_) => {
132         generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
133       }
134       Data::Enum(_) => {
135         bail!("Deriving AnyBitPattern is not supported for enums")
136       }
137     }
138   }
139 }
140 
141 pub struct Zeroable;
142 
143 /// Helper function to get the variant with discriminant zero (implicit or
144 /// explicit).
get_zero_variant(enum_: &DataEnum) -> Result<Option<&Variant>>145 fn get_zero_variant(enum_: &DataEnum) -> Result<Option<&Variant>> {
146   let iter = VariantDiscriminantIterator::new(enum_.variants.iter());
147   let mut zero_variant = None;
148   for res in iter {
149     let (discriminant, variant) = res?;
150     if discriminant == 0 {
151       zero_variant = Some(variant);
152       break;
153     }
154   }
155   Ok(zero_variant)
156 }
157 
158 impl Derivable for Zeroable {
ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>159   fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
160     Ok(syn::parse_quote!(#crate_name::Zeroable))
161   }
162 
check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()>163   fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
164     let repr = get_repr(attributes)?;
165     match ty {
166       Data::Struct(_) => Ok(()),
167       Data::Enum(_) => {
168         if !matches!(
169           repr.repr,
170           Repr::C | Repr::Integer(_) | Repr::CWithDiscriminant(_)
171         ) {
172           bail!("Zeroable requires the enum to be an explicit #[repr(Int)] and/or #[repr(C)]")
173         }
174 
175         // We ensure there is a zero variant in `asserts`, since it is needed
176         // there anyway.
177 
178         Ok(())
179       }
180       Data::Union(_) => Ok(()),
181     }
182   }
183 
asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result<TokenStream>184   fn asserts(
185     input: &DeriveInput, crate_name: &TokenStream,
186   ) -> Result<TokenStream> {
187     match &input.data {
188       Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
189       Data::Struct(_) => {
190         generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
191       }
192       Data::Enum(enum_) => {
193         let zero_variant = get_zero_variant(enum_)?;
194 
195         if zero_variant.is_none() {
196           bail!("No variant's discriminant is 0")
197         };
198 
199         generate_fields_are_trait(
200           input,
201           zero_variant,
202           Self::ident(input, crate_name)?,
203         )
204       }
205     }
206   }
207 
explicit_bounds_attribute_name() -> Option<&'static str>208   fn explicit_bounds_attribute_name() -> Option<&'static str> {
209     Some("zeroable")
210   }
211 
perfect_derive_fields(input: &DeriveInput) -> Option<Fields>212   fn perfect_derive_fields(input: &DeriveInput) -> Option<Fields> {
213     match &input.data {
214       Data::Struct(struct_) => Some(struct_.fields.clone()),
215       Data::Enum(enum_) => {
216         // We handle `Err` returns from `get_zero_variant` in `asserts`, so it's
217         // fine to just ignore them here and return `None`.
218         // Otherwise, we clone the `fields` of the zero variant (if any).
219         Some(get_zero_variant(enum_).ok()??.fields.clone())
220       }
221       Data::Union(_) => None,
222     }
223   }
224 }
225 
226 pub struct NoUninit;
227 
228 impl Derivable for NoUninit {
ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>229   fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
230     Ok(syn::parse_quote!(#crate_name::NoUninit))
231   }
232 
check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()>233   fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
234     let repr = get_repr(attributes)?;
235     match ty {
236       Data::Struct(_) => match repr.repr {
237         Repr::C | Repr::Transparent => Ok(()),
238         _ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
239       },
240       Data::Enum(_) => if repr.repr.is_integer() {
241         Ok(())
242       } else {
243         bail!("NoUninit requires the enum to be an explicit #[repr(Int)]")
244       },
245       Data::Union(_) => bail!("NoUninit can only be derived on enums and structs")
246     }
247   }
248 
asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result<TokenStream>249   fn asserts(
250     input: &DeriveInput, crate_name: &TokenStream,
251   ) -> Result<TokenStream> {
252     if !input.generics.params.is_empty() {
253       bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
254     }
255 
256     match &input.data {
257       Data::Struct(DataStruct { .. }) => {
258         let assert_no_padding = generate_assert_no_padding(&input)?;
259         let assert_fields_are_no_padding = generate_fields_are_trait(
260           &input,
261           None,
262           Self::ident(input, crate_name)?,
263         )?;
264 
265         Ok(quote!(
266             #assert_no_padding
267             #assert_fields_are_no_padding
268         ))
269       }
270       Data::Enum(DataEnum { variants, .. }) => {
271         if variants.iter().any(|variant| !variant.fields.is_empty()) {
272           bail!("Only fieldless enums are supported for NoUninit")
273         } else {
274           Ok(quote!())
275         }
276       }
277       Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */
278     }
279   }
280 
trait_impl( _input: &DeriveInput, _crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)>281   fn trait_impl(
282     _input: &DeriveInput, _crate_name: &TokenStream,
283   ) -> Result<(TokenStream, TokenStream)> {
284     Ok((quote!(), quote!()))
285   }
286 }
287 
288 pub struct CheckedBitPattern;
289 
290 impl Derivable for CheckedBitPattern {
ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>291   fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
292     Ok(syn::parse_quote!(#crate_name::CheckedBitPattern))
293   }
294 
check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()>295   fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
296     let repr = get_repr(attributes)?;
297     match ty {
298       Data::Struct(_) => match repr.repr {
299         Repr::C | Repr::Transparent => Ok(()),
300         _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
301       },
302       Data::Enum(DataEnum { variants,.. }) => {
303         if !enum_has_fields(variants.iter()){
304           if repr.repr.is_integer() {
305             Ok(())
306           } else {
307             bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
308           }
309         } else if matches!(repr.repr, Repr::Rust) {
310           bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
311         } else {
312           Ok(())
313         }
314       }
315       Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
316     }
317   }
318 
asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result<TokenStream>319   fn asserts(
320     input: &DeriveInput, crate_name: &TokenStream,
321   ) -> Result<TokenStream> {
322     if !input.generics.params.is_empty() {
323       bail!("CheckedBitPattern cannot be derived for structs containing generic parameters");
324     }
325 
326     match &input.data {
327       Data::Struct(DataStruct { .. }) => {
328         let assert_fields_are_maybe_pod = generate_fields_are_trait(
329           &input,
330           None,
331           Self::ident(input, crate_name)?,
332         )?;
333 
334         Ok(assert_fields_are_maybe_pod)
335       }
336       // nothing needed, already guaranteed OK by NoUninit.
337       Data::Enum(_) => Ok(quote!()),
338       Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
339     }
340   }
341 
trait_impl( input: &DeriveInput, crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)>342   fn trait_impl(
343     input: &DeriveInput, crate_name: &TokenStream,
344   ) -> Result<(TokenStream, TokenStream)> {
345     match &input.data {
346       Data::Struct(DataStruct { fields, .. }) => {
347         generate_checked_bit_pattern_struct(
348           &input.ident,
349           fields,
350           &input.attrs,
351           crate_name,
352         )
353       }
354       Data::Enum(DataEnum { variants, .. }) => {
355         generate_checked_bit_pattern_enum(input, variants, crate_name)
356       }
357       Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
358     }
359   }
360 }
361 
362 pub struct TransparentWrapper;
363 
364 impl TransparentWrapper {
get_wrapper_type( attributes: &[Attribute], fields: &Fields, ) -> Option<TokenStream>365   fn get_wrapper_type(
366     attributes: &[Attribute], fields: &Fields,
367   ) -> Option<TokenStream> {
368     let transparent_param = get_simple_attr(attributes, "transparent");
369     transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
370       let mut types = get_field_types(&fields);
371       let first_type = types.next();
372       if let Some(_) = types.next() {
373         // can't guess param type if there is more than one field
374         return None;
375       } else {
376         first_type.map(|ty| ty.to_token_stream())
377       }
378     })
379   }
380 }
381 
382 impl Derivable for TransparentWrapper {
ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>383   fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
384     let fields = get_struct_fields(input)?;
385 
386     let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
387       Some(ty) => ty,
388       None => bail!(
389         "\
390         when deriving TransparentWrapper for a struct with more than one field \
391         you need to specify the transparent field using #[transparent(T)]\
392       "
393       ),
394     };
395 
396     Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>))
397   }
398 
asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result<TokenStream>399   fn asserts(
400     input: &DeriveInput, crate_name: &TokenStream,
401   ) -> Result<TokenStream> {
402     let (impl_generics, _ty_generics, where_clause) =
403       input.generics.split_for_impl();
404     let fields = get_struct_fields(input)?;
405     let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
406       Some(wrapped_type) => wrapped_type.to_string(),
407       None => unreachable!(), /* other code will already reject this derive */
408     };
409     let mut wrapped_field_ty = None;
410     let mut nonwrapped_field_tys = vec![];
411     for field in fields.iter() {
412       let field_ty = &field.ty;
413       if field_ty.to_token_stream().to_string() == wrapped_type {
414         if wrapped_field_ty.is_some() {
415           bail!(
416             "TransparentWrapper can only have one field of the wrapped type"
417           );
418         }
419         wrapped_field_ty = Some(field_ty);
420       } else {
421         nonwrapped_field_tys.push(field_ty);
422       }
423     }
424     if let Some(wrapped_field_ty) = wrapped_field_ty {
425       Ok(quote!(
426         const _: () = {
427           #[repr(transparent)]
428           #[allow(clippy::multiple_bound_locations)]
429           struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
430           fn assert_zeroable<Z: #crate_name::Zeroable>() {}
431           #[allow(clippy::multiple_bound_locations)]
432           fn check #impl_generics () #where_clause {
433             #(
434               assert_zeroable::<#nonwrapped_field_tys>();
435             )*
436           }
437         };
438       ))
439     } else {
440       bail!("TransparentWrapper must have one field of the wrapped type")
441     }
442   }
443 
check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()>444   fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
445     let repr = get_repr(attributes)?;
446 
447     match repr.repr {
448       Repr::Transparent => Ok(()),
449       _ => {
450         bail!(
451           "TransparentWrapper requires the struct to be #[repr(transparent)]"
452         )
453       }
454     }
455   }
456 
requires_where_clause() -> bool457   fn requires_where_clause() -> bool {
458     false
459   }
460 }
461 
462 pub struct Contiguous;
463 
464 impl Derivable for Contiguous {
ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>465   fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
466     Ok(syn::parse_quote!(#crate_name::Contiguous))
467   }
468 
trait_impl( input: &DeriveInput, _crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)>469   fn trait_impl(
470     input: &DeriveInput, _crate_name: &TokenStream,
471   ) -> Result<(TokenStream, TokenStream)> {
472     let repr = get_repr(&input.attrs)?;
473 
474     let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
475       integer_ty
476     } else {
477       bail!("Contiguous requires the enum to be #[repr(Int)]");
478     };
479 
480     let variants = get_enum_variants(input)?;
481     if enum_has_fields(variants.clone()) {
482       return Err(Error::new_spanned(
483         &input,
484         "Only fieldless enums are supported",
485       ));
486     }
487 
488     let mut variants_with_discriminant =
489       VariantDiscriminantIterator::new(variants);
490 
491     let (min, max, count) = variants_with_discriminant.try_fold(
492       (i128::MAX, i128::MIN, 0),
493       |(min, max, count), res| {
494         let (discriminant, _variant) = res?;
495         Ok::<_, Error>((
496           i128::min(min, discriminant),
497           i128::max(max, discriminant),
498           count + 1,
499         ))
500       },
501     )?;
502 
503     if max - min != count - 1 {
504       bail! {
505         "Contiguous requires the enum discriminants to be contiguous",
506       }
507     }
508 
509     let min_lit = LitInt::new(&format!("{}", min), input.span());
510     let max_lit = LitInt::new(&format!("{}", max), input.span());
511 
512     // `from_integer` and `into_integer` are usually provided by the trait's
513     // default implementation. We override this implementation because it
514     // goes through `transmute_copy`, which can lead to inefficient assembly as seen in https://github.com/Lokathor/bytemuck/issues/175 .
515 
516     Ok((
517       quote!(),
518       quote! {
519           type Int = #integer_ty;
520 
521           #[allow(clippy::missing_docs_in_private_items)]
522           const MIN_VALUE: #integer_ty = #min_lit;
523 
524           #[allow(clippy::missing_docs_in_private_items)]
525           const MAX_VALUE: #integer_ty = #max_lit;
526 
527           #[inline]
528           fn from_integer(value: Self::Int) -> Option<Self> {
529             #[allow(clippy::manual_range_contains)]
530             if Self::MIN_VALUE <= value && value <= Self::MAX_VALUE {
531               Some(unsafe { ::core::mem::transmute(value) })
532             } else {
533               None
534             }
535           }
536 
537           #[inline]
538           fn into_integer(self) -> Self::Int {
539               self as #integer_ty
540           }
541       },
542     ))
543   }
544 }
545 
get_struct_fields(input: &DeriveInput) -> Result<&Fields>546 fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> {
547   if let Data::Struct(DataStruct { fields, .. }) = &input.data {
548     Ok(fields)
549   } else {
550     bail!("deriving this trait is only supported for structs")
551   }
552 }
553 
554 /// Extract the `Fields` off a `DeriveInput`, or, in the `enum` case, off
555 /// those of the `enum_variant`, when provided (e.g., for `Zeroable`).
556 ///
557 /// We purposely allow not providing an `enum_variant` for cases where
558 /// the caller wants to reject supporting `enum`s (e.g., `NoPadding`).
get_fields( input: &DeriveInput, enum_variant: Option<&Variant>, ) -> Result<Fields>559 fn get_fields(
560   input: &DeriveInput, enum_variant: Option<&Variant>,
561 ) -> Result<Fields> {
562   match &input.data {
563     Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
564     Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
565     Data::Enum(_) => match enum_variant {
566       Some(variant) => Ok(variant.fields.clone()),
567       None => bail!("deriving this trait is not supported for enums"),
568     },
569   }
570 }
571 
get_enum_variants<'a>( input: &'a DeriveInput, ) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a>572 fn get_enum_variants<'a>(
573   input: &'a DeriveInput,
574 ) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
575   if let Data::Enum(DataEnum { variants, .. }) = &input.data {
576     Ok(variants.iter())
577   } else {
578     bail!("deriving this trait is only supported for enums")
579   }
580 }
581 
get_field_types<'a>( fields: &'a Fields, ) -> impl Iterator<Item = &'a Type> + 'a582 fn get_field_types<'a>(
583   fields: &'a Fields,
584 ) -> impl Iterator<Item = &'a Type> + 'a {
585   fields.iter().map(|field| &field.ty)
586 }
587 
generate_checked_bit_pattern_struct( input_ident: &Ident, fields: &Fields, attrs: &[Attribute], crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)>588 fn generate_checked_bit_pattern_struct(
589   input_ident: &Ident, fields: &Fields, attrs: &[Attribute],
590   crate_name: &TokenStream,
591 ) -> Result<(TokenStream, TokenStream)> {
592   let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span());
593 
594   let repr = get_repr(attrs)?;
595 
596   let field_names = fields
597     .iter()
598     .enumerate()
599     .map(|(i, field)| {
600       field.ident.clone().unwrap_or_else(|| {
601         Ident::new(&format!("field{}", i), input_ident.span())
602       })
603     })
604     .collect::<Vec<_>>();
605   let field_tys = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
606 
607   let field_name = &field_names[..];
608   let field_ty = &field_tys[..];
609 
610   Ok((
611     quote! {
612         #[doc = #GENERATED_TYPE_DOCUMENTATION]
613         #repr
614         #[derive(Clone, Copy, #crate_name::AnyBitPattern)]
615         #[allow(missing_docs)]
616         pub struct #bits_ty {
617             #(#field_name: <#field_ty as #crate_name::CheckedBitPattern>::Bits,)*
618         }
619 
620         #[allow(unexpected_cfgs)]
621         const _: () = {
622           #[cfg(not(target_arch = "spirv"))]
623           impl ::core::fmt::Debug for #bits_ty {
624             fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
625               let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty));
626               #(::core::fmt::DebugStruct::field(&mut debug_struct, ::core::stringify!(#field_name), &self.#field_name);)*
627               ::core::fmt::DebugStruct::finish(&mut debug_struct)
628             }
629           }
630         };
631     },
632     quote! {
633         type Bits = #bits_ty;
634 
635         #[inline]
636         #[allow(clippy::double_comparisons, unused)]
637         fn is_valid_bit_pattern(bits: &#bits_ty) -> bool {
638             #(<#field_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(&{ bits.#field_name }) && )* true
639         }
640     },
641   ))
642 }
643 
generate_checked_bit_pattern_enum( input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>, crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)>644 fn generate_checked_bit_pattern_enum(
645   input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
646   crate_name: &TokenStream,
647 ) -> Result<(TokenStream, TokenStream)> {
648   if enum_has_fields(variants.iter()) {
649     generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name)
650   } else {
651     generate_checked_bit_pattern_enum_without_fields(input, variants)
652   }
653 }
654 
generate_checked_bit_pattern_enum_without_fields( input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>, ) -> Result<(TokenStream, TokenStream)>655 fn generate_checked_bit_pattern_enum_without_fields(
656   input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
657 ) -> Result<(TokenStream, TokenStream)> {
658   let span = input.span();
659   let mut variants_with_discriminant =
660     VariantDiscriminantIterator::new(variants.iter());
661 
662   let (min, max, count) = variants_with_discriminant.try_fold(
663     (i128::MAX, i128::MIN, 0),
664     |(min, max, count), res| {
665       let (discriminant, _variant) = res?;
666       Ok::<_, Error>((
667         i128::min(min, discriminant),
668         i128::max(max, discriminant),
669         count + 1,
670       ))
671     },
672   )?;
673 
674   let check = if count == 0 {
675     quote!(false)
676   } else if max - min == count - 1 {
677     // contiguous range
678     let min_lit = LitInt::new(&format!("{}", min), span);
679     let max_lit = LitInt::new(&format!("{}", max), span);
680 
681     quote!(*bits >= #min_lit && *bits <= #max_lit)
682   } else {
683     // not contiguous range, check for each
684     let variant_discriminant_lits =
685       VariantDiscriminantIterator::new(variants.iter())
686         .map(|res| {
687           let (discriminant, _variant) = res?;
688           Ok(LitInt::new(&format!("{}", discriminant), span))
689         })
690         .collect::<Result<Vec<_>>>()?;
691 
692     // count is at least 1
693     let first = &variant_discriminant_lits[0];
694     let rest = &variant_discriminant_lits[1..];
695 
696     quote!(matches!(*bits, #first #(| #rest )*))
697   };
698 
699   let repr = get_repr(&input.attrs)?;
700   let integer = repr.repr.as_integer().unwrap(); // should be checked in attr check already
701   Ok((
702     quote!(),
703     quote! {
704         type Bits = #integer;
705 
706         #[inline]
707         #[allow(clippy::double_comparisons)]
708         fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
709             #check
710         }
711     },
712   ))
713 }
714 
generate_checked_bit_pattern_enum_with_fields( input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>, crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)>715 fn generate_checked_bit_pattern_enum_with_fields(
716   input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
717   crate_name: &TokenStream,
718 ) -> Result<(TokenStream, TokenStream)> {
719   let representation = get_repr(&input.attrs)?;
720   let vis = &input.vis;
721 
722   match representation.repr {
723     Repr::Rust => unreachable!(),
724     repr @ (Repr::C | Repr::CWithDiscriminant(_)) => {
725       let integer = match repr {
726         Repr::C => quote!(::core::ffi::c_int),
727         Repr::CWithDiscriminant(integer) => quote!(#integer),
728         _ => unreachable!(),
729       };
730       let input_ident = &input.ident;
731 
732       let bits_repr = Representation { repr: Repr::C, ..representation };
733 
734       // the enum manually re-configured as the actual tagged union it
735       // represents, thus circumventing the requirements rust imposes on
736       // the tag even when using #[repr(C)] enum layout
737       // see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields
738       let bits_ty_ident =
739         Ident::new(&format!("{input_ident}Bits"), input.span());
740 
741       // the variants union part of the tagged union. These get put into a union
742       // which gets the AnyBitPattern derive applied to it, thus checking
743       // that the fields of the union obey the requriements of AnyBitPattern.
744       // The types that actually go in the union are one more level of
745       // indirection deep: we generate new structs for each variant
746       // (`variant_struct_definitions`) which themselves have the
747       // `CheckedBitPattern` derive applied, thus generating
748       // `{variant_struct_ident}Bits` structs, which are the ones that go
749       // into this union.
750       let variants_union_ident =
751         Ident::new(&format!("{}Variants", input.ident), input.span());
752 
753       let variant_struct_idents = variants.iter().map(|v| {
754         Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
755       });
756 
757       let variant_struct_definitions =
758         variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
759           let fields = v.fields.iter().map(|v| &v.ty);
760 
761           quote! {
762             #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
763             #[repr(C)]
764             #vis struct #variant_struct_ident(#(#fields),*);
765           }
766         });
767 
768       let union_fields = variant_struct_idents
769         .clone()
770         .zip(variants.iter())
771         .map(|(variant_struct_ident, v)| {
772           let variant_struct_bits_ident =
773             Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
774           let field_ident = &v.ident;
775           quote! {
776             #field_ident: #variant_struct_bits_ident
777           }
778         });
779 
780       let variant_checks = variant_struct_idents
781         .clone()
782         .zip(VariantDiscriminantIterator::new(variants.iter()))
783         .zip(variants.iter())
784         .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
785           let (discriminant, _variant) = discriminant?;
786           let discriminant = LitInt::new(&discriminant.to_string(), v.span());
787           let ident = &v.ident;
788           Ok(quote! {
789             #discriminant => {
790               let payload = unsafe { &bits.payload.#ident };
791               <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
792             }
793           })
794         })
795         .collect::<Result<Vec<_>>>()?;
796 
797       Ok((
798         quote! {
799           #[doc = #GENERATED_TYPE_DOCUMENTATION]
800           #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
801           #bits_repr
802           #vis struct #bits_ty_ident {
803             tag: #integer,
804             payload: #variants_union_ident,
805           }
806 
807           #[allow(unexpected_cfgs)]
808           const _: () = {
809             #[cfg(not(target_arch = "spirv"))]
810             impl ::core::fmt::Debug for #bits_ty_ident {
811               fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
812                 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
813                 ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", &self.tag);
814                 ::core::fmt::DebugStruct::field(&mut debug_struct, "payload", &self.payload);
815                 ::core::fmt::DebugStruct::finish(&mut debug_struct)
816               }
817             }
818           };
819 
820           #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
821           #[repr(C)]
822           #[allow(non_snake_case)]
823           #vis union #variants_union_ident {
824             #(#union_fields,)*
825           }
826 
827           #[allow(unexpected_cfgs)]
828           const _: () = {
829             #[cfg(not(target_arch = "spirv"))]
830             impl ::core::fmt::Debug for #variants_union_ident {
831               fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
832                 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
833                 ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
834               }
835             }
836           };
837 
838           #(#variant_struct_definitions)*
839         },
840         quote! {
841           type Bits = #bits_ty_ident;
842 
843           #[inline]
844           #[allow(clippy::double_comparisons)]
845           fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
846             match bits.tag {
847               #(#variant_checks)*
848               _ => false,
849             }
850           }
851         },
852       ))
853     }
854     Repr::Transparent => {
855       if variants.len() != 1 {
856         bail!("enums with more than one variant cannot be transparent")
857       }
858 
859       let variant = &variants[0];
860 
861       let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
862       let fields = variant.fields.iter().map(|v| &v.ty);
863 
864       Ok((
865         quote! {
866           #[doc = #GENERATED_TYPE_DOCUMENTATION]
867           #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
868           #[repr(C)]
869           #vis struct #bits_ty(#(#fields),*);
870         },
871         quote! {
872           type Bits = <#bits_ty as #crate_name::CheckedBitPattern>::Bits;
873 
874           #[inline]
875           #[allow(clippy::double_comparisons)]
876           fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
877             <#bits_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(bits)
878           }
879         },
880       ))
881     }
882     Repr::Integer(integer) => {
883       let bits_repr = Representation { repr: Repr::C, ..representation };
884       let input_ident = &input.ident;
885 
886       // the enum manually re-configured as the union it represents. such a
887       // union is the union of variants as a repr(c) struct with the
888       // discriminator type inserted at the beginning. in our case we
889       // union the `Bits` representation of each variant rather than the variant
890       // itself, which we generate via a nested `CheckedBitPattern` derive
891       // on the `variant_struct_definitions` generated below.
892       //
893       // see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields
894       let bits_ty_ident =
895         Ident::new(&format!("{input_ident}Bits"), input.span());
896 
897       let variant_struct_idents = variants.iter().map(|v| {
898         Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
899       });
900 
901       let variant_struct_definitions =
902         variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
903           let fields = v.fields.iter().map(|v| &v.ty);
904 
905           // adding the discriminant repr integer as first field, as described above
906           quote! {
907             #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
908             #[repr(C)]
909             #vis struct #variant_struct_ident(#integer, #(#fields),*);
910           }
911         });
912 
913       let union_fields = variant_struct_idents
914         .clone()
915         .zip(variants.iter())
916         .map(|(variant_struct_ident, v)| {
917           let variant_struct_bits_ident =
918             Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
919           let field_ident = &v.ident;
920           quote! {
921             #field_ident: #variant_struct_bits_ident
922           }
923         });
924 
925       let variant_checks = variant_struct_idents
926         .clone()
927         .zip(VariantDiscriminantIterator::new(variants.iter()))
928         .zip(variants.iter())
929         .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
930           let (discriminant, _variant) = discriminant?;
931           let discriminant = LitInt::new(&discriminant.to_string(), v.span());
932           let ident = &v.ident;
933           Ok(quote! {
934             #discriminant => {
935               let payload = unsafe { &bits.#ident };
936               <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
937             }
938           })
939         })
940         .collect::<Result<Vec<_>>>()?;
941 
942       Ok((
943         quote! {
944           #[doc = #GENERATED_TYPE_DOCUMENTATION]
945           #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
946           #bits_repr
947           #[allow(non_snake_case)]
948           #vis union #bits_ty_ident {
949             __tag: #integer,
950             #(#union_fields,)*
951           }
952 
953           #[allow(unexpected_cfgs)]
954           const _: () = {
955             #[cfg(not(target_arch = "spirv"))]
956             impl ::core::fmt::Debug for #bits_ty_ident {
957               fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
958                 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
959                 ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
960                 ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
961               }
962             }
963           };
964 
965           #(#variant_struct_definitions)*
966         },
967         quote! {
968           type Bits = #bits_ty_ident;
969 
970           #[inline]
971           #[allow(clippy::double_comparisons)]
972           fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
973             match unsafe { bits.__tag } {
974               #(#variant_checks)*
975               _ => false,
976             }
977           }
978         },
979       ))
980     }
981   }
982 }
983 
984 /// Check that a struct has no padding by asserting that the size of the struct
985 /// is equal to the sum of the size of it's fields
generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream>986 fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
987   let struct_type = &input.ident;
988   let enum_variant = None; // `no padding` check is not supported for `enum`s yet.
989   let fields = get_fields(input, enum_variant)?;
990 
991   let mut field_types = get_field_types(&fields);
992   let size_sum = if let Some(first) = field_types.next() {
993     let size_first = quote!(::core::mem::size_of::<#first>());
994     let size_rest = quote!(#( + ::core::mem::size_of::<#field_types>() )*);
995 
996     quote!(#size_first #size_rest)
997   } else {
998     quote!(0)
999   };
1000 
1001   Ok(quote! {const _: fn() = || {
1002     #[doc(hidden)]
1003     struct TypeWithoutPadding([u8; #size_sum]);
1004     let _ = ::core::mem::transmute::<#struct_type, TypeWithoutPadding>;
1005   };})
1006 }
1007 
1008 /// Check that all fields implement a given trait
generate_fields_are_trait( input: &DeriveInput, enum_variant: Option<&Variant>, trait_: syn::Path, ) -> Result<TokenStream>1009 fn generate_fields_are_trait(
1010   input: &DeriveInput, enum_variant: Option<&Variant>, trait_: syn::Path,
1011 ) -> Result<TokenStream> {
1012   let (impl_generics, _ty_generics, where_clause) =
1013     input.generics.split_for_impl();
1014   let fields = get_fields(input, enum_variant)?;
1015   let field_types = get_field_types(&fields);
1016   Ok(quote! {#(const _: fn() = || {
1017       #[allow(clippy::missing_const_for_fn)]
1018       #[doc(hidden)]
1019       fn check #impl_generics () #where_clause {
1020         fn assert_impl<T: #trait_>() {}
1021         assert_impl::<#field_types>();
1022       }
1023     };)*
1024   })
1025 }
1026 
get_ident_from_stream(tokens: TokenStream) -> Option<Ident>1027 fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
1028   match tokens.into_iter().next() {
1029     Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
1030     Some(TokenTree::Ident(ident)) => Some(ident),
1031     _ => None,
1032   }
1033 }
1034 
1035 /// get a simple #[foo(bar)] attribute, returning "bar"
get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident>1036 fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
1037   for attr in attributes {
1038     if let (AttrStyle::Outer, Meta::List(list)) = (&attr.style, &attr.meta) {
1039       if list.path.is_ident(attr_name) {
1040         if let Some(ident) = get_ident_from_stream(list.tokens.clone()) {
1041           return Some(ident);
1042         }
1043       }
1044     }
1045   }
1046 
1047   None
1048 }
1049 
get_repr(attributes: &[Attribute]) -> Result<Representation>1050 fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
1051   attributes
1052     .iter()
1053     .filter_map(|attr| {
1054       if attr.path().is_ident("repr") {
1055         Some(attr.parse_args::<Representation>())
1056       } else {
1057         None
1058       }
1059     })
1060     .try_fold(Representation::default(), |a, b| {
1061       let b = b?;
1062       Ok(Representation {
1063         repr: match (a.repr, b.repr) {
1064           (a, Repr::Rust) => a,
1065           (Repr::Rust, b) => b,
1066           _ => bail!("conflicting representation hints"),
1067         },
1068         packed: match (a.packed, b.packed) {
1069           (a, None) => a,
1070           (None, b) => b,
1071           _ => bail!("conflicting representation hints"),
1072         },
1073         align: match (a.align, b.align) {
1074           (Some(a), Some(b)) => Some(cmp::max(a, b)),
1075           (a, None) => a,
1076           (None, b) => b,
1077         },
1078       })
1079     })
1080 }
1081 
1082 mk_repr! {
1083   U8 => u8,
1084   I8 => i8,
1085   U16 => u16,
1086   I16 => i16,
1087   U32 => u32,
1088   I32 => i32,
1089   U64 => u64,
1090   I64 => i64,
1091   I128 => i128,
1092   U128 => u128,
1093   Usize => usize,
1094   Isize => isize,
1095 }
1096 // where
1097 macro_rules! mk_repr {(
1098   $(
1099     $Xn:ident => $xn:ident
1100   ),* $(,)?
1101 ) => (
1102   #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1103   enum IntegerRepr {
1104     $($Xn),*
1105   }
1106 
1107   impl<'a> TryFrom<&'a str> for IntegerRepr {
1108     type Error = &'a str;
1109 
1110     fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
1111       match value {
1112         $(
1113           stringify!($xn) => Ok(Self::$Xn),
1114         )*
1115         _ => Err(value),
1116       }
1117     }
1118   }
1119 
1120   impl ToTokens for IntegerRepr {
1121     fn to_tokens(&self, tokens: &mut TokenStream) {
1122       match self {
1123         $(
1124           Self::$Xn => tokens.extend(quote!($xn)),
1125         )*
1126       }
1127     }
1128   }
1129 )}
1130 use mk_repr;
1131 
1132 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1133 enum Repr {
1134   Rust,
1135   C,
1136   Transparent,
1137   Integer(IntegerRepr),
1138   CWithDiscriminant(IntegerRepr),
1139 }
1140 
1141 impl Repr {
is_integer(&self) -> bool1142   fn is_integer(&self) -> bool {
1143     matches!(self, Self::Integer(..))
1144   }
1145 
as_integer(&self) -> Option<IntegerRepr>1146   fn as_integer(&self) -> Option<IntegerRepr> {
1147     if let Self::Integer(v) = self {
1148       Some(*v)
1149     } else {
1150       None
1151     }
1152   }
1153 }
1154 
1155 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1156 struct Representation {
1157   packed: Option<u32>,
1158   align: Option<u32>,
1159   repr: Repr,
1160 }
1161 
1162 impl Default for Representation {
default() -> Self1163   fn default() -> Self {
1164     Self { packed: None, align: None, repr: Repr::Rust }
1165   }
1166 }
1167 
1168 impl Parse for Representation {
parse(input: ParseStream<'_>) -> Result<Representation>1169   fn parse(input: ParseStream<'_>) -> Result<Representation> {
1170     let mut ret = Representation::default();
1171     while !input.is_empty() {
1172       let keyword = input.parse::<Ident>()?;
1173       // preëmptively call `.to_string()` *once* (rather than on `is_ident()`)
1174       let keyword_str = keyword.to_string();
1175       let new_repr = match keyword_str.as_str() {
1176         "C" => Repr::C,
1177         "transparent" => Repr::Transparent,
1178         "packed" => {
1179           ret.packed = Some(if input.peek(token::Paren) {
1180             let contents;
1181             parenthesized!(contents in input);
1182             LitInt::base10_parse::<u32>(&contents.parse()?)?
1183           } else {
1184             1
1185           });
1186           let _: Option<Token![,]> = input.parse()?;
1187           continue;
1188         }
1189         "align" => {
1190           let contents;
1191           parenthesized!(contents in input);
1192           let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
1193           ret.align = Some(
1194             ret
1195               .align
1196               .map_or(new_align, |old_align| cmp::max(old_align, new_align)),
1197           );
1198           let _: Option<Token![,]> = input.parse()?;
1199           continue;
1200         }
1201         ident => {
1202           let primitive = IntegerRepr::try_from(ident)
1203             .map_err(|_| input.error("unrecognized representation hint"))?;
1204           Repr::Integer(primitive)
1205         }
1206       };
1207       ret.repr = match (ret.repr, new_repr) {
1208         (Repr::Rust, new_repr) => {
1209           // This is the first explicit repr.
1210           new_repr
1211         }
1212         (Repr::C, Repr::Integer(integer))
1213         | (Repr::Integer(integer), Repr::C) => {
1214           // Both the C repr and an integer repr have been specified
1215           // -> merge into a C wit discriminant.
1216           Repr::CWithDiscriminant(integer)
1217         }
1218         (_, _) => {
1219           return Err(input.error("duplicate representation hint"));
1220         }
1221       };
1222       let _: Option<Token![,]> = input.parse()?;
1223     }
1224     Ok(ret)
1225   }
1226 }
1227 
1228 impl ToTokens for Representation {
to_tokens(&self, tokens: &mut TokenStream)1229   fn to_tokens(&self, tokens: &mut TokenStream) {
1230     let mut meta = Punctuated::<_, Token![,]>::new();
1231 
1232     match self.repr {
1233       Repr::Rust => {}
1234       Repr::C => meta.push(quote!(C)),
1235       Repr::Transparent => meta.push(quote!(transparent)),
1236       Repr::Integer(primitive) => meta.push(quote!(#primitive)),
1237       Repr::CWithDiscriminant(primitive) => {
1238         meta.push(quote!(C));
1239         meta.push(quote!(#primitive));
1240       }
1241     }
1242 
1243     if let Some(packed) = self.packed.as_ref() {
1244       let lit = LitInt::new(&packed.to_string(), Span::call_site());
1245       meta.push(quote!(packed(#lit)));
1246     }
1247 
1248     if let Some(align) = self.align.as_ref() {
1249       let lit = LitInt::new(&align.to_string(), Span::call_site());
1250       meta.push(quote!(align(#lit)));
1251     }
1252 
1253     tokens.extend(quote!(
1254       #[repr(#meta)]
1255     ));
1256   }
1257 }
1258 
enum_has_fields<'a>( mut variants: impl Iterator<Item = &'a Variant>, ) -> bool1259 fn enum_has_fields<'a>(
1260   mut variants: impl Iterator<Item = &'a Variant>,
1261 ) -> bool {
1262   variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
1263 }
1264 
1265 struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
1266   inner: I,
1267   last_value: i128,
1268 }
1269 
1270 impl<'a, I: Iterator<Item = &'a Variant> + 'a>
1271   VariantDiscriminantIterator<'a, I>
1272 {
new(inner: I) -> Self1273   fn new(inner: I) -> Self {
1274     VariantDiscriminantIterator { inner, last_value: -1 }
1275   }
1276 }
1277 
1278 impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
1279   for VariantDiscriminantIterator<'a, I>
1280 {
1281   type Item = Result<(i128, &'a Variant)>;
1282 
next(&mut self) -> Option<Self::Item>1283   fn next(&mut self) -> Option<Self::Item> {
1284     let variant = self.inner.next()?;
1285 
1286     if let Some((_, discriminant)) = &variant.discriminant {
1287       let discriminant_value = match parse_int_expr(discriminant) {
1288         Ok(value) => value,
1289         Err(e) => return Some(Err(e)),
1290       };
1291       self.last_value = discriminant_value;
1292     } else {
1293       // If this wraps, then either:
1294       // 1. the enum is using repr(u128), so wrapping is correct
1295       // 2. the enum is using repr(i<=128 or u<128), so the compiler will
1296       //    already emit a "wrapping discriminant" E0370 error.
1297       self.last_value = self.last_value.wrapping_add(1);
1298       // Static assert that there is no integer repr > 128 bits. If that
1299       // changes, the above comment is inaccurate and needs to be updated!
1300       // FIXME(zachs18): maybe should also do something to ensure `isize::BITS
1301       // <= 128`?
1302       if let Some(repr) = None::<IntegerRepr> {
1303         match repr {
1304           IntegerRepr::U8
1305           | IntegerRepr::I8
1306           | IntegerRepr::U16
1307           | IntegerRepr::I16
1308           | IntegerRepr::U32
1309           | IntegerRepr::I32
1310           | IntegerRepr::U64
1311           | IntegerRepr::I64
1312           | IntegerRepr::I128
1313           | IntegerRepr::U128
1314           | IntegerRepr::Usize
1315           | IntegerRepr::Isize => (),
1316         }
1317       }
1318     }
1319 
1320     Some(Ok((self.last_value, variant)))
1321   }
1322 }
1323 
parse_int_expr(expr: &Expr) -> Result<i128>1324 fn parse_int_expr(expr: &Expr) -> Result<i128> {
1325   match expr {
1326     Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
1327       parse_int_expr(expr).map(|int| -int)
1328     }
1329     Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(),
1330     Expr::Lit(ExprLit { lit: Lit::Byte(byte), .. }) => Ok(byte.value().into()),
1331     _ => bail!("Not an integer expression"),
1332   }
1333 }
1334 
1335 #[cfg(test)]
1336 mod tests {
1337   use syn::parse_quote;
1338 
1339   use super::{get_repr, IntegerRepr, Repr, Representation};
1340 
1341   #[test]
parse_basic_repr()1342   fn parse_basic_repr() {
1343     let attr = parse_quote!(#[repr(C)]);
1344     let repr = get_repr(&[attr]).unwrap();
1345     assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
1346 
1347     let attr = parse_quote!(#[repr(transparent)]);
1348     let repr = get_repr(&[attr]).unwrap();
1349     assert_eq!(
1350       repr,
1351       Representation { repr: Repr::Transparent, ..Default::default() }
1352     );
1353 
1354     let attr = parse_quote!(#[repr(u8)]);
1355     let repr = get_repr(&[attr]).unwrap();
1356     assert_eq!(
1357       repr,
1358       Representation {
1359         repr: Repr::Integer(IntegerRepr::U8),
1360         ..Default::default()
1361       }
1362     );
1363 
1364     let attr = parse_quote!(#[repr(packed)]);
1365     let repr = get_repr(&[attr]).unwrap();
1366     assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1367 
1368     let attr = parse_quote!(#[repr(packed(1))]);
1369     let repr = get_repr(&[attr]).unwrap();
1370     assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1371 
1372     let attr = parse_quote!(#[repr(packed(2))]);
1373     let repr = get_repr(&[attr]).unwrap();
1374     assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
1375 
1376     let attr = parse_quote!(#[repr(align(2))]);
1377     let repr = get_repr(&[attr]).unwrap();
1378     assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
1379   }
1380 
1381   #[test]
parse_advanced_repr()1382   fn parse_advanced_repr() {
1383     let attr = parse_quote!(#[repr(align(4), align(2))]);
1384     let repr = get_repr(&[attr]).unwrap();
1385     assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1386 
1387     let attr1 = parse_quote!(#[repr(align(1))]);
1388     let attr2 = parse_quote!(#[repr(align(4))]);
1389     let attr3 = parse_quote!(#[repr(align(2))]);
1390     let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
1391     assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1392 
1393     let attr = parse_quote!(#[repr(C, u8)]);
1394     let repr = get_repr(&[attr]).unwrap();
1395     assert_eq!(
1396       repr,
1397       Representation {
1398         repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1399         ..Default::default()
1400       }
1401     );
1402 
1403     let attr = parse_quote!(#[repr(u8, C)]);
1404     let repr = get_repr(&[attr]).unwrap();
1405     assert_eq!(
1406       repr,
1407       Representation {
1408         repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1409         ..Default::default()
1410       }
1411     );
1412   }
1413 }
1414 
bytemuck_crate_name(input: &DeriveInput) -> TokenStream1415 pub fn bytemuck_crate_name(input: &DeriveInput) -> TokenStream {
1416   const ATTR_NAME: &'static str = "crate";
1417 
1418   let mut crate_name = quote!(::bytemuck);
1419   for attr in &input.attrs {
1420     if !attr.path().is_ident("bytemuck") {
1421       continue;
1422     }
1423 
1424     attr.parse_nested_meta(|meta| {
1425       if meta.path.is_ident(ATTR_NAME) {
1426         let expr: syn::Expr = meta.value()?.parse()?;
1427         let mut value = &expr;
1428         while let syn::Expr::Group(e) = value {
1429           value = &e.expr;
1430         }
1431         if let syn::Expr::Lit(syn::ExprLit {
1432           lit: syn::Lit::Str(lit), ..
1433         }) = value
1434         {
1435           let suffix = lit.suffix();
1436           if !suffix.is_empty() {
1437             bail!(format!("Unexpected suffix `{}` on string literal", suffix))
1438           }
1439           let path: syn::Path = match lit.parse() {
1440             Ok(path) => path,
1441             Err(_) => {
1442               bail!(format!("Failed to parse path: {:?}", lit.value()))
1443             }
1444           };
1445           crate_name = path.into_token_stream();
1446         } else {
1447           bail!(
1448             "Expected bytemuck `crate` attribute to be a string: `crate = \"...\"`",
1449           )
1450         }
1451       }
1452       Ok(())
1453     }).unwrap();
1454   }
1455 
1456   return crate_name;
1457 }
1458 
1459 const GENERATED_TYPE_DOCUMENTATION: &str =
1460   " `bytemuck`-generated type for internal purposes only.";
1461