1 // Copyright 2024 The Fuchsia Authors
2 //
3 // Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
4 // <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
6 // This file may not be copied, modified, or distributed except according to
7 // those terms.
8
9 use proc_macro2::{Span, TokenStream};
10 use quote::quote;
11 use syn::{parse_quote, DataEnum, Error, Fields, Generics, Ident};
12
13 use crate::{derive_try_from_bytes_inner, repr::EnumRepr, Trait};
14
15 /// Generates a tag enum for the given enum. This generates an enum with the
16 /// same non-align `repr`s, variants, and corresponding discriminants, but none
17 /// of the fields.
generate_tag_enum(repr: &EnumRepr, data: &DataEnum) -> TokenStream18 pub(crate) fn generate_tag_enum(repr: &EnumRepr, data: &DataEnum) -> TokenStream {
19 let variants = data.variants.iter().map(|v| {
20 let ident = &v.ident;
21 if let Some((eq, discriminant)) = &v.discriminant {
22 quote! { #ident #eq #discriminant }
23 } else {
24 quote! { #ident }
25 }
26 });
27
28 // Don't include any `repr(align)` when generating the tag enum, as that
29 // could add padding after the tag but before any variants, which is not the
30 // correct behavior.
31 let repr = match repr {
32 EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
33 EnumRepr::Compound(c, _) => quote! { #c },
34 };
35
36 quote! {
37 #repr
38 #[allow(dead_code, non_camel_case_types)]
39 enum ___ZerocopyTag {
40 #(#variants,)*
41 }
42 }
43 }
44
tag_ident(variant_ident: &Ident) -> Ident45 fn tag_ident(variant_ident: &Ident) -> Ident {
46 Ident::new(&format!("___ZEROCOPY_TAG_{}", variant_ident), variant_ident.span())
47 }
48
49 /// Generates a constant for the tag associated with each variant of the enum.
50 /// When we match on the enum's tag, each arm matches one of these constants. We
51 /// have to use constants here because:
52 ///
53 /// - The type that we're matching on is not the type of the tag, it's an
54 /// integer of the same size as the tag type and with the same bit patterns.
55 /// - We can't read the enum tag as an enum because the bytes may not represent
56 /// a valid variant.
57 /// - Patterns do not currently support const expressions, so we have to assign
58 /// these constants to names rather than use them inline in the `match`
59 /// statement.
generate_tag_consts(data: &DataEnum) -> TokenStream60 fn generate_tag_consts(data: &DataEnum) -> TokenStream {
61 let tags = data.variants.iter().map(|v| {
62 let variant_ident = &v.ident;
63 let tag_ident = tag_ident(variant_ident);
64
65 quote! {
66 // This casts the enum variant to its discriminant, and then
67 // converts the discriminant to the target integral type via a
68 // numeric cast [1].
69 //
70 // Because these are the same size, this is defined to be a no-op
71 // and therefore is a lossless conversion [2].
72 //
73 // [1]: https://doc.rust-lang.org/stable/reference/expressions/operator-expr.html#enum-cast
74 // [2]: https://doc.rust-lang.org/stable/reference/expressions/operator-expr.html#numeric-cast
75 #[allow(non_upper_case_globals)]
76 const #tag_ident: ___ZerocopyTagPrimitive =
77 ___ZerocopyTag::#variant_ident as ___ZerocopyTagPrimitive;
78 }
79 });
80
81 quote! {
82 #(#tags)*
83 }
84 }
85
variant_struct_ident(variant_ident: &Ident) -> Ident86 fn variant_struct_ident(variant_ident: &Ident) -> Ident {
87 Ident::new(&format!("___ZerocopyVariantStruct_{}", variant_ident), variant_ident.span())
88 }
89
90 /// Generates variant structs for the given enum variant.
91 ///
92 /// These are structs associated with each variant of an enum. They are
93 /// `repr(C)` tuple structs with the same fields as the variant after a
94 /// `MaybeUninit<___ZerocopyInnerTag>`.
95 ///
96 /// In order to unify the generated types for `repr(C)` and `repr(int)` enums,
97 /// we use a "fused" representation with fields for both an inner tag and an
98 /// outer tag. Depending on the repr, we will set one of these tags to the tag
99 /// type and the other to `()`. This lets us generate the same code but put the
100 /// tags in different locations.
generate_variant_structs( enum_name: &Ident, generics: &Generics, data: &DataEnum, ) -> TokenStream101 fn generate_variant_structs(
102 enum_name: &Ident,
103 generics: &Generics,
104 data: &DataEnum,
105 ) -> TokenStream {
106 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
107
108 // All variant structs have a `PhantomData<MyEnum<...>>` field because we
109 // don't know which generic parameters each variant will use, and unused
110 // generic parameters are a compile error.
111 let phantom_ty = quote! {
112 core_reexport::marker::PhantomData<#enum_name #ty_generics>
113 };
114
115 let variant_structs = data.variants.iter().filter_map(|variant| {
116 // We don't generate variant structs for unit variants because we only
117 // need to check the tag. This helps cut down our generated code a bit.
118 if matches!(variant.fields, Fields::Unit) {
119 return None;
120 }
121
122 let variant_struct_ident = variant_struct_ident(&variant.ident);
123 let field_types = variant.fields.iter().map(|f| &f.ty);
124
125 let variant_struct = parse_quote! {
126 #[repr(C)]
127 #[allow(non_snake_case)]
128 struct #variant_struct_ident #impl_generics (
129 core_reexport::mem::MaybeUninit<___ZerocopyInnerTag>,
130 #(#field_types,)*
131 #phantom_ty,
132 ) #where_clause;
133 };
134
135 // We do this rather than emitting `#[derive(::zerocopy::TryFromBytes)]`
136 // because that is not hygienic, and this is also more performant.
137 let try_from_bytes_impl = derive_try_from_bytes_inner(&variant_struct, Trait::TryFromBytes)
138 .expect("derive_try_from_bytes_inner should not fail on synthesized type");
139
140 Some(quote! {
141 #variant_struct
142 #try_from_bytes_impl
143 })
144 });
145
146 quote! {
147 #(#variant_structs)*
148 }
149 }
150
generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream151 fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream {
152 let (_, ty_generics, _) = generics.split_for_impl();
153
154 let fields = data.variants.iter().filter_map(|variant| {
155 // We don't generate variant structs for unit variants because we only
156 // need to check the tag. This helps cut down our generated code a bit.
157 if matches!(variant.fields, Fields::Unit) {
158 return None;
159 }
160
161 // Field names are prefixed with `__field_` to prevent name collision with
162 // the `__nonempty` field.
163 let field_name = Ident::new(&format!("__field_{}", &variant.ident), variant.ident.span());
164 let variant_struct_ident = variant_struct_ident(&variant.ident);
165
166 Some(quote! {
167 #field_name: core_reexport::mem::ManuallyDrop<
168 #variant_struct_ident #ty_generics
169 >,
170 })
171 });
172
173 quote! {
174 #[repr(C)]
175 #[allow(non_snake_case)]
176 union ___ZerocopyVariants #generics {
177 #(#fields)*
178 // Enums can have variants with no fields, but unions must
179 // have at least one field. So we just add a trailing unit
180 // to ensure that this union always has at least one field.
181 // Because this union is `repr(C)`, this unit type does not
182 // affect the layout.
183 __nonempty: (),
184 }
185 }
186 }
187
188 /// Generates an implementation of `is_bit_valid` for an arbitrary enum.
189 ///
190 /// The general process is:
191 ///
192 /// 1. Generate a tag enum. This is an enum with the same repr, variants, and
193 /// corresponding discriminants as the original enum, but without any fields
194 /// on the variants. This gives us access to an enum where the variants have
195 /// the same discriminants as the one we're writing `is_bit_valid` for.
196 /// 2. Make constants from the variants of the tag enum. We need these because
197 /// we can't put const exprs in match arms.
198 /// 3. Generate variant structs. These are structs which have the same fields as
199 /// each variant of the enum, and are `#[repr(C)]` with an optional "inner
200 /// tag".
201 /// 4. Generate a variants union, with one field for each variant struct type.
202 /// 5. And finally, our raw enum is a `#[repr(C)]` struct of an "outer tag" and
203 /// the variants union.
204 ///
205 /// See these reference links for fully-worked example decompositions.
206 ///
207 /// - `repr(C)`: <https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields>
208 /// - `repr(int)`: <https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields>
209 /// - `repr(C, int)`: <https://doc.rust-lang.org/reference/type-layout.html#combining-primitive-representations-of-enums-with-fields-and-reprc>
derive_is_bit_valid( enum_ident: &Ident, repr: &EnumRepr, generics: &Generics, data: &DataEnum, ) -> Result<TokenStream, Error>210 pub(crate) fn derive_is_bit_valid(
211 enum_ident: &Ident,
212 repr: &EnumRepr,
213 generics: &Generics,
214 data: &DataEnum,
215 ) -> Result<TokenStream, Error> {
216 let trait_path = Trait::TryFromBytes.crate_path();
217 let tag_enum = generate_tag_enum(repr, data);
218 let tag_consts = generate_tag_consts(data);
219
220 let (outer_tag_type, inner_tag_type) = if repr.is_c() {
221 (quote! { ___ZerocopyTag }, quote! { () })
222 } else if repr.is_primitive() {
223 (quote! { () }, quote! { ___ZerocopyTag })
224 } else {
225 return Err(Error::new(
226 Span::call_site(),
227 "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
228 ));
229 };
230
231 let variant_structs = generate_variant_structs(enum_ident, generics, data);
232 let variants_union = generate_variants_union(generics, data);
233
234 let (_, ty_generics, _) = generics.split_for_impl();
235
236 let match_arms = data.variants.iter().map(|variant| {
237 let tag_ident = tag_ident(&variant.ident);
238 let variant_struct_ident = variant_struct_ident(&variant.ident);
239
240 if matches!(variant.fields, Fields::Unit) {
241 // Unit variants don't need any further validation beyond checking
242 // the tag.
243 quote! {
244 #tag_ident => true
245 }
246 } else {
247 quote! {
248 #tag_ident => {
249 // SAFETY:
250 // - This cast is from a `repr(C)` union which has a field
251 // of type `variant_struct_ident` to that variant struct
252 // type itself. This addresses a subset of the bytes
253 // addressed by `variants`.
254 // - The returned pointer is cast from `p`, and so has the
255 // same provenance as `p`.
256 // - We checked that the tag of the enum matched the
257 // constant for this variant, so this cast preserves
258 // types and locations of all fields. Therefore, any
259 // `UnsafeCell`s will have the same location as in the
260 // original type.
261 let variant = unsafe {
262 variants.cast_unsized_unchecked(
263 |p: *mut ___ZerocopyVariants #ty_generics| {
264 p as *mut #variant_struct_ident #ty_generics
265 }
266 )
267 };
268 // SAFETY: `cast_unsized_unchecked` removes the
269 // initialization invariant from `p`, so we re-assert that
270 // all of the bytes are initialized.
271 let variant = unsafe { variant.assume_initialized() };
272 <
273 #variant_struct_ident #ty_generics as #trait_path
274 >::is_bit_valid(variant)
275 }
276 }
277 }
278 });
279
280 Ok(quote! {
281 // SAFETY: We use `is_bit_valid` to validate that the bit pattern of the
282 // enum's tag corresponds to one of the enum's discriminants. Then, we
283 // check the bit validity of each field of the corresponding variant.
284 // Thus, this is a sound implementation of `is_bit_valid`.
285 fn is_bit_valid<___ZerocopyAliasing>(
286 mut candidate: ::zerocopy::Maybe<'_, Self, ___ZerocopyAliasing>,
287 ) -> ::zerocopy::util::macro_util::core_reexport::primitive::bool
288 where
289 ___ZerocopyAliasing: ::zerocopy::pointer::invariant::Reference,
290 {
291 use ::zerocopy::util::macro_util::core_reexport;
292
293 #tag_enum
294
295 type ___ZerocopyTagPrimitive = ::zerocopy::util::macro_util::SizeToTag<
296 { core_reexport::mem::size_of::<___ZerocopyTag>() },
297 >;
298
299 #tag_consts
300
301 type ___ZerocopyOuterTag = #outer_tag_type;
302 type ___ZerocopyInnerTag = #inner_tag_type;
303
304 #variant_structs
305
306 #variants_union
307
308 #[repr(C)]
309 struct ___ZerocopyRawEnum #generics {
310 tag: ___ZerocopyOuterTag,
311 variants: ___ZerocopyVariants #ty_generics,
312 }
313
314 let tag = {
315 // SAFETY:
316 // - The provided cast addresses a subset of the bytes addressed
317 // by `candidate` because it addresses the starting tag of the
318 // enum.
319 // - Because the pointer is cast from `candidate`, it has the
320 // same provenance as it.
321 // - There are no `UnsafeCell`s in the tag because it is a
322 // primitive integer.
323 let tag_ptr = unsafe {
324 candidate.reborrow().cast_unsized_unchecked(|p: *mut Self| {
325 p as *mut ___ZerocopyTagPrimitive
326 })
327 };
328 // SAFETY: `tag_ptr` is casted from `candidate`, whose referent
329 // is `Initialized`. Since we have not written uninitialized
330 // bytes into the referent, `tag_ptr` is also `Initialized`.
331 let tag_ptr = unsafe { tag_ptr.assume_initialized() };
332 tag_ptr.bikeshed_recall_valid().read_unaligned::<::zerocopy::BecauseImmutable>()
333 };
334
335 // SAFETY:
336 // - The raw enum has the same fields in the same locations as the
337 // input enum, and may have a lower alignment. This guarantees
338 // that it addresses a subset of the bytes addressed by
339 // `candidate`.
340 // - The returned pointer is cast from `p`, and so has the same
341 // provenance as `p`.
342 // - The raw enum has the same types at the same locations as the
343 // original enum, and so preserves the locations of any
344 // `UnsafeCell`s.
345 let raw_enum = unsafe {
346 candidate.cast_unsized_unchecked(|p: *mut Self| {
347 p as *mut ___ZerocopyRawEnum #ty_generics
348 })
349 };
350 // SAFETY: `cast_unsized_unchecked` removes the initialization
351 // invariant from `p`, so we re-assert that all of the bytes are
352 // initialized.
353 let raw_enum = unsafe { raw_enum.assume_initialized() };
354 // SAFETY:
355 // - This projection returns a subfield of `this` using
356 // `addr_of_mut!`.
357 // - Because the subfield pointer is derived from `this`, it has the
358 // same provenance.
359 // - The locations of `UnsafeCell`s in the subfield match the
360 // locations of `UnsafeCell`s in `this`. This is because the
361 // subfield pointer just points to a smaller portion of the
362 // overall struct.
363 let variants = unsafe {
364 raw_enum.cast_unsized_unchecked(|p: *mut ___ZerocopyRawEnum #ty_generics| {
365 core_reexport::ptr::addr_of_mut!((*p).variants)
366 })
367 };
368
369 #[allow(non_upper_case_globals)]
370 match tag {
371 #(#match_arms,)*
372 _ => false,
373 }
374 }
375 })
376 }
377