• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::ARBITRARY_ATTRIBUTE_NAME;
2 use proc_macro2::{Group, Span, TokenStream, TokenTree};
3 use quote::quote;
4 use syn::{spanned::Spanned, *};
5 
6 /// Determines how a value for a field should be constructed.
7 #[cfg_attr(test, derive(Debug))]
8 pub enum FieldConstructor {
9     /// Assume that Arbitrary is defined for the type of this field and use it (default)
10     Arbitrary,
11 
12     /// Places `Default::default()` as a field value.
13     Default,
14 
15     /// Use custom function or closure to generate a value for a field.
16     With(TokenStream),
17 
18     /// Set a field always to the given value.
19     Value(TokenStream),
20 }
21 
determine_field_constructor(field: &Field) -> Result<FieldConstructor>22 pub fn determine_field_constructor(field: &Field) -> Result<FieldConstructor> {
23     let opt_attr = fetch_attr_from_field(field)?;
24     let ctor = match opt_attr {
25         Some(attr) => parse_attribute(attr)?,
26         None => FieldConstructor::Arbitrary,
27     };
28     Ok(ctor)
29 }
30 
fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>>31 fn fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>> {
32     let found_attributes: Vec<_> = field
33         .attrs
34         .iter()
35         .filter(|a| {
36             let path = &a.path;
37             let name = quote!(#path).to_string();
38             name == ARBITRARY_ATTRIBUTE_NAME
39         })
40         .collect();
41     if found_attributes.len() > 1 {
42         let name = field.ident.as_ref().unwrap();
43         let msg = format!(
44             "Multiple conflicting #[{ARBITRARY_ATTRIBUTE_NAME}] attributes found on field `{name}`"
45         );
46         return Err(syn::Error::new(field.span(), msg));
47     }
48     Ok(found_attributes.into_iter().next())
49 }
50 
parse_attribute(attr: &Attribute) -> Result<FieldConstructor>51 fn parse_attribute(attr: &Attribute) -> Result<FieldConstructor> {
52     let group = {
53         let mut tokens_iter = attr.clone().tokens.into_iter();
54         let token = tokens_iter.next().ok_or_else(|| {
55             let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty.");
56             syn::Error::new(attr.span(), msg)
57         })?;
58         match token {
59             TokenTree::Group(g) => g,
60             t => {
61                 let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] must contain a group, got: {t})");
62                 return Err(syn::Error::new(attr.span(), msg));
63             }
64         }
65     };
66     parse_attribute_internals(group)
67 }
68 
parse_attribute_internals(group: Group) -> Result<FieldConstructor>69 fn parse_attribute_internals(group: Group) -> Result<FieldConstructor> {
70     let stream = group.stream();
71     let mut tokens_iter = stream.into_iter();
72     let token = tokens_iter.next().ok_or_else(|| {
73         let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty.");
74         syn::Error::new(group.span(), msg)
75     })?;
76     match token.to_string().as_ref() {
77         "default" => Ok(FieldConstructor::Default),
78         "with" => {
79             let func_path = parse_assigned_value("with", tokens_iter, group.span())?;
80             Ok(FieldConstructor::With(func_path))
81         }
82         "value" => {
83             let value = parse_assigned_value("value", tokens_iter, group.span())?;
84             Ok(FieldConstructor::Value(value))
85         }
86         _ => {
87             let msg = format!("Unknown option for #[{ARBITRARY_ATTRIBUTE_NAME}]: `{token}`");
88             Err(syn::Error::new(token.span(), msg))
89         }
90     }
91 }
92 
93 // Input:
94 //     = 2 + 2
95 // Output:
96 //     2 + 2
parse_assigned_value( opt_name: &str, mut tokens_iter: impl Iterator<Item = TokenTree>, default_span: Span, ) -> Result<TokenStream>97 fn parse_assigned_value(
98     opt_name: &str,
99     mut tokens_iter: impl Iterator<Item = TokenTree>,
100     default_span: Span,
101 ) -> Result<TokenStream> {
102     let eq_sign = tokens_iter.next().ok_or_else(|| {
103         let msg = format!(
104             "Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], `{opt_name}` is missing assignment."
105         );
106         syn::Error::new(default_span, msg)
107     })?;
108 
109     if eq_sign.to_string() == "=" {
110         Ok(tokens_iter.collect())
111     } else {
112         let msg = format!("Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], expected `=` after `{opt_name}`, got: `{eq_sign}`");
113         Err(syn::Error::new(eq_sign.span(), msg))
114     }
115 }
116