1 //! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits.
2
3 extern crate proc_macro;
4
5 mod traits;
6
7 use proc_macro2::TokenStream;
8 use quote::quote;
9 use syn::{parse_macro_input, DeriveInput, Result};
10
11 use crate::traits::{
12 AnyBitPattern, CheckedBitPattern, Contiguous, Derivable, NoUninit, Pod,
13 TransparentWrapper, Zeroable,
14 };
15
16 /// Derive the `Pod` trait for a struct
17 ///
18 /// The macro ensures that the struct follows all the the safety requirements
19 /// for the `Pod` trait.
20 ///
21 /// The following constraints need to be satisfied for the macro to succeed
22 ///
23 /// - All fields in the struct must implement `Pod`
24 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
25 /// - The struct must not contain any padding bytes
26 /// - The struct contains no generic parameters, if it is not
27 /// `#[repr(transparent)]`
28 ///
29 /// ## Examples
30 ///
31 /// ```rust
32 /// # use std::marker::PhantomData;
33 /// # use bytemuck_derive::{Pod, Zeroable};
34 /// #[derive(Copy, Clone, Pod, Zeroable)]
35 /// #[repr(C)]
36 /// struct Test {
37 /// a: u16,
38 /// b: u16,
39 /// }
40 ///
41 /// #[derive(Copy, Clone, Pod, Zeroable)]
42 /// #[repr(transparent)]
43 /// struct Generic<A, B> {
44 /// a: A,
45 /// b: PhantomData<B>,
46 /// }
47 /// ```
48 ///
49 /// If the struct is generic, it must be `#[repr(transparent)]` also.
50 ///
51 /// ```compile_fail
52 /// # use bytemuck::{Pod, Zeroable};
53 /// # use std::marker::PhantomData;
54 /// #[derive(Copy, Clone, Pod, Zeroable)]
55 /// #[repr(C)] // must be `#[repr(transparent)]`
56 /// struct Generic<A> {
57 /// a: A,
58 /// }
59 /// ```
60 ///
61 /// If the struct is generic and `#[repr(transparent)]`, then it is only `Pod`
62 /// when all of its generics are `Pod`, not just its fields.
63 ///
64 /// ```
65 /// # use bytemuck::{Pod, Zeroable};
66 /// # use std::marker::PhantomData;
67 /// #[derive(Copy, Clone, Pod, Zeroable)]
68 /// #[repr(transparent)]
69 /// struct Generic<A, B> {
70 /// a: A,
71 /// b: PhantomData<B>,
72 /// }
73 ///
74 /// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<u32> });
75 /// ```
76 ///
77 /// ```compile_fail
78 /// # use bytemuck::{Pod, Zeroable};
79 /// # use std::marker::PhantomData;
80 /// # #[derive(Copy, Clone, Pod, Zeroable)]
81 /// # #[repr(transparent)]
82 /// # struct Generic<A, B> {
83 /// # a: A,
84 /// # b: PhantomData<B>,
85 /// # }
86 /// struct NotPod;
87 ///
88 /// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<NotPod> });
89 /// ```
90 #[proc_macro_derive(Pod)]
derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream91 pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
92 let expanded =
93 derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
94
95 proc_macro::TokenStream::from(expanded)
96 }
97
98 /// Derive the `AnyBitPattern` trait for a struct
99 ///
100 /// The macro ensures that the struct follows all the the safety requirements
101 /// for the `AnyBitPattern` trait.
102 ///
103 /// The following constraints need to be satisfied for the macro to succeed
104 ///
105 /// - All fields in the struct must to implement `AnyBitPattern`
106 #[proc_macro_derive(AnyBitPattern)]
derive_anybitpattern( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream107 pub fn derive_anybitpattern(
108 input: proc_macro::TokenStream,
109 ) -> proc_macro::TokenStream {
110 let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
111 input as DeriveInput
112 ));
113
114 proc_macro::TokenStream::from(expanded)
115 }
116
117 /// Derive the `Zeroable` trait for a struct
118 ///
119 /// The macro ensures that the struct follows all the the safety requirements
120 /// for the `Zeroable` trait.
121 ///
122 /// The following constraints need to be satisfied for the macro to succeed
123 ///
124 /// - All fields in the struct must to implement `Zeroable`
125 ///
126 /// ## Example
127 ///
128 /// ```rust
129 /// # use bytemuck_derive::{Zeroable};
130 /// #[derive(Copy, Clone, Zeroable)]
131 /// #[repr(C)]
132 /// struct Test {
133 /// a: u16,
134 /// b: u16,
135 /// }
136 /// ```
137 ///
138 /// # Custom bounds
139 ///
140 /// Custom bounds for the derived `Zeroable` impl can be given using the
141 /// `#[zeroable(bound = "")]` helper attribute.
142 ///
143 /// Using this attribute additionally opts-in to "perfect derive" semantics,
144 /// where instead of adding bounds for each generic type parameter, bounds are
145 /// added for each field's type.
146 ///
147 /// ## Examples
148 ///
149 /// ```rust
150 /// # use bytemuck::Zeroable;
151 /// # use std::marker::PhantomData;
152 /// #[derive(Clone, Zeroable)]
153 /// #[zeroable(bound = "")]
154 /// struct AlwaysZeroable<T> {
155 /// a: PhantomData<T>,
156 /// }
157 ///
158 /// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
159 /// ```
160 ///
161 /// ```rust,compile_fail
162 /// # use bytemuck::Zeroable;
163 /// # use std::marker::PhantomData;
164 /// #[derive(Clone, Zeroable)]
165 /// #[zeroable(bound = "T: Copy")]
166 /// struct ZeroableWhenTIsCopy<T> {
167 /// a: PhantomData<T>,
168 /// }
169 ///
170 /// ZeroableWhenTIsCopy::<String>::zeroed();
171 /// ```
172 ///
173 /// The restriction that all fields must be Zeroable is still applied, and this
174 /// is enforced using the mentioned "perfect derive" semantics.
175 ///
176 /// ```rust
177 /// # use bytemuck::Zeroable;
178 /// #[derive(Clone, Zeroable)]
179 /// #[zeroable(bound = "")]
180 /// struct ZeroableWhenTIsZeroable<T> {
181 /// a: T,
182 /// }
183 /// ZeroableWhenTIsZeroable::<u32>::zeroed();
184 /// ```
185 ///
186 /// ```rust,compile_fail
187 /// # use bytemuck::Zeroable;
188 /// # #[derive(Clone, Zeroable)]
189 /// # #[zeroable(bound = "")]
190 /// # struct ZeroableWhenTIsZeroable<T> {
191 /// # a: T,
192 /// # }
193 /// ZeroableWhenTIsZeroable::<String>::zeroed();
194 /// ```
195 #[proc_macro_derive(Zeroable, attributes(zeroable))]
derive_zeroable( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream196 pub fn derive_zeroable(
197 input: proc_macro::TokenStream,
198 ) -> proc_macro::TokenStream {
199 let expanded =
200 derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
201
202 proc_macro::TokenStream::from(expanded)
203 }
204
205 /// Derive the `NoUninit` trait for a struct or enum
206 ///
207 /// The macro ensures that the type follows all the the safety requirements
208 /// for the `NoUninit` trait.
209 ///
210 /// The following constraints need to be satisfied for the macro to succeed
211 /// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
212 /// bounds, i.e. the type must be `Sized + Copy + 'static`):
213 ///
214 /// If applied to a struct:
215 /// - All fields in the struct must implement `NoUninit`
216 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
217 /// - The struct must not contain any padding bytes
218 /// - The struct must contain no generic parameters
219 ///
220 /// If applied to an enum:
221 /// - The enum must be explicit `#[repr(Int)]`
222 /// - All variants must be fieldless
223 /// - The enum must contain no generic parameters
224 #[proc_macro_derive(NoUninit)]
derive_no_uninit( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream225 pub fn derive_no_uninit(
226 input: proc_macro::TokenStream,
227 ) -> proc_macro::TokenStream {
228 let expanded =
229 derive_marker_trait::<NoUninit>(parse_macro_input!(input as DeriveInput));
230
231 proc_macro::TokenStream::from(expanded)
232 }
233
234 /// Derive the `CheckedBitPattern` trait for a struct or enum.
235 ///
236 /// The macro ensures that the type follows all the the safety requirements
237 /// for the `CheckedBitPattern` trait and derives the required `Bits` type
238 /// definition and `is_valid_bit_pattern` method for the type automatically.
239 ///
240 /// The following constraints need to be satisfied for the macro to succeed
241 /// (the rest of the constraints are guaranteed by the `CheckedBitPattern`
242 /// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit`
243 /// trait which `CheckedBitPattern` is a subtrait of):
244 ///
245 /// If applied to a struct:
246 /// - All fields must implement `CheckedBitPattern`
247 ///
248 /// If applied to an enum:
249 /// - All requirements already checked by `NoUninit`, just impls the trait
250 #[proc_macro_derive(CheckedBitPattern)]
derive_maybe_pod( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream251 pub fn derive_maybe_pod(
252 input: proc_macro::TokenStream,
253 ) -> proc_macro::TokenStream {
254 let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
255 input as DeriveInput
256 ));
257
258 proc_macro::TokenStream::from(expanded)
259 }
260
261 /// Derive the `TransparentWrapper` trait for a struct
262 ///
263 /// The macro ensures that the struct follows all the the safety requirements
264 /// for the `TransparentWrapper` trait.
265 ///
266 /// The following constraints need to be satisfied for the macro to succeed
267 ///
268 /// - The struct must be `#[repr(transparent)]`
269 /// - The struct must contain the `Wrapped` type
270 /// - Any ZST fields must be [`Zeroable`][derive@Zeroable].
271 ///
272 /// If the struct only contains a single field, the `Wrapped` type will
273 /// automatically be determined. If there is more then one field in the struct,
274 /// you need to specify the `Wrapped` type using `#[transparent(T)]`
275 ///
276 /// ## Examples
277 ///
278 /// ```rust
279 /// # use bytemuck_derive::TransparentWrapper;
280 /// # use std::marker::PhantomData;
281 /// #[derive(Copy, Clone, TransparentWrapper)]
282 /// #[repr(transparent)]
283 /// #[transparent(u16)]
284 /// struct Test<T> {
285 /// inner: u16,
286 /// extra: PhantomData<T>,
287 /// }
288 /// ```
289 ///
290 /// If the struct contains more than one field, the `Wrapped` type must be
291 /// explicitly specified.
292 ///
293 /// ```rust,compile_fail
294 /// # use bytemuck_derive::TransparentWrapper;
295 /// # use std::marker::PhantomData;
296 /// #[derive(Copy, Clone, TransparentWrapper)]
297 /// #[repr(transparent)]
298 /// // missing `#[transparent(u16)]`
299 /// struct Test<T> {
300 /// inner: u16,
301 /// extra: PhantomData<T>,
302 /// }
303 /// ```
304 ///
305 /// Any ZST fields must be `Zeroable`.
306 ///
307 /// ```rust,compile_fail
308 /// # use bytemuck_derive::TransparentWrapper;
309 /// # use std::marker::PhantomData;
310 /// struct NonTransparentSafeZST;
311 ///
312 /// #[derive(TransparentWrapper)]
313 /// #[repr(transparent)]
314 /// #[transparent(u16)]
315 /// struct Test<T> {
316 /// inner: u16,
317 /// extra: PhantomData<T>,
318 /// another_extra: NonTransparentSafeZST, // not `Zeroable`
319 /// }
320 /// ```
321 #[proc_macro_derive(TransparentWrapper, attributes(transparent))]
derive_transparent( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream322 pub fn derive_transparent(
323 input: proc_macro::TokenStream,
324 ) -> proc_macro::TokenStream {
325 let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
326 input as DeriveInput
327 ));
328
329 proc_macro::TokenStream::from(expanded)
330 }
331
332 /// Derive the `Contiguous` trait for an enum
333 ///
334 /// The macro ensures that the enum follows all the the safety requirements
335 /// for the `Contiguous` trait.
336 ///
337 /// The following constraints need to be satisfied for the macro to succeed
338 ///
339 /// - The enum must be `#[repr(Int)]`
340 /// - The enum must be fieldless
341 /// - The enum discriminants must form a contiguous range
342 ///
343 /// ## Example
344 ///
345 /// ```rust
346 /// # use bytemuck_derive::{Contiguous};
347 ///
348 /// #[derive(Copy, Clone, Contiguous)]
349 /// #[repr(u8)]
350 /// enum Test {
351 /// A = 0,
352 /// B = 1,
353 /// C = 2,
354 /// }
355 /// ```
356 #[proc_macro_derive(Contiguous)]
derive_contiguous( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream357 pub fn derive_contiguous(
358 input: proc_macro::TokenStream,
359 ) -> proc_macro::TokenStream {
360 let expanded =
361 derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
362
363 proc_macro::TokenStream::from(expanded)
364 }
365
366 /// Derive the `PartialEq` and `Eq` trait for a type
367 ///
368 /// The macro implements `PartialEq` and `Eq` by casting both sides of the
369 /// comparison to a byte slice and then compares those.
370 ///
371 /// ## Warning
372 ///
373 /// Since this implements a byte wise comparison, the behavior of floating point
374 /// numbers does not match their usual comparison behavior. Additionally other
375 /// custom comparison behaviors of the individual fields are also ignored. This
376 /// also does not implement `StructuralPartialEq` / `StructuralEq` like
377 /// `PartialEq` / `Eq` would. This means you can't pattern match on the values.
378 ///
379 /// ## Example
380 ///
381 /// ```rust
382 /// # use bytemuck_derive::{ByteEq, NoUninit};
383 /// #[derive(Copy, Clone, NoUninit, ByteEq)]
384 /// #[repr(C)]
385 /// struct Test {
386 /// a: u32,
387 /// b: char,
388 /// c: f32,
389 /// }
390 /// ```
391 #[proc_macro_derive(ByteEq)]
derive_byte_eq( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream392 pub fn derive_byte_eq(
393 input: proc_macro::TokenStream,
394 ) -> proc_macro::TokenStream {
395 let input = parse_macro_input!(input as DeriveInput);
396 let ident = input.ident;
397
398 proc_macro::TokenStream::from(quote! {
399 impl ::core::cmp::PartialEq for #ident {
400 #[inline]
401 #[must_use]
402 fn eq(&self, other: &Self) -> bool {
403 ::bytemuck::bytes_of(self) == ::bytemuck::bytes_of(other)
404 }
405 }
406 impl ::core::cmp::Eq for #ident { }
407 })
408 }
409
410 /// Derive the `Hash` trait for a type
411 ///
412 /// The macro implements `Hash` by casting the value to a byte slice and hashing
413 /// that.
414 ///
415 /// ## Warning
416 ///
417 /// The hash does not match the standard library's `Hash` derive.
418 ///
419 /// ## Example
420 ///
421 /// ```rust
422 /// # use bytemuck_derive::{ByteHash, NoUninit};
423 /// #[derive(Copy, Clone, NoUninit, ByteHash)]
424 /// #[repr(C)]
425 /// struct Test {
426 /// a: u32,
427 /// b: char,
428 /// c: f32,
429 /// }
430 /// ```
431 #[proc_macro_derive(ByteHash)]
derive_byte_hash( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream432 pub fn derive_byte_hash(
433 input: proc_macro::TokenStream,
434 ) -> proc_macro::TokenStream {
435 let input = parse_macro_input!(input as DeriveInput);
436 let ident = input.ident;
437
438 proc_macro::TokenStream::from(quote! {
439 impl ::core::hash::Hash for #ident {
440 #[inline]
441 fn hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
442 ::core::hash::Hash::hash_slice(::bytemuck::bytes_of(self), state)
443 }
444
445 #[inline]
446 fn hash_slice<H: ::core::hash::Hasher>(data: &[Self], state: &mut H) {
447 ::core::hash::Hash::hash_slice(::bytemuck::cast_slice::<_, u8>(data), state)
448 }
449 }
450 })
451 }
452
453 /// Basic wrapper for error handling
derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream454 fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
455 derive_marker_trait_inner::<Trait>(input)
456 .unwrap_or_else(|err| err.into_compile_error())
457 }
458
459 /// Find `#[name(key = "value")]` helper attributes on the struct, and return
460 /// their `"value"`s parsed with `parser`.
461 ///
462 /// Returns an error if any attributes with the given `name` do not match the
463 /// expected format. Returns `Ok([])` if no attributes with `name` are found.
find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>( attributes: &[syn::Attribute], name: &str, key: &str, parser: P, example_value: &str, invalid_value_msg: &str, ) -> Result<Vec<P::Output>>464 fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
465 attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
466 example_value: &str, invalid_value_msg: &str,
467 ) -> Result<Vec<P::Output>> {
468 let invalid_format_msg =
469 format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
470 let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
471 // If a `Path` matches our `name`, return an error, else ignore it.
472 // e.g. `#[zeroable]`
473 syn::Meta::Path(path) => path
474 .is_ident(name)
475 .then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))),
476 // If a `NameValue` matches our `name`, return an error, else ignore it.
477 // e.g. `#[zeroable = "hello"]`
478 syn::Meta::NameValue(namevalue) => {
479 namevalue.path.is_ident(name).then(|| {
480 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
481 })
482 }
483 // If a `List` matches our `name`, match its contents to our format, else
484 // ignore it. If its contents match our format, return the value, else
485 // return an error.
486 syn::Meta::List(list) => list.path.is_ident(name).then(|| {
487 let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone())
488 .map_err(|_| {
489 syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
490 })?;
491 if namevalue.path.is_ident(key) {
492 match namevalue.value {
493 syn::Expr::Lit(syn::ExprLit {
494 lit: syn::Lit::Str(strlit), ..
495 }) => Ok(strlit),
496 _ => {
497 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
498 }
499 }
500 } else {
501 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
502 }
503 }),
504 });
505 // Parse each value found with the given parser, and return them if no errors
506 // occur.
507 values_to_check
508 .map(|lit| {
509 let lit = lit?;
510 lit.parse_with(parser).map_err(|err| {
511 syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}"))
512 })
513 })
514 .collect()
515 }
516
derive_marker_trait_inner<Trait: Derivable>( mut input: DeriveInput, ) -> Result<TokenStream>517 fn derive_marker_trait_inner<Trait: Derivable>(
518 mut input: DeriveInput,
519 ) -> Result<TokenStream> {
520 let trait_ = Trait::ident(&input)?;
521 // If this trait allows explicit bounds, and any explicit bounds were given,
522 // then use those explicit bounds. Else, apply the default bounds (bound
523 // each generic type on this trait).
524 if let Some(name) = Trait::explicit_bounds_attribute_name() {
525 // See if any explicit bounds were given in attributes.
526 let explicit_bounds = find_and_parse_helper_attributes(
527 &input.attrs,
528 name,
529 "bound",
530 <syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>>::parse_terminated,
531 "Type: Trait",
532 "invalid where predicate",
533 )?;
534
535 if !explicit_bounds.is_empty() {
536 // Explicit bounds were given.
537 // Enforce explicitly given bounds, and emit "perfect derive" (i.e. add
538 // bounds for each field's type).
539 let explicit_bounds = explicit_bounds
540 .into_iter()
541 .flatten()
542 .collect::<Vec<syn::WherePredicate>>();
543
544 let predicates = &mut input.generics.make_where_clause().predicates;
545
546 predicates.extend(explicit_bounds);
547
548 let fields = match &input.data {
549 syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
550 syn::Data::Union(_) => {
551 return Err(syn::Error::new_spanned(
552 trait_,
553 &"perfect derive is not supported for unions",
554 ));
555 }
556 syn::Data::Enum(_) => {
557 return Err(syn::Error::new_spanned(
558 trait_,
559 &"perfect derive is not supported for enums",
560 ));
561 }
562 };
563
564 for field in fields {
565 let ty = field.ty;
566 predicates.push(syn::parse_quote!(
567 #ty: #trait_
568 ));
569 }
570 } else {
571 // No explicit bounds were given.
572 // Enforce trait bound on all type generics.
573 add_trait_marker(&mut input.generics, &trait_);
574 }
575 } else {
576 // This trait does not allow explicit bounds.
577 // Enforce trait bound on all type generics.
578 add_trait_marker(&mut input.generics, &trait_);
579 }
580
581 let name = &input.ident;
582
583 let (impl_generics, ty_generics, where_clause) =
584 input.generics.split_for_impl();
585
586 Trait::check_attributes(&input.data, &input.attrs)?;
587 let asserts = Trait::asserts(&input)?;
588 let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input)?;
589
590 let implies_trait = if let Some(implies_trait) = Trait::implies_trait() {
591 quote!(unsafe impl #impl_generics #implies_trait for #name #ty_generics #where_clause {})
592 } else {
593 quote!()
594 };
595
596 let where_clause =
597 if Trait::requires_where_clause() { where_clause } else { None };
598
599 Ok(quote! {
600 #asserts
601
602 #trait_impl_extras
603
604 unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
605 #trait_impl
606 }
607
608 #implies_trait
609 })
610 }
611
612 /// Add a trait marker to the generics if it is not already present
add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path)613 fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
614 // Get each generic type parameter.
615 let type_params = generics
616 .type_params()
617 .map(|param| ¶m.ident)
618 .map(|param| {
619 syn::parse_quote!(
620 #param: #trait_name
621 )
622 })
623 .collect::<Vec<syn::WherePredicate>>();
624
625 generics.make_where_clause().predicates.extend(type_params);
626 }
627