• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 extern crate proc_macro;
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::quote;
5 use syn::*;
6 
7 mod container_attributes;
8 mod field_attributes;
9 mod variant_attributes;
10 
11 use container_attributes::ContainerAttributes;
12 use field_attributes::{determine_field_constructor, FieldConstructor};
13 use variant_attributes::not_skipped;
14 
15 const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
16 const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
17 
18 #[proc_macro_derive(Arbitrary, attributes(arbitrary))]
derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream19 pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
20     let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
21     expand_derive_arbitrary(input)
22         .unwrap_or_else(syn::Error::into_compile_error)
23         .into()
24 }
25 
expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream>26 fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
27     let container_attrs = ContainerAttributes::from_derive_input(&input)?;
28 
29     let (lifetime_without_bounds, lifetime_with_bounds) =
30         build_arbitrary_lifetime(input.generics.clone());
31 
32     let recursive_count = syn::Ident::new(
33         &format!("RECURSIVE_COUNT_{}", input.ident),
34         Span::call_site(),
35     );
36 
37     let arbitrary_method =
38         gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
39     let size_hint_method = gen_size_hint_method(&input)?;
40     let name = input.ident;
41 
42     // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
43     let generics = apply_trait_bounds(
44         input.generics,
45         lifetime_without_bounds.clone(),
46         &container_attrs,
47     )?;
48 
49     // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
50     let mut generics_with_lifetime = generics.clone();
51     generics_with_lifetime
52         .params
53         .push(GenericParam::Lifetime(lifetime_with_bounds));
54     let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
55 
56     // Build TypeGenerics and WhereClause without a lifetime
57     let (_, ty_generics, where_clause) = generics.split_for_impl();
58 
59     Ok(quote! {
60         const _: () = {
61             ::std::thread_local! {
62                 #[allow(non_upper_case_globals)]
63                 static #recursive_count: ::core::cell::Cell<u32> = ::core::cell::Cell::new(0);
64             }
65 
66             #[automatically_derived]
67             impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
68                 #arbitrary_method
69                 #size_hint_method
70             }
71         };
72     })
73 }
74 
75 // Returns: (lifetime without bounds, lifetime with bounds)
76 // Example: ("'arbitrary", "'arbitrary: 'a + 'b")
build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam)77 fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
78     let lifetime_without_bounds =
79         LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
80     let mut lifetime_with_bounds = lifetime_without_bounds.clone();
81 
82     for param in generics.params.iter() {
83         if let GenericParam::Lifetime(lifetime_def) = param {
84             lifetime_with_bounds
85                 .bounds
86                 .push(lifetime_def.lifetime.clone());
87         }
88     }
89 
90     (lifetime_without_bounds, lifetime_with_bounds)
91 }
92 
apply_trait_bounds( mut generics: Generics, lifetime: LifetimeParam, container_attrs: &ContainerAttributes, ) -> Result<Generics>93 fn apply_trait_bounds(
94     mut generics: Generics,
95     lifetime: LifetimeParam,
96     container_attrs: &ContainerAttributes,
97 ) -> Result<Generics> {
98     // If user-supplied bounds exist, apply them to their matching type parameters.
99     if let Some(config_bounds) = &container_attrs.bounds {
100         let mut config_bounds_applied = 0;
101         for param in generics.params.iter_mut() {
102             if let GenericParam::Type(type_param) = param {
103                 if let Some(replacement) = config_bounds
104                     .iter()
105                     .flatten()
106                     .find(|p| p.ident == type_param.ident)
107                 {
108                     *type_param = replacement.clone();
109                     config_bounds_applied += 1;
110                 } else {
111                     // If no user-supplied bounds exist for this type, delete the original bounds.
112                     // This mimics serde.
113                     type_param.bounds = Default::default();
114                     type_param.default = None;
115                 }
116             }
117         }
118         let config_bounds_supplied = config_bounds
119             .iter()
120             .map(|bounds| bounds.len())
121             .sum::<usize>();
122         if config_bounds_applied != config_bounds_supplied {
123             return Err(Error::new(
124                 Span::call_site(),
125                 format!(
126                     "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
127                     ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
128                 ),
129             ));
130         }
131         Ok(generics)
132     } else {
133         // Otherwise, inject a `T: Arbitrary` bound for every parameter.
134         Ok(add_trait_bounds(generics, lifetime))
135     }
136 }
137 
138 // Add a bound `T: Arbitrary` to every type parameter T.
add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics139 fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
140     for param in generics.params.iter_mut() {
141         if let GenericParam::Type(type_param) = param {
142             type_param
143                 .bounds
144                 .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
145         }
146     }
147     generics
148 }
149 
with_recursive_count_guard( recursive_count: &syn::Ident, expr: impl quote::ToTokens, ) -> impl quote::ToTokens150 fn with_recursive_count_guard(
151     recursive_count: &syn::Ident,
152     expr: impl quote::ToTokens,
153 ) -> impl quote::ToTokens {
154     quote! {
155         let guard_against_recursion = u.is_empty();
156         if guard_against_recursion {
157             #recursive_count.with(|count| {
158                 if count.get() > 0 {
159                     return Err(arbitrary::Error::NotEnoughData);
160                 }
161                 count.set(count.get() + 1);
162                 Ok(())
163             })?;
164         }
165 
166         let result = (|| { #expr })();
167 
168         if guard_against_recursion {
169             #recursive_count.with(|count| {
170                 count.set(count.get() - 1);
171             });
172         }
173 
174         result
175     }
176 }
177 
gen_arbitrary_method( input: &DeriveInput, lifetime: LifetimeParam, recursive_count: &syn::Ident, ) -> Result<TokenStream>178 fn gen_arbitrary_method(
179     input: &DeriveInput,
180     lifetime: LifetimeParam,
181     recursive_count: &syn::Ident,
182 ) -> Result<TokenStream> {
183     fn arbitrary_structlike(
184         fields: &Fields,
185         ident: &syn::Ident,
186         lifetime: LifetimeParam,
187         recursive_count: &syn::Ident,
188     ) -> Result<TokenStream> {
189         let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
190         let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
191 
192         let arbitrary_take_rest = construct_take_rest(fields)?;
193         let take_rest_body =
194             with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
195 
196         Ok(quote! {
197             fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
198                 #body
199             }
200 
201             fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
202                 #take_rest_body
203             }
204         })
205     }
206 
207     fn arbitrary_variant(
208         index: u64,
209         enum_name: &Ident,
210         variant_name: &Ident,
211         ctor: TokenStream,
212     ) -> TokenStream {
213         quote! { #index => #enum_name::#variant_name #ctor }
214     }
215 
216     fn arbitrary_enum_method(
217         recursive_count: &syn::Ident,
218         unstructured: TokenStream,
219         variants: &[TokenStream],
220     ) -> impl quote::ToTokens {
221         let count = variants.len() as u64;
222         with_recursive_count_guard(
223             recursive_count,
224             quote! {
225                 // Use a multiply + shift to generate a ranged random number
226                 // with slight bias. For details, see:
227                 // https://lemire.me/blog/2016/06/30/fast-random-shuffling
228                 Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
229                     #(#variants,)*
230                     _ => unreachable!()
231                 })
232             },
233         )
234     }
235 
236     fn arbitrary_enum(
237         DataEnum { variants, .. }: &DataEnum,
238         enum_name: &Ident,
239         lifetime: LifetimeParam,
240         recursive_count: &syn::Ident,
241     ) -> Result<TokenStream> {
242         let filtered_variants = variants.iter().filter(not_skipped);
243 
244         // Check attributes of all variants:
245         filtered_variants
246             .clone()
247             .try_for_each(check_variant_attrs)?;
248 
249         // From here on, we can assume that the attributes of all variants were checked.
250         let enumerated_variants = filtered_variants
251             .enumerate()
252             .map(|(index, variant)| (index as u64, variant));
253 
254         // Construct `match`-arms for the `arbitrary` method.
255         let variants = enumerated_variants
256             .clone()
257             .map(|(index, Variant { fields, ident, .. })| {
258                 construct(fields, |_, field| gen_constructor_for_field(field))
259                     .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
260             })
261             .collect::<Result<Vec<TokenStream>>>()?;
262 
263         // Construct `match`-arms for the `arbitrary_take_rest` method.
264         let variants_take_rest = enumerated_variants
265             .map(|(index, Variant { fields, ident, .. })| {
266                 construct_take_rest(fields)
267                     .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
268             })
269             .collect::<Result<Vec<TokenStream>>>()?;
270 
271         // Most of the time, `variants` is not empty (the happy path),
272         //   thus `variants_take_rest` will be used,
273         //   so no need to move this check before constructing `variants_take_rest`.
274         // If `variants` is empty, this will emit a compiler-error.
275         (!variants.is_empty())
276             .then(|| {
277                 // TODO: Improve dealing with `u` vs. `&mut u`.
278                 let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
279                 let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);
280 
281                 quote! {
282                     fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
283                         #arbitrary
284                     }
285 
286                     fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
287                         #arbitrary_take_rest
288                     }
289                 }
290             })
291             .ok_or_else(|| Error::new_spanned(
292                 enum_name,
293                 "Enum must have at least one variant, that is not skipped"
294             ))
295     }
296 
297     let ident = &input.ident;
298     match &input.data {
299         Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
300         Data::Union(data) => arbitrary_structlike(
301             &Fields::Named(data.fields.clone()),
302             ident,
303             lifetime,
304             recursive_count,
305         ),
306         Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
307     }
308 }
309 
construct( fields: &Fields, ctor: impl Fn(usize, &Field) -> Result<TokenStream>, ) -> Result<TokenStream>310 fn construct(
311     fields: &Fields,
312     ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
313 ) -> Result<TokenStream> {
314     let output = match fields {
315         Fields::Named(names) => {
316             let names: Vec<TokenStream> = names
317                 .named
318                 .iter()
319                 .enumerate()
320                 .map(|(i, f)| {
321                     let name = f.ident.as_ref().unwrap();
322                     ctor(i, f).map(|ctor| quote! { #name: #ctor })
323                 })
324                 .collect::<Result<_>>()?;
325             quote! { { #(#names,)* } }
326         }
327         Fields::Unnamed(names) => {
328             let names: Vec<TokenStream> = names
329                 .unnamed
330                 .iter()
331                 .enumerate()
332                 .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
333                 .collect::<Result<_>>()?;
334             quote! { ( #(#names),* ) }
335         }
336         Fields::Unit => quote!(),
337     };
338     Ok(output)
339 }
340 
construct_take_rest(fields: &Fields) -> Result<TokenStream>341 fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
342     construct(fields, |idx, field| {
343         determine_field_constructor(field).map(|field_constructor| match field_constructor {
344             FieldConstructor::Default => quote!(::core::default::Default::default()),
345             FieldConstructor::Arbitrary => {
346                 if idx + 1 == fields.len() {
347                     quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
348                 } else {
349                     quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
350                 }
351             }
352             FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
353             FieldConstructor::Value(value) => quote!(#value),
354         })
355     })
356 }
357 
gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream>358 fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
359     let size_hint_fields = |fields: &Fields| {
360         fields
361             .iter()
362             .map(|f| {
363                 let ty = &f.ty;
364                 determine_field_constructor(f).map(|field_constructor| {
365                     match field_constructor {
366                         FieldConstructor::Default | FieldConstructor::Value(_) => {
367                             quote!(Ok((0, Some(0))))
368                         }
369                         FieldConstructor::Arbitrary => {
370                             quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
371                         }
372 
373                         // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
374                         // just an educated guess, although it's gonna be inaccurate for dynamically
375                         // allocated types (Vec, HashMap, etc.).
376                         FieldConstructor::With(_) => {
377                             quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
378                         }
379                     }
380                 })
381             })
382             .collect::<Result<Vec<TokenStream>>>()
383             .map(|hints| {
384                 quote! {
385                     Ok(arbitrary::size_hint::and_all(&[
386                         #( #hints? ),*
387                     ]))
388                 }
389             })
390     };
391     let size_hint_structlike = |fields: &Fields| {
392         size_hint_fields(fields).map(|hint| {
393             quote! {
394                 #[inline]
395                 fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
396                     Self::try_size_hint(depth).unwrap_or_default()
397                 }
398 
399                 #[inline]
400                 fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
401                     arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
402                 }
403             }
404         })
405     };
406     match &input.data {
407         Data::Struct(data) => size_hint_structlike(&data.fields),
408         Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
409         Data::Enum(data) => data
410             .variants
411             .iter()
412             .filter(not_skipped)
413             .map(|Variant { fields, .. }| {
414                 // The attributes of all variants are checked in `gen_arbitrary_method` above
415                 //   and can therefore assume that they are valid.
416                 size_hint_fields(fields)
417             })
418             .collect::<Result<Vec<TokenStream>>>()
419             .map(|variants| {
420                 quote! {
421                     fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
422                         Self::try_size_hint(depth).unwrap_or_default()
423                     }
424                     #[inline]
425                     fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
426                         Ok(arbitrary::size_hint::and(
427                             <u32 as arbitrary::Arbitrary>::try_size_hint(depth)?,
428                             arbitrary::size_hint::try_recursion_guard(depth, |depth| {
429                                 Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
430                             })?,
431                         ))
432                     }
433                 }
434             }),
435     }
436 }
437 
gen_constructor_for_field(field: &Field) -> Result<TokenStream>438 fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
439     let ctor = match determine_field_constructor(field)? {
440         FieldConstructor::Default => quote!(::core::default::Default::default()),
441         FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
442         FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
443         FieldConstructor::Value(value) => quote!(#value),
444     };
445     Ok(ctor)
446 }
447 
check_variant_attrs(variant: &Variant) -> Result<()>448 fn check_variant_attrs(variant: &Variant) -> Result<()> {
449     for attr in &variant.attrs {
450         if attr.path().is_ident(ARBITRARY_ATTRIBUTE_NAME) {
451             return Err(Error::new_spanned(
452                 attr,
453                 format!(
454                     "invalid `{}` attribute. it is unsupported on enum variants. try applying it to a field of the variant instead",
455                     ARBITRARY_ATTRIBUTE_NAME
456                 ),
457             ));
458         }
459     }
460     Ok(())
461 }
462