1 extern crate proc_macro;
2
3 use proc_macro2::{Span, TokenStream};
4 use quote::quote;
5 use syn::*;
6
7 mod container_attributes;
8 mod field_attributes;
9 mod variant_attributes;
10
11 use container_attributes::ContainerAttributes;
12 use field_attributes::{determine_field_constructor, FieldConstructor};
13 use variant_attributes::not_skipped;
14
15 const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
16 const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
17
18 #[proc_macro_derive(Arbitrary, attributes(arbitrary))]
derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream19 pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
20 let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
21 expand_derive_arbitrary(input)
22 .unwrap_or_else(syn::Error::into_compile_error)
23 .into()
24 }
25
expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream>26 fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
27 let container_attrs = ContainerAttributes::from_derive_input(&input)?;
28
29 let (lifetime_without_bounds, lifetime_with_bounds) =
30 build_arbitrary_lifetime(input.generics.clone());
31
32 let recursive_count = syn::Ident::new(
33 &format!("RECURSIVE_COUNT_{}", input.ident),
34 Span::call_site(),
35 );
36
37 let arbitrary_method =
38 gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
39 let size_hint_method = gen_size_hint_method(&input)?;
40 let name = input.ident;
41
42 // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
43 let generics = apply_trait_bounds(
44 input.generics,
45 lifetime_without_bounds.clone(),
46 &container_attrs,
47 )?;
48
49 // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
50 let mut generics_with_lifetime = generics.clone();
51 generics_with_lifetime
52 .params
53 .push(GenericParam::Lifetime(lifetime_with_bounds));
54 let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
55
56 // Build TypeGenerics and WhereClause without a lifetime
57 let (_, ty_generics, where_clause) = generics.split_for_impl();
58
59 Ok(quote! {
60 const _: () = {
61 ::std::thread_local! {
62 #[allow(non_upper_case_globals)]
63 static #recursive_count: ::core::cell::Cell<u32> = ::core::cell::Cell::new(0);
64 }
65
66 #[automatically_derived]
67 impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
68 #arbitrary_method
69 #size_hint_method
70 }
71 };
72 })
73 }
74
75 // Returns: (lifetime without bounds, lifetime with bounds)
76 // Example: ("'arbitrary", "'arbitrary: 'a + 'b")
build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam)77 fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
78 let lifetime_without_bounds =
79 LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
80 let mut lifetime_with_bounds = lifetime_without_bounds.clone();
81
82 for param in generics.params.iter() {
83 if let GenericParam::Lifetime(lifetime_def) = param {
84 lifetime_with_bounds
85 .bounds
86 .push(lifetime_def.lifetime.clone());
87 }
88 }
89
90 (lifetime_without_bounds, lifetime_with_bounds)
91 }
92
apply_trait_bounds( mut generics: Generics, lifetime: LifetimeParam, container_attrs: &ContainerAttributes, ) -> Result<Generics>93 fn apply_trait_bounds(
94 mut generics: Generics,
95 lifetime: LifetimeParam,
96 container_attrs: &ContainerAttributes,
97 ) -> Result<Generics> {
98 // If user-supplied bounds exist, apply them to their matching type parameters.
99 if let Some(config_bounds) = &container_attrs.bounds {
100 let mut config_bounds_applied = 0;
101 for param in generics.params.iter_mut() {
102 if let GenericParam::Type(type_param) = param {
103 if let Some(replacement) = config_bounds
104 .iter()
105 .flatten()
106 .find(|p| p.ident == type_param.ident)
107 {
108 *type_param = replacement.clone();
109 config_bounds_applied += 1;
110 } else {
111 // If no user-supplied bounds exist for this type, delete the original bounds.
112 // This mimics serde.
113 type_param.bounds = Default::default();
114 type_param.default = None;
115 }
116 }
117 }
118 let config_bounds_supplied = config_bounds
119 .iter()
120 .map(|bounds| bounds.len())
121 .sum::<usize>();
122 if config_bounds_applied != config_bounds_supplied {
123 return Err(Error::new(
124 Span::call_site(),
125 format!(
126 "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
127 ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
128 ),
129 ));
130 }
131 Ok(generics)
132 } else {
133 // Otherwise, inject a `T: Arbitrary` bound for every parameter.
134 Ok(add_trait_bounds(generics, lifetime))
135 }
136 }
137
138 // Add a bound `T: Arbitrary` to every type parameter T.
add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics139 fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
140 for param in generics.params.iter_mut() {
141 if let GenericParam::Type(type_param) = param {
142 type_param
143 .bounds
144 .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
145 }
146 }
147 generics
148 }
149
with_recursive_count_guard( recursive_count: &syn::Ident, expr: impl quote::ToTokens, ) -> impl quote::ToTokens150 fn with_recursive_count_guard(
151 recursive_count: &syn::Ident,
152 expr: impl quote::ToTokens,
153 ) -> impl quote::ToTokens {
154 quote! {
155 let guard_against_recursion = u.is_empty();
156 if guard_against_recursion {
157 #recursive_count.with(|count| {
158 if count.get() > 0 {
159 return Err(arbitrary::Error::NotEnoughData);
160 }
161 count.set(count.get() + 1);
162 Ok(())
163 })?;
164 }
165
166 let result = (|| { #expr })();
167
168 if guard_against_recursion {
169 #recursive_count.with(|count| {
170 count.set(count.get() - 1);
171 });
172 }
173
174 result
175 }
176 }
177
gen_arbitrary_method( input: &DeriveInput, lifetime: LifetimeParam, recursive_count: &syn::Ident, ) -> Result<TokenStream>178 fn gen_arbitrary_method(
179 input: &DeriveInput,
180 lifetime: LifetimeParam,
181 recursive_count: &syn::Ident,
182 ) -> Result<TokenStream> {
183 fn arbitrary_structlike(
184 fields: &Fields,
185 ident: &syn::Ident,
186 lifetime: LifetimeParam,
187 recursive_count: &syn::Ident,
188 ) -> Result<TokenStream> {
189 let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
190 let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
191
192 let arbitrary_take_rest = construct_take_rest(fields)?;
193 let take_rest_body =
194 with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
195
196 Ok(quote! {
197 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
198 #body
199 }
200
201 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
202 #take_rest_body
203 }
204 })
205 }
206
207 fn arbitrary_variant(
208 index: u64,
209 enum_name: &Ident,
210 variant_name: &Ident,
211 ctor: TokenStream,
212 ) -> TokenStream {
213 quote! { #index => #enum_name::#variant_name #ctor }
214 }
215
216 fn arbitrary_enum_method(
217 recursive_count: &syn::Ident,
218 unstructured: TokenStream,
219 variants: &[TokenStream],
220 ) -> impl quote::ToTokens {
221 let count = variants.len() as u64;
222 with_recursive_count_guard(
223 recursive_count,
224 quote! {
225 // Use a multiply + shift to generate a ranged random number
226 // with slight bias. For details, see:
227 // https://lemire.me/blog/2016/06/30/fast-random-shuffling
228 Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
229 #(#variants,)*
230 _ => unreachable!()
231 })
232 },
233 )
234 }
235
236 fn arbitrary_enum(
237 DataEnum { variants, .. }: &DataEnum,
238 enum_name: &Ident,
239 lifetime: LifetimeParam,
240 recursive_count: &syn::Ident,
241 ) -> Result<TokenStream> {
242 let filtered_variants = variants.iter().filter(not_skipped);
243
244 // Check attributes of all variants:
245 filtered_variants
246 .clone()
247 .try_for_each(check_variant_attrs)?;
248
249 // From here on, we can assume that the attributes of all variants were checked.
250 let enumerated_variants = filtered_variants
251 .enumerate()
252 .map(|(index, variant)| (index as u64, variant));
253
254 // Construct `match`-arms for the `arbitrary` method.
255 let variants = enumerated_variants
256 .clone()
257 .map(|(index, Variant { fields, ident, .. })| {
258 construct(fields, |_, field| gen_constructor_for_field(field))
259 .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
260 })
261 .collect::<Result<Vec<TokenStream>>>()?;
262
263 // Construct `match`-arms for the `arbitrary_take_rest` method.
264 let variants_take_rest = enumerated_variants
265 .map(|(index, Variant { fields, ident, .. })| {
266 construct_take_rest(fields)
267 .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
268 })
269 .collect::<Result<Vec<TokenStream>>>()?;
270
271 // Most of the time, `variants` is not empty (the happy path),
272 // thus `variants_take_rest` will be used,
273 // so no need to move this check before constructing `variants_take_rest`.
274 // If `variants` is empty, this will emit a compiler-error.
275 (!variants.is_empty())
276 .then(|| {
277 // TODO: Improve dealing with `u` vs. `&mut u`.
278 let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
279 let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);
280
281 quote! {
282 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
283 #arbitrary
284 }
285
286 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
287 #arbitrary_take_rest
288 }
289 }
290 })
291 .ok_or_else(|| Error::new_spanned(
292 enum_name,
293 "Enum must have at least one variant, that is not skipped"
294 ))
295 }
296
297 let ident = &input.ident;
298 match &input.data {
299 Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
300 Data::Union(data) => arbitrary_structlike(
301 &Fields::Named(data.fields.clone()),
302 ident,
303 lifetime,
304 recursive_count,
305 ),
306 Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
307 }
308 }
309
construct( fields: &Fields, ctor: impl Fn(usize, &Field) -> Result<TokenStream>, ) -> Result<TokenStream>310 fn construct(
311 fields: &Fields,
312 ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
313 ) -> Result<TokenStream> {
314 let output = match fields {
315 Fields::Named(names) => {
316 let names: Vec<TokenStream> = names
317 .named
318 .iter()
319 .enumerate()
320 .map(|(i, f)| {
321 let name = f.ident.as_ref().unwrap();
322 ctor(i, f).map(|ctor| quote! { #name: #ctor })
323 })
324 .collect::<Result<_>>()?;
325 quote! { { #(#names,)* } }
326 }
327 Fields::Unnamed(names) => {
328 let names: Vec<TokenStream> = names
329 .unnamed
330 .iter()
331 .enumerate()
332 .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
333 .collect::<Result<_>>()?;
334 quote! { ( #(#names),* ) }
335 }
336 Fields::Unit => quote!(),
337 };
338 Ok(output)
339 }
340
construct_take_rest(fields: &Fields) -> Result<TokenStream>341 fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
342 construct(fields, |idx, field| {
343 determine_field_constructor(field).map(|field_constructor| match field_constructor {
344 FieldConstructor::Default => quote!(::core::default::Default::default()),
345 FieldConstructor::Arbitrary => {
346 if idx + 1 == fields.len() {
347 quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
348 } else {
349 quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
350 }
351 }
352 FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
353 FieldConstructor::Value(value) => quote!(#value),
354 })
355 })
356 }
357
gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream>358 fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
359 let size_hint_fields = |fields: &Fields| {
360 fields
361 .iter()
362 .map(|f| {
363 let ty = &f.ty;
364 determine_field_constructor(f).map(|field_constructor| {
365 match field_constructor {
366 FieldConstructor::Default | FieldConstructor::Value(_) => {
367 quote!(Ok((0, Some(0))))
368 }
369 FieldConstructor::Arbitrary => {
370 quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
371 }
372
373 // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
374 // just an educated guess, although it's gonna be inaccurate for dynamically
375 // allocated types (Vec, HashMap, etc.).
376 FieldConstructor::With(_) => {
377 quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
378 }
379 }
380 })
381 })
382 .collect::<Result<Vec<TokenStream>>>()
383 .map(|hints| {
384 quote! {
385 Ok(arbitrary::size_hint::and_all(&[
386 #( #hints? ),*
387 ]))
388 }
389 })
390 };
391 let size_hint_structlike = |fields: &Fields| {
392 size_hint_fields(fields).map(|hint| {
393 quote! {
394 #[inline]
395 fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
396 Self::try_size_hint(depth).unwrap_or_default()
397 }
398
399 #[inline]
400 fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
401 arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
402 }
403 }
404 })
405 };
406 match &input.data {
407 Data::Struct(data) => size_hint_structlike(&data.fields),
408 Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
409 Data::Enum(data) => data
410 .variants
411 .iter()
412 .filter(not_skipped)
413 .map(|Variant { fields, .. }| {
414 // The attributes of all variants are checked in `gen_arbitrary_method` above
415 // and can therefore assume that they are valid.
416 size_hint_fields(fields)
417 })
418 .collect::<Result<Vec<TokenStream>>>()
419 .map(|variants| {
420 quote! {
421 fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
422 Self::try_size_hint(depth).unwrap_or_default()
423 }
424 #[inline]
425 fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
426 Ok(arbitrary::size_hint::and(
427 <u32 as arbitrary::Arbitrary>::try_size_hint(depth)?,
428 arbitrary::size_hint::try_recursion_guard(depth, |depth| {
429 Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
430 })?,
431 ))
432 }
433 }
434 }),
435 }
436 }
437
gen_constructor_for_field(field: &Field) -> Result<TokenStream>438 fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
439 let ctor = match determine_field_constructor(field)? {
440 FieldConstructor::Default => quote!(::core::default::Default::default()),
441 FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
442 FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
443 FieldConstructor::Value(value) => quote!(#value),
444 };
445 Ok(ctor)
446 }
447
check_variant_attrs(variant: &Variant) -> Result<()>448 fn check_variant_attrs(variant: &Variant) -> Result<()> {
449 for attr in &variant.attrs {
450 if attr.path().is_ident(ARBITRARY_ATTRIBUTE_NAME) {
451 return Err(Error::new_spanned(
452 attr,
453 format!(
454 "invalid `{}` attribute. it is unsupported on enum variants. try applying it to a field of the variant instead",
455 ARBITRARY_ATTRIBUTE_NAME
456 ),
457 ));
458 }
459 }
460 Ok(())
461 }
462