• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Fuchsia Authors
2 //
3 // Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
4 // <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
6 // This file may not be copied, modified, or distributed except according to
7 // those terms.
8 
9 use proc_macro2::{Span, TokenStream};
10 use quote::quote;
11 use syn::{parse_quote, DataEnum, Error, Fields, Generics, Ident};
12 
13 use crate::{derive_try_from_bytes_inner, repr::EnumRepr, Trait};
14 
15 /// Generates a tag enum for the given enum. This generates an enum with the
16 /// same non-align `repr`s, variants, and corresponding discriminants, but none
17 /// of the fields.
generate_tag_enum(repr: &EnumRepr, data: &DataEnum) -> TokenStream18 pub(crate) fn generate_tag_enum(repr: &EnumRepr, data: &DataEnum) -> TokenStream {
19     let variants = data.variants.iter().map(|v| {
20         let ident = &v.ident;
21         if let Some((eq, discriminant)) = &v.discriminant {
22             quote! { #ident #eq #discriminant }
23         } else {
24             quote! { #ident }
25         }
26     });
27 
28     // Don't include any `repr(align)` when generating the tag enum, as that
29     // could add padding after the tag but before any variants, which is not the
30     // correct behavior.
31     let repr = match repr {
32         EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
33         EnumRepr::Compound(c, _) => quote! { #c },
34     };
35 
36     quote! {
37         #repr
38         #[allow(dead_code, non_camel_case_types)]
39         enum ___ZerocopyTag {
40             #(#variants,)*
41         }
42     }
43 }
44 
tag_ident(variant_ident: &Ident) -> Ident45 fn tag_ident(variant_ident: &Ident) -> Ident {
46     Ident::new(&format!("___ZEROCOPY_TAG_{}", variant_ident), variant_ident.span())
47 }
48 
49 /// Generates a constant for the tag associated with each variant of the enum.
50 /// When we match on the enum's tag, each arm matches one of these constants. We
51 /// have to use constants here because:
52 ///
53 /// - The type that we're matching on is not the type of the tag, it's an
54 ///   integer of the same size as the tag type and with the same bit patterns.
55 /// - We can't read the enum tag as an enum because the bytes may not represent
56 ///   a valid variant.
57 /// - Patterns do not currently support const expressions, so we have to assign
58 ///   these constants to names rather than use them inline in the `match`
59 ///   statement.
generate_tag_consts(data: &DataEnum) -> TokenStream60 fn generate_tag_consts(data: &DataEnum) -> TokenStream {
61     let tags = data.variants.iter().map(|v| {
62         let variant_ident = &v.ident;
63         let tag_ident = tag_ident(variant_ident);
64 
65         quote! {
66             // This casts the enum variant to its discriminant, and then
67             // converts the discriminant to the target integral type via a
68             // numeric cast [1].
69             //
70             // Because these are the same size, this is defined to be a no-op
71             // and therefore is a lossless conversion [2].
72             //
73             // [1]: https://doc.rust-lang.org/stable/reference/expressions/operator-expr.html#enum-cast
74             // [2]: https://doc.rust-lang.org/stable/reference/expressions/operator-expr.html#numeric-cast
75             #[allow(non_upper_case_globals)]
76             const #tag_ident: ___ZerocopyTagPrimitive =
77                 ___ZerocopyTag::#variant_ident as ___ZerocopyTagPrimitive;
78         }
79     });
80 
81     quote! {
82         #(#tags)*
83     }
84 }
85 
variant_struct_ident(variant_ident: &Ident) -> Ident86 fn variant_struct_ident(variant_ident: &Ident) -> Ident {
87     Ident::new(&format!("___ZerocopyVariantStruct_{}", variant_ident), variant_ident.span())
88 }
89 
90 /// Generates variant structs for the given enum variant.
91 ///
92 /// These are structs associated with each variant of an enum. They are
93 /// `repr(C)` tuple structs with the same fields as the variant after a
94 /// `MaybeUninit<___ZerocopyInnerTag>`.
95 ///
96 /// In order to unify the generated types for `repr(C)` and `repr(int)` enums,
97 /// we use a "fused" representation with fields for both an inner tag and an
98 /// outer tag. Depending on the repr, we will set one of these tags to the tag
99 /// type and the other to `()`. This lets us generate the same code but put the
100 /// tags in different locations.
generate_variant_structs( enum_name: &Ident, generics: &Generics, data: &DataEnum, ) -> TokenStream101 fn generate_variant_structs(
102     enum_name: &Ident,
103     generics: &Generics,
104     data: &DataEnum,
105 ) -> TokenStream {
106     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
107 
108     // All variant structs have a `PhantomData<MyEnum<...>>` field because we
109     // don't know which generic parameters each variant will use, and unused
110     // generic parameters are a compile error.
111     let phantom_ty = quote! {
112         core_reexport::marker::PhantomData<#enum_name #ty_generics>
113     };
114 
115     let variant_structs = data.variants.iter().filter_map(|variant| {
116         // We don't generate variant structs for unit variants because we only
117         // need to check the tag. This helps cut down our generated code a bit.
118         if matches!(variant.fields, Fields::Unit) {
119             return None;
120         }
121 
122         let variant_struct_ident = variant_struct_ident(&variant.ident);
123         let field_types = variant.fields.iter().map(|f| &f.ty);
124 
125         let variant_struct = parse_quote! {
126             #[repr(C)]
127             #[allow(non_snake_case)]
128             struct #variant_struct_ident #impl_generics (
129                 core_reexport::mem::MaybeUninit<___ZerocopyInnerTag>,
130                 #(#field_types,)*
131                 #phantom_ty,
132             ) #where_clause;
133         };
134 
135         // We do this rather than emitting `#[derive(::zerocopy::TryFromBytes)]`
136         // because that is not hygienic, and this is also more performant.
137         let try_from_bytes_impl = derive_try_from_bytes_inner(&variant_struct, Trait::TryFromBytes)
138             .expect("derive_try_from_bytes_inner should not fail on synthesized type");
139 
140         Some(quote! {
141             #variant_struct
142             #try_from_bytes_impl
143         })
144     });
145 
146     quote! {
147         #(#variant_structs)*
148     }
149 }
150 
generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream151 fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream {
152     let (_, ty_generics, _) = generics.split_for_impl();
153 
154     let fields = data.variants.iter().filter_map(|variant| {
155         // We don't generate variant structs for unit variants because we only
156         // need to check the tag. This helps cut down our generated code a bit.
157         if matches!(variant.fields, Fields::Unit) {
158             return None;
159         }
160 
161         // Field names are prefixed with `__field_` to prevent name collision with
162         // the `__nonempty` field.
163         let field_name = Ident::new(&format!("__field_{}", &variant.ident), variant.ident.span());
164         let variant_struct_ident = variant_struct_ident(&variant.ident);
165 
166         Some(quote! {
167             #field_name: core_reexport::mem::ManuallyDrop<
168                 #variant_struct_ident #ty_generics
169             >,
170         })
171     });
172 
173     quote! {
174         #[repr(C)]
175         #[allow(non_snake_case)]
176         union ___ZerocopyVariants #generics {
177             #(#fields)*
178             // Enums can have variants with no fields, but unions must
179             // have at least one field. So we just add a trailing unit
180             // to ensure that this union always has at least one field.
181             // Because this union is `repr(C)`, this unit type does not
182             // affect the layout.
183             __nonempty: (),
184         }
185     }
186 }
187 
188 /// Generates an implementation of `is_bit_valid` for an arbitrary enum.
189 ///
190 /// The general process is:
191 ///
192 /// 1. Generate a tag enum. This is an enum with the same repr, variants, and
193 ///    corresponding discriminants as the original enum, but without any fields
194 ///    on the variants. This gives us access to an enum where the variants have
195 ///    the same discriminants as the one we're writing `is_bit_valid` for.
196 /// 2. Make constants from the variants of the tag enum. We need these because
197 ///    we can't put const exprs in match arms.
198 /// 3. Generate variant structs. These are structs which have the same fields as
199 ///    each variant of the enum, and are `#[repr(C)]` with an optional "inner
200 ///    tag".
201 /// 4. Generate a variants union, with one field for each variant struct type.
202 /// 5. And finally, our raw enum is a `#[repr(C)]` struct of an "outer tag" and
203 ///    the variants union.
204 ///
205 /// See these reference links for fully-worked example decompositions.
206 ///
207 /// - `repr(C)`: <https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields>
208 /// - `repr(int)`: <https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields>
209 /// - `repr(C, int)`: <https://doc.rust-lang.org/reference/type-layout.html#combining-primitive-representations-of-enums-with-fields-and-reprc>
derive_is_bit_valid( enum_ident: &Ident, repr: &EnumRepr, generics: &Generics, data: &DataEnum, ) -> Result<TokenStream, Error>210 pub(crate) fn derive_is_bit_valid(
211     enum_ident: &Ident,
212     repr: &EnumRepr,
213     generics: &Generics,
214     data: &DataEnum,
215 ) -> Result<TokenStream, Error> {
216     let trait_path = Trait::TryFromBytes.crate_path();
217     let tag_enum = generate_tag_enum(repr, data);
218     let tag_consts = generate_tag_consts(data);
219 
220     let (outer_tag_type, inner_tag_type) = if repr.is_c() {
221         (quote! { ___ZerocopyTag }, quote! { () })
222     } else if repr.is_primitive() {
223         (quote! { () }, quote! { ___ZerocopyTag })
224     } else {
225         return Err(Error::new(
226             Span::call_site(),
227             "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
228         ));
229     };
230 
231     let variant_structs = generate_variant_structs(enum_ident, generics, data);
232     let variants_union = generate_variants_union(generics, data);
233 
234     let (_, ty_generics, _) = generics.split_for_impl();
235 
236     let match_arms = data.variants.iter().map(|variant| {
237         let tag_ident = tag_ident(&variant.ident);
238         let variant_struct_ident = variant_struct_ident(&variant.ident);
239 
240         if matches!(variant.fields, Fields::Unit) {
241             // Unit variants don't need any further validation beyond checking
242             // the tag.
243             quote! {
244                 #tag_ident => true
245             }
246         } else {
247             quote! {
248                 #tag_ident => {
249                     // SAFETY:
250                     // - This cast is from a `repr(C)` union which has a field
251                     //   of type `variant_struct_ident` to that variant struct
252                     //   type itself. This addresses a subset of the bytes
253                     //   addressed by `variants`.
254                     // - The returned pointer is cast from `p`, and so has the
255                     //   same provenance as `p`.
256                     // - We checked that the tag of the enum matched the
257                     //   constant for this variant, so this cast preserves
258                     //   types and locations of all fields. Therefore, any
259                     //   `UnsafeCell`s will have the same location as in the
260                     //   original type.
261                     let variant = unsafe {
262                         variants.cast_unsized_unchecked(
263                             |p: *mut ___ZerocopyVariants #ty_generics| {
264                                 p as *mut #variant_struct_ident #ty_generics
265                             }
266                         )
267                     };
268                     // SAFETY: `cast_unsized_unchecked` removes the
269                     // initialization invariant from `p`, so we re-assert that
270                     // all of the bytes are initialized.
271                     let variant = unsafe { variant.assume_initialized() };
272                     <
273                         #variant_struct_ident #ty_generics as #trait_path
274                     >::is_bit_valid(variant)
275                 }
276             }
277         }
278     });
279 
280     Ok(quote! {
281         // SAFETY: We use `is_bit_valid` to validate that the bit pattern of the
282         // enum's tag corresponds to one of the enum's discriminants. Then, we
283         // check the bit validity of each field of the corresponding variant.
284         // Thus, this is a sound implementation of `is_bit_valid`.
285         fn is_bit_valid<___ZerocopyAliasing>(
286             mut candidate: ::zerocopy::Maybe<'_, Self, ___ZerocopyAliasing>,
287         ) -> ::zerocopy::util::macro_util::core_reexport::primitive::bool
288         where
289             ___ZerocopyAliasing: ::zerocopy::pointer::invariant::Reference,
290         {
291             use ::zerocopy::util::macro_util::core_reexport;
292 
293             #tag_enum
294 
295             type ___ZerocopyTagPrimitive = ::zerocopy::util::macro_util::SizeToTag<
296                 { core_reexport::mem::size_of::<___ZerocopyTag>() },
297             >;
298 
299             #tag_consts
300 
301             type ___ZerocopyOuterTag = #outer_tag_type;
302             type ___ZerocopyInnerTag = #inner_tag_type;
303 
304             #variant_structs
305 
306             #variants_union
307 
308             #[repr(C)]
309             struct ___ZerocopyRawEnum #generics {
310                 tag: ___ZerocopyOuterTag,
311                 variants: ___ZerocopyVariants #ty_generics,
312             }
313 
314             let tag = {
315                 // SAFETY:
316                 // - The provided cast addresses a subset of the bytes addressed
317                 //   by `candidate` because it addresses the starting tag of the
318                 //   enum.
319                 // - Because the pointer is cast from `candidate`, it has the
320                 //   same provenance as it.
321                 // - There are no `UnsafeCell`s in the tag because it is a
322                 //   primitive integer.
323                 let tag_ptr = unsafe {
324                     candidate.reborrow().cast_unsized_unchecked(|p: *mut Self| {
325                         p as *mut ___ZerocopyTagPrimitive
326                     })
327                 };
328                 // SAFETY: `tag_ptr` is casted from `candidate`, whose referent
329                 // is `Initialized`. Since we have not written uninitialized
330                 // bytes into the referent, `tag_ptr` is also `Initialized`.
331                 let tag_ptr = unsafe { tag_ptr.assume_initialized() };
332                 tag_ptr.bikeshed_recall_valid().read_unaligned::<::zerocopy::BecauseImmutable>()
333             };
334 
335             // SAFETY:
336             // - The raw enum has the same fields in the same locations as the
337             //   input enum, and may have a lower alignment. This guarantees
338             //   that it addresses a subset of the bytes addressed by
339             //   `candidate`.
340             // - The returned pointer is cast from `p`, and so has the same
341             //   provenance as `p`.
342             // - The raw enum has the same types at the same locations as the
343             //   original enum, and so preserves the locations of any
344             //   `UnsafeCell`s.
345             let raw_enum = unsafe {
346                 candidate.cast_unsized_unchecked(|p: *mut Self| {
347                     p as *mut ___ZerocopyRawEnum #ty_generics
348                 })
349             };
350             // SAFETY: `cast_unsized_unchecked` removes the initialization
351             // invariant from `p`, so we re-assert that all of the bytes are
352             // initialized.
353             let raw_enum = unsafe { raw_enum.assume_initialized() };
354             // SAFETY:
355             // - This projection returns a subfield of `this` using
356             //   `addr_of_mut!`.
357             // - Because the subfield pointer is derived from `this`, it has the
358             //   same provenance.
359             // - The locations of `UnsafeCell`s in the subfield match the
360             //   locations of `UnsafeCell`s in `this`. This is because the
361             //   subfield pointer just points to a smaller portion of the
362             //   overall struct.
363             let variants = unsafe {
364                 raw_enum.cast_unsized_unchecked(|p: *mut ___ZerocopyRawEnum #ty_generics| {
365                     core_reexport::ptr::addr_of_mut!((*p).variants)
366                 })
367             };
368 
369             #[allow(non_upper_case_globals)]
370             match tag {
371                 #(#match_arms,)*
372                 _ => false,
373             }
374         }
375     })
376 }
377