• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits.
2 
3 extern crate proc_macro;
4 
5 mod traits;
6 
7 use proc_macro2::TokenStream;
8 use quote::quote;
9 use syn::{parse_macro_input, DeriveInput, Result};
10 
11 use crate::traits::{
12   AnyBitPattern, CheckedBitPattern, Contiguous, Derivable, NoUninit, Pod,
13   TransparentWrapper, Zeroable,
14 };
15 
16 /// Derive the `Pod` trait for a struct
17 ///
18 /// The macro ensures that the struct follows all the the safety requirements
19 /// for the `Pod` trait.
20 ///
21 /// The following constraints need to be satisfied for the macro to succeed
22 ///
23 /// - All fields in the struct must implement `Pod`
24 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
25 /// - The struct must not contain any padding bytes
26 /// - The struct contains no generic parameters, if it is not
27 ///   `#[repr(transparent)]`
28 ///
29 /// ## Examples
30 ///
31 /// ```rust
32 /// # use std::marker::PhantomData;
33 /// # use bytemuck_derive::{Pod, Zeroable};
34 /// #[derive(Copy, Clone, Pod, Zeroable)]
35 /// #[repr(C)]
36 /// struct Test {
37 ///   a: u16,
38 ///   b: u16,
39 /// }
40 ///
41 /// #[derive(Copy, Clone, Pod, Zeroable)]
42 /// #[repr(transparent)]
43 /// struct Generic<A, B> {
44 ///   a: A,
45 ///   b: PhantomData<B>,
46 /// }
47 /// ```
48 ///
49 /// If the struct is generic, it must be `#[repr(transparent)]` also.
50 ///
51 /// ```compile_fail
52 /// # use bytemuck::{Pod, Zeroable};
53 /// # use std::marker::PhantomData;
54 /// #[derive(Copy, Clone, Pod, Zeroable)]
55 /// #[repr(C)] // must be `#[repr(transparent)]`
56 /// struct Generic<A> {
57 ///   a: A,
58 /// }
59 /// ```
60 ///
61 /// If the struct is generic and `#[repr(transparent)]`, then it is only `Pod`
62 /// when all of its generics are `Pod`, not just its fields.
63 ///
64 /// ```
65 /// # use bytemuck::{Pod, Zeroable};
66 /// # use std::marker::PhantomData;
67 /// #[derive(Copy, Clone, Pod, Zeroable)]
68 /// #[repr(transparent)]
69 /// struct Generic<A, B> {
70 ///   a: A,
71 ///   b: PhantomData<B>,
72 /// }
73 ///
74 /// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<u32> });
75 /// ```
76 ///
77 /// ```compile_fail
78 /// # use bytemuck::{Pod, Zeroable};
79 /// # use std::marker::PhantomData;
80 /// # #[derive(Copy, Clone, Pod, Zeroable)]
81 /// # #[repr(transparent)]
82 /// # struct Generic<A, B> {
83 /// #   a: A,
84 /// #   b: PhantomData<B>,
85 /// # }
86 /// struct NotPod;
87 ///
88 /// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<NotPod> });
89 /// ```
90 #[proc_macro_derive(Pod)]
derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream91 pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
92   let expanded =
93     derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
94 
95   proc_macro::TokenStream::from(expanded)
96 }
97 
98 /// Derive the `AnyBitPattern` trait for a struct
99 ///
100 /// The macro ensures that the struct follows all the the safety requirements
101 /// for the `AnyBitPattern` trait.
102 ///
103 /// The following constraints need to be satisfied for the macro to succeed
104 ///
105 /// - All fields in the struct must to implement `AnyBitPattern`
106 #[proc_macro_derive(AnyBitPattern)]
derive_anybitpattern( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream107 pub fn derive_anybitpattern(
108   input: proc_macro::TokenStream,
109 ) -> proc_macro::TokenStream {
110   let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
111     input as DeriveInput
112   ));
113 
114   proc_macro::TokenStream::from(expanded)
115 }
116 
117 /// Derive the `Zeroable` trait for a struct
118 ///
119 /// The macro ensures that the struct follows all the the safety requirements
120 /// for the `Zeroable` trait.
121 ///
122 /// The following constraints need to be satisfied for the macro to succeed
123 ///
124 /// - All fields in the struct must to implement `Zeroable`
125 ///
126 /// ## Example
127 ///
128 /// ```rust
129 /// # use bytemuck_derive::{Zeroable};
130 /// #[derive(Copy, Clone, Zeroable)]
131 /// #[repr(C)]
132 /// struct Test {
133 ///   a: u16,
134 ///   b: u16,
135 /// }
136 /// ```
137 ///
138 /// # Custom bounds
139 ///
140 /// Custom bounds for the derived `Zeroable` impl can be given using the
141 /// `#[zeroable(bound = "")]` helper attribute.
142 ///
143 /// Using this attribute additionally opts-in to "perfect derive" semantics,
144 /// where instead of adding bounds for each generic type parameter, bounds are
145 /// added for each field's type.
146 ///
147 /// ## Examples
148 ///
149 /// ```rust
150 /// # use bytemuck::Zeroable;
151 /// # use std::marker::PhantomData;
152 /// #[derive(Clone, Zeroable)]
153 /// #[zeroable(bound = "")]
154 /// struct AlwaysZeroable<T> {
155 ///   a: PhantomData<T>,
156 /// }
157 ///
158 /// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
159 /// ```
160 ///
161 /// ```rust,compile_fail
162 /// # use bytemuck::Zeroable;
163 /// # use std::marker::PhantomData;
164 /// #[derive(Clone, Zeroable)]
165 /// #[zeroable(bound = "T: Copy")]
166 /// struct ZeroableWhenTIsCopy<T> {
167 ///   a: PhantomData<T>,
168 /// }
169 ///
170 /// ZeroableWhenTIsCopy::<String>::zeroed();
171 /// ```
172 ///
173 /// The restriction that all fields must be Zeroable is still applied, and this
174 /// is enforced using the mentioned "perfect derive" semantics.
175 ///
176 /// ```rust
177 /// # use bytemuck::Zeroable;
178 /// #[derive(Clone, Zeroable)]
179 /// #[zeroable(bound = "")]
180 /// struct ZeroableWhenTIsZeroable<T> {
181 ///   a: T,
182 /// }
183 /// ZeroableWhenTIsZeroable::<u32>::zeroed();
184 /// ```
185 ///
186 /// ```rust,compile_fail
187 /// # use bytemuck::Zeroable;
188 /// # #[derive(Clone, Zeroable)]
189 /// # #[zeroable(bound = "")]
190 /// # struct ZeroableWhenTIsZeroable<T> {
191 /// #   a: T,
192 /// # }
193 /// ZeroableWhenTIsZeroable::<String>::zeroed();
194 /// ```
195 #[proc_macro_derive(Zeroable, attributes(zeroable))]
derive_zeroable( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream196 pub fn derive_zeroable(
197   input: proc_macro::TokenStream,
198 ) -> proc_macro::TokenStream {
199   let expanded =
200     derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
201 
202   proc_macro::TokenStream::from(expanded)
203 }
204 
205 /// Derive the `NoUninit` trait for a struct or enum
206 ///
207 /// The macro ensures that the type follows all the the safety requirements
208 /// for the `NoUninit` trait.
209 ///
210 /// The following constraints need to be satisfied for the macro to succeed
211 /// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
212 /// bounds, i.e. the type must be `Sized + Copy + 'static`):
213 ///
214 /// If applied to a struct:
215 /// - All fields in the struct must implement `NoUninit`
216 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
217 /// - The struct must not contain any padding bytes
218 /// - The struct must contain no generic parameters
219 ///
220 /// If applied to an enum:
221 /// - The enum must be explicit `#[repr(Int)]`
222 /// - All variants must be fieldless
223 /// - The enum must contain no generic parameters
224 #[proc_macro_derive(NoUninit)]
derive_no_uninit( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream225 pub fn derive_no_uninit(
226   input: proc_macro::TokenStream,
227 ) -> proc_macro::TokenStream {
228   let expanded =
229     derive_marker_trait::<NoUninit>(parse_macro_input!(input as DeriveInput));
230 
231   proc_macro::TokenStream::from(expanded)
232 }
233 
234 /// Derive the `CheckedBitPattern` trait for a struct or enum.
235 ///
236 /// The macro ensures that the type follows all the the safety requirements
237 /// for the `CheckedBitPattern` trait and derives the required `Bits` type
238 /// definition and `is_valid_bit_pattern` method for the type automatically.
239 ///
240 /// The following constraints need to be satisfied for the macro to succeed
241 /// (the rest of the constraints are guaranteed by the `CheckedBitPattern`
242 /// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit`
243 /// trait which `CheckedBitPattern` is a subtrait of):
244 ///
245 /// If applied to a struct:
246 /// - All fields must implement `CheckedBitPattern`
247 ///
248 /// If applied to an enum:
249 /// - All requirements already checked by `NoUninit`, just impls the trait
250 #[proc_macro_derive(CheckedBitPattern)]
derive_maybe_pod( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream251 pub fn derive_maybe_pod(
252   input: proc_macro::TokenStream,
253 ) -> proc_macro::TokenStream {
254   let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
255     input as DeriveInput
256   ));
257 
258   proc_macro::TokenStream::from(expanded)
259 }
260 
261 /// Derive the `TransparentWrapper` trait for a struct
262 ///
263 /// The macro ensures that the struct follows all the the safety requirements
264 /// for the `TransparentWrapper` trait.
265 ///
266 /// The following constraints need to be satisfied for the macro to succeed
267 ///
268 /// - The struct must be `#[repr(transparent)]`
269 /// - The struct must contain the `Wrapped` type
270 /// - Any ZST fields must be [`Zeroable`][derive@Zeroable].
271 ///
272 /// If the struct only contains a single field, the `Wrapped` type will
273 /// automatically be determined. If there is more then one field in the struct,
274 /// you need to specify the `Wrapped` type using `#[transparent(T)]`
275 ///
276 /// ## Examples
277 ///
278 /// ```rust
279 /// # use bytemuck_derive::TransparentWrapper;
280 /// # use std::marker::PhantomData;
281 /// #[derive(Copy, Clone, TransparentWrapper)]
282 /// #[repr(transparent)]
283 /// #[transparent(u16)]
284 /// struct Test<T> {
285 ///   inner: u16,
286 ///   extra: PhantomData<T>,
287 /// }
288 /// ```
289 ///
290 /// If the struct contains more than one field, the `Wrapped` type must be
291 /// explicitly specified.
292 ///
293 /// ```rust,compile_fail
294 /// # use bytemuck_derive::TransparentWrapper;
295 /// # use std::marker::PhantomData;
296 /// #[derive(Copy, Clone, TransparentWrapper)]
297 /// #[repr(transparent)]
298 /// // missing `#[transparent(u16)]`
299 /// struct Test<T> {
300 ///   inner: u16,
301 ///   extra: PhantomData<T>,
302 /// }
303 /// ```
304 ///
305 /// Any ZST fields must be `Zeroable`.
306 ///
307 /// ```rust,compile_fail
308 /// # use bytemuck_derive::TransparentWrapper;
309 /// # use std::marker::PhantomData;
310 /// struct NonTransparentSafeZST;
311 ///
312 /// #[derive(TransparentWrapper)]
313 /// #[repr(transparent)]
314 /// #[transparent(u16)]
315 /// struct Test<T> {
316 ///   inner: u16,
317 ///   extra: PhantomData<T>,
318 ///   another_extra: NonTransparentSafeZST, // not `Zeroable`
319 /// }
320 /// ```
321 #[proc_macro_derive(TransparentWrapper, attributes(transparent))]
derive_transparent( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream322 pub fn derive_transparent(
323   input: proc_macro::TokenStream,
324 ) -> proc_macro::TokenStream {
325   let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
326     input as DeriveInput
327   ));
328 
329   proc_macro::TokenStream::from(expanded)
330 }
331 
332 /// Derive the `Contiguous` trait for an enum
333 ///
334 /// The macro ensures that the enum follows all the the safety requirements
335 /// for the `Contiguous` trait.
336 ///
337 /// The following constraints need to be satisfied for the macro to succeed
338 ///
339 /// - The enum must be `#[repr(Int)]`
340 /// - The enum must be fieldless
341 /// - The enum discriminants must form a contiguous range
342 ///
343 /// ## Example
344 ///
345 /// ```rust
346 /// # use bytemuck_derive::{Contiguous};
347 ///
348 /// #[derive(Copy, Clone, Contiguous)]
349 /// #[repr(u8)]
350 /// enum Test {
351 ///   A = 0,
352 ///   B = 1,
353 ///   C = 2,
354 /// }
355 /// ```
356 #[proc_macro_derive(Contiguous)]
derive_contiguous( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream357 pub fn derive_contiguous(
358   input: proc_macro::TokenStream,
359 ) -> proc_macro::TokenStream {
360   let expanded =
361     derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
362 
363   proc_macro::TokenStream::from(expanded)
364 }
365 
366 /// Derive the `PartialEq` and `Eq` trait for a type
367 ///
368 /// The macro implements `PartialEq` and `Eq` by casting both sides of the
369 /// comparison to a byte slice and then compares those.
370 ///
371 /// ## Warning
372 ///
373 /// Since this implements a byte wise comparison, the behavior of floating point
374 /// numbers does not match their usual comparison behavior. Additionally other
375 /// custom comparison behaviors of the individual fields are also ignored. This
376 /// also does not implement `StructuralPartialEq` / `StructuralEq` like
377 /// `PartialEq` / `Eq` would. This means you can't pattern match on the values.
378 ///
379 /// ## Example
380 ///
381 /// ```rust
382 /// # use bytemuck_derive::{ByteEq, NoUninit};
383 /// #[derive(Copy, Clone, NoUninit, ByteEq)]
384 /// #[repr(C)]
385 /// struct Test {
386 ///   a: u32,
387 ///   b: char,
388 ///   c: f32,
389 /// }
390 /// ```
391 #[proc_macro_derive(ByteEq)]
derive_byte_eq( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream392 pub fn derive_byte_eq(
393   input: proc_macro::TokenStream,
394 ) -> proc_macro::TokenStream {
395   let input = parse_macro_input!(input as DeriveInput);
396   let ident = input.ident;
397 
398   proc_macro::TokenStream::from(quote! {
399     impl ::core::cmp::PartialEq for #ident {
400       #[inline]
401       #[must_use]
402       fn eq(&self, other: &Self) -> bool {
403         ::bytemuck::bytes_of(self) == ::bytemuck::bytes_of(other)
404       }
405     }
406     impl ::core::cmp::Eq for #ident { }
407   })
408 }
409 
410 /// Derive the `Hash` trait for a type
411 ///
412 /// The macro implements `Hash` by casting the value to a byte slice and hashing
413 /// that.
414 ///
415 /// ## Warning
416 ///
417 /// The hash does not match the standard library's `Hash` derive.
418 ///
419 /// ## Example
420 ///
421 /// ```rust
422 /// # use bytemuck_derive::{ByteHash, NoUninit};
423 /// #[derive(Copy, Clone, NoUninit, ByteHash)]
424 /// #[repr(C)]
425 /// struct Test {
426 ///   a: u32,
427 ///   b: char,
428 ///   c: f32,
429 /// }
430 /// ```
431 #[proc_macro_derive(ByteHash)]
derive_byte_hash( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream432 pub fn derive_byte_hash(
433   input: proc_macro::TokenStream,
434 ) -> proc_macro::TokenStream {
435   let input = parse_macro_input!(input as DeriveInput);
436   let ident = input.ident;
437 
438   proc_macro::TokenStream::from(quote! {
439     impl ::core::hash::Hash for #ident {
440       #[inline]
441       fn hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
442         ::core::hash::Hash::hash_slice(::bytemuck::bytes_of(self), state)
443       }
444 
445       #[inline]
446       fn hash_slice<H: ::core::hash::Hasher>(data: &[Self], state: &mut H) {
447         ::core::hash::Hash::hash_slice(::bytemuck::cast_slice::<_, u8>(data), state)
448       }
449     }
450   })
451 }
452 
453 /// Basic wrapper for error handling
derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream454 fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
455   derive_marker_trait_inner::<Trait>(input)
456     .unwrap_or_else(|err| err.into_compile_error())
457 }
458 
459 /// Find `#[name(key = "value")]` helper attributes on the struct, and return
460 /// their `"value"`s parsed with `parser`.
461 ///
462 /// Returns an error if any attributes with the given `name` do not match the
463 /// expected format. Returns `Ok([])` if no attributes with `name` are found.
find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>( attributes: &[syn::Attribute], name: &str, key: &str, parser: P, example_value: &str, invalid_value_msg: &str, ) -> Result<Vec<P::Output>>464 fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
465   attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
466   example_value: &str, invalid_value_msg: &str,
467 ) -> Result<Vec<P::Output>> {
468   let invalid_format_msg =
469     format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
470   let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
471     // If a `Path` matches our `name`, return an error, else ignore it.
472     // e.g. `#[zeroable]`
473     syn::Meta::Path(path) => path
474       .is_ident(name)
475       .then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))),
476     // If a `NameValue` matches our `name`, return an error, else ignore it.
477     // e.g. `#[zeroable = "hello"]`
478     syn::Meta::NameValue(namevalue) => {
479       namevalue.path.is_ident(name).then(|| {
480         Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
481       })
482     }
483     // If a `List` matches our `name`, match its contents to our format, else
484     // ignore it. If its contents match our format, return the value, else
485     // return an error.
486     syn::Meta::List(list) => list.path.is_ident(name).then(|| {
487       let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone())
488         .map_err(|_| {
489           syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
490         })?;
491       if namevalue.path.is_ident(key) {
492         match namevalue.value {
493           syn::Expr::Lit(syn::ExprLit {
494             lit: syn::Lit::Str(strlit), ..
495           }) => Ok(strlit),
496           _ => {
497             Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
498           }
499         }
500       } else {
501         Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
502       }
503     }),
504   });
505   // Parse each value found with the given parser, and return them if no errors
506   // occur.
507   values_to_check
508     .map(|lit| {
509       let lit = lit?;
510       lit.parse_with(parser).map_err(|err| {
511         syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}"))
512       })
513     })
514     .collect()
515 }
516 
derive_marker_trait_inner<Trait: Derivable>( mut input: DeriveInput, ) -> Result<TokenStream>517 fn derive_marker_trait_inner<Trait: Derivable>(
518   mut input: DeriveInput,
519 ) -> Result<TokenStream> {
520   let trait_ = Trait::ident(&input)?;
521   // If this trait allows explicit bounds, and any explicit bounds were given,
522   // then use those explicit bounds. Else, apply the default bounds (bound
523   // each generic type on this trait).
524   if let Some(name) = Trait::explicit_bounds_attribute_name() {
525     // See if any explicit bounds were given in attributes.
526     let explicit_bounds = find_and_parse_helper_attributes(
527       &input.attrs,
528       name,
529       "bound",
530       <syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>>::parse_terminated,
531       "Type: Trait",
532       "invalid where predicate",
533     )?;
534 
535     if !explicit_bounds.is_empty() {
536       // Explicit bounds were given.
537       // Enforce explicitly given bounds, and emit "perfect derive" (i.e. add
538       // bounds for each field's type).
539       let explicit_bounds = explicit_bounds
540         .into_iter()
541         .flatten()
542         .collect::<Vec<syn::WherePredicate>>();
543 
544       let predicates = &mut input.generics.make_where_clause().predicates;
545 
546       predicates.extend(explicit_bounds);
547 
548       let fields = match &input.data {
549         syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
550         syn::Data::Union(_) => {
551           return Err(syn::Error::new_spanned(
552             trait_,
553             &"perfect derive is not supported for unions",
554           ));
555         }
556         syn::Data::Enum(_) => {
557           return Err(syn::Error::new_spanned(
558             trait_,
559             &"perfect derive is not supported for enums",
560           ));
561         }
562       };
563 
564       for field in fields {
565         let ty = field.ty;
566         predicates.push(syn::parse_quote!(
567           #ty: #trait_
568         ));
569       }
570     } else {
571       // No explicit bounds were given.
572       // Enforce trait bound on all type generics.
573       add_trait_marker(&mut input.generics, &trait_);
574     }
575   } else {
576     // This trait does not allow explicit bounds.
577     // Enforce trait bound on all type generics.
578     add_trait_marker(&mut input.generics, &trait_);
579   }
580 
581   let name = &input.ident;
582 
583   let (impl_generics, ty_generics, where_clause) =
584     input.generics.split_for_impl();
585 
586   Trait::check_attributes(&input.data, &input.attrs)?;
587   let asserts = Trait::asserts(&input)?;
588   let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input)?;
589 
590   let implies_trait = if let Some(implies_trait) = Trait::implies_trait() {
591     quote!(unsafe impl #impl_generics #implies_trait for #name #ty_generics #where_clause {})
592   } else {
593     quote!()
594   };
595 
596   let where_clause =
597     if Trait::requires_where_clause() { where_clause } else { None };
598 
599   Ok(quote! {
600     #asserts
601 
602     #trait_impl_extras
603 
604     unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
605       #trait_impl
606     }
607 
608     #implies_trait
609   })
610 }
611 
612 /// Add a trait marker to the generics if it is not already present
add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path)613 fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
614   // Get each generic type parameter.
615   let type_params = generics
616     .type_params()
617     .map(|param| &param.ident)
618     .map(|param| {
619       syn::parse_quote!(
620         #param: #trait_name
621       )
622     })
623     .collect::<Vec<syn::WherePredicate>>();
624 
625   generics.make_where_clause().predicates.extend(type_params);
626 }
627