1 extern crate proc_macro;
2
3 use proc_macro2::{Span, TokenStream};
4 use quote::quote;
5 use syn::*;
6
7 static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
8
9 #[proc_macro_derive(Arbitrary)]
derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream10 pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
11 let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
12 let (lifetime_without_bounds, lifetime_with_bounds) =
13 build_arbitrary_lifetime(input.generics.clone());
14
15 let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone());
16 let size_hint_method = gen_size_hint_method(&input);
17 let name = input.ident;
18 // Add a bound `T: Arbitrary` to every type parameter T.
19 let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone());
20
21 // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
22 let mut generics_with_lifetime = generics.clone();
23 generics_with_lifetime
24 .params
25 .push(GenericParam::Lifetime(lifetime_with_bounds));
26 let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
27
28 // Build TypeGenerics and WhereClause without a lifetime
29 let (_, ty_generics, where_clause) = generics.split_for_impl();
30
31 (quote! {
32 impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
33 #arbitrary_method
34 #size_hint_method
35 }
36 })
37 .into()
38 }
39
40 // Returns: (lifetime without bounds, lifetime with bounds)
41 // Example: ("'arbitrary", "'arbitrary: 'a + 'b")
build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef)42 fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) {
43 let lifetime_without_bounds =
44 LifetimeDef::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
45 let mut lifetime_with_bounds = lifetime_without_bounds.clone();
46
47 for param in generics.params.iter() {
48 if let GenericParam::Lifetime(lifetime_def) = param {
49 lifetime_with_bounds
50 .bounds
51 .push(lifetime_def.lifetime.clone());
52 }
53 }
54
55 (lifetime_without_bounds, lifetime_with_bounds)
56 }
57
58 // Add a bound `T: Arbitrary` to every type parameter T.
add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics59 fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
60 for param in generics.params.iter_mut() {
61 if let GenericParam::Type(type_param) = param {
62 type_param
63 .bounds
64 .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
65 }
66 }
67 generics
68 }
69
gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream70 fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream {
71 let ident = &input.ident;
72 let arbitrary_structlike = |fields| {
73 let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?));
74 let arbitrary_take_rest = construct_take_rest(fields);
75 quote! {
76 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
77 Ok(#ident #arbitrary)
78 }
79
80 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
81 Ok(#ident #arbitrary_take_rest)
82 }
83 }
84 };
85 match &input.data {
86 Data::Struct(data) => arbitrary_structlike(&data.fields),
87 Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())),
88 Data::Enum(data) => {
89 let variants = data.variants.iter().enumerate().map(|(i, variant)| {
90 let idx = i as u64;
91 let ctor = construct(&variant.fields, |_, _| {
92 quote!(arbitrary::Arbitrary::arbitrary(u)?)
93 });
94 let variant_name = &variant.ident;
95 quote! { #idx => #ident::#variant_name #ctor }
96 });
97 let variants_take_rest = data.variants.iter().enumerate().map(|(i, variant)| {
98 let idx = i as u64;
99 let ctor = construct_take_rest(&variant.fields);
100 let variant_name = &variant.ident;
101 quote! { #idx => #ident::#variant_name #ctor }
102 });
103 let count = data.variants.len() as u64;
104 quote! {
105 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
106 // Use a multiply + shift to generate a ranged random number
107 // with slight bias. For details, see:
108 // https://lemire.me/blog/2016/06/30/fast-random-shuffling
109 Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
110 #(#variants,)*
111 _ => unreachable!()
112 })
113 }
114
115 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
116 // Use a multiply + shift to generate a ranged random number
117 // with slight bias. For details, see:
118 // https://lemire.me/blog/2016/06/30/fast-random-shuffling
119 Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
120 #(#variants_take_rest,)*
121 _ => unreachable!()
122 })
123 }
124 }
125 }
126 }
127 }
128
construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> TokenStream129 fn construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> TokenStream {
130 match fields {
131 Fields::Named(names) => {
132 let names = names.named.iter().enumerate().map(|(i, f)| {
133 let name = f.ident.as_ref().unwrap();
134 let ctor = ctor(i, f);
135 quote! { #name: #ctor }
136 });
137 quote! { { #(#names,)* } }
138 }
139 Fields::Unnamed(names) => {
140 let names = names.unnamed.iter().enumerate().map(|(i, f)| {
141 let ctor = ctor(i, f);
142 quote! { #ctor }
143 });
144 quote! { ( #(#names),* ) }
145 }
146 Fields::Unit => quote!(),
147 }
148 }
149
construct_take_rest(fields: &Fields) -> TokenStream150 fn construct_take_rest(fields: &Fields) -> TokenStream {
151 construct(fields, |idx, _| {
152 if idx + 1 == fields.len() {
153 quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
154 } else {
155 quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
156 }
157 })
158 }
159
gen_size_hint_method(input: &DeriveInput) -> TokenStream160 fn gen_size_hint_method(input: &DeriveInput) -> TokenStream {
161 let size_hint_fields = |fields: &Fields| {
162 let tys = fields.iter().map(|f| &f.ty);
163 quote! {
164 arbitrary::size_hint::and_all(&[
165 #( <#tys as arbitrary::Arbitrary>::size_hint(depth) ),*
166 ])
167 }
168 };
169 let size_hint_structlike = |fields: &Fields| {
170 let hint = size_hint_fields(fields);
171 quote! {
172 #[inline]
173 fn size_hint(depth: usize) -> (usize, Option<usize>) {
174 arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
175 }
176 }
177 };
178 match &input.data {
179 Data::Struct(data) => size_hint_structlike(&data.fields),
180 Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
181 Data::Enum(data) => {
182 let variants = data.variants.iter().map(|v| size_hint_fields(&v.fields));
183 quote! {
184 #[inline]
185 fn size_hint(depth: usize) -> (usize, Option<usize>) {
186 arbitrary::size_hint::and(
187 <u32 as arbitrary::Arbitrary>::size_hint(depth),
188 arbitrary::size_hint::recursion_guard(depth, |depth| {
189 arbitrary::size_hint::or_all(&[ #( #variants ),* ])
190 }),
191 )
192 }
193 }
194 }
195 }
196 }
197