• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::{cfg, file, lookup};
2 use anyhow::Result;
3 use proc_macro2::{Ident, Span, TokenStream};
4 use quote::{format_ident, quote};
5 use syn_codegen::{Data, Definitions, Node, Type};
6 
7 const EQ_SRC: &str = "src/gen/eq.rs";
8 
always_eq(field_type: &Type) -> bool9 fn always_eq(field_type: &Type) -> bool {
10     match field_type {
11         Type::Ext(ty) => ty == "Span",
12         Type::Token(_) | Type::Group(_) => true,
13         Type::Box(inner) => always_eq(inner),
14         Type::Tuple(inner) => inner.iter().all(always_eq),
15         _ => false,
16     }
17 }
18 
expand_impl_body(defs: &Definitions, node: &Node) -> TokenStream19 fn expand_impl_body(defs: &Definitions, node: &Node) -> TokenStream {
20     let type_name = &node.ident;
21     let ident = Ident::new(type_name, Span::call_site());
22 
23     match &node.data {
24         Data::Enum(variants) if variants.is_empty() => quote!(match *self {}),
25         Data::Enum(variants) => {
26             let arms = variants.iter().map(|(variant_name, fields)| {
27                 let variant = Ident::new(variant_name, Span::call_site());
28                 if fields.is_empty() {
29                     quote! {
30                         (#ident::#variant, #ident::#variant) => true,
31                     }
32                 } else {
33                     let mut this_pats = Vec::new();
34                     let mut other_pats = Vec::new();
35                     let mut comparisons = Vec::new();
36                     for (i, field) in fields.iter().enumerate() {
37                         if always_eq(field) {
38                             this_pats.push(format_ident!("_"));
39                             other_pats.push(format_ident!("_"));
40                             continue;
41                         }
42                         let this = format_ident!("self{}", i);
43                         let other = format_ident!("other{}", i);
44                         comparisons.push(match field {
45                             Type::Ext(ty) if ty == "TokenStream" => {
46                                 quote!(TokenStreamHelper(#this) == TokenStreamHelper(#other))
47                             }
48                             Type::Ext(ty) if ty == "Literal" => {
49                                 quote!(#this.to_string() == #other.to_string())
50                             }
51                             _ => quote!(#this == #other),
52                         });
53                         this_pats.push(this);
54                         other_pats.push(other);
55                     }
56                     if comparisons.is_empty() {
57                         comparisons.push(quote!(true));
58                     }
59                     let mut cfg = None;
60                     if node.ident == "Expr" {
61                         if let Type::Syn(ty) = &fields[0] {
62                             if !lookup::node(defs, ty).features.any.contains("derive") {
63                                 cfg = Some(quote!(#[cfg(feature = "full")]));
64                             }
65                         }
66                     }
67                     quote! {
68                         #cfg
69                         (#ident::#variant(#(#this_pats),*), #ident::#variant(#(#other_pats),*)) => {
70                             #(#comparisons)&&*
71                         }
72                     }
73                 }
74             });
75             let fallthrough = if variants.len() == 1 {
76                 None
77             } else {
78                 Some(quote!(_ => false,))
79             };
80             quote! {
81                 match (self, other) {
82                     #(#arms)*
83                     #fallthrough
84                 }
85             }
86         }
87         Data::Struct(fields) => {
88             let mut comparisons = Vec::new();
89             for (f, ty) in fields {
90                 if always_eq(ty) {
91                     continue;
92                 }
93                 let ident = Ident::new(f, Span::call_site());
94                 comparisons.push(match ty {
95                     Type::Ext(ty) if ty == "TokenStream" => {
96                         quote!(TokenStreamHelper(&self.#ident) == TokenStreamHelper(&other.#ident))
97                     }
98                     _ => quote!(self.#ident == other.#ident),
99                 });
100             }
101             if comparisons.is_empty() {
102                 quote!(true)
103             } else {
104                 quote!(#(#comparisons)&&*)
105             }
106         }
107         Data::Private => unreachable!(),
108     }
109 }
110 
expand_impl(defs: &Definitions, node: &Node) -> TokenStream111 fn expand_impl(defs: &Definitions, node: &Node) -> TokenStream {
112     if node.ident == "Member" || node.ident == "Index" || node.ident == "Lifetime" {
113         return TokenStream::new();
114     }
115 
116     let ident = Ident::new(&node.ident, Span::call_site());
117     let cfg_features = cfg::features(&node.features, "extra-traits");
118 
119     let eq = quote! {
120         #cfg_features
121         impl Eq for #ident {}
122     };
123 
124     let manual_partial_eq = node.data == Data::Private;
125     if manual_partial_eq {
126         return eq;
127     }
128 
129     let body = expand_impl_body(defs, node);
130     let other = match &node.data {
131         Data::Enum(variants) if variants.is_empty() => quote!(_other),
132         Data::Struct(fields) if fields.values().all(always_eq) => quote!(_other),
133         _ => quote!(other),
134     };
135 
136     quote! {
137         #eq
138 
139         #cfg_features
140         impl PartialEq for #ident {
141             fn eq(&self, #other: &Self) -> bool {
142                 #body
143             }
144         }
145     }
146 }
147 
generate(defs: &Definitions) -> Result<()>148 pub fn generate(defs: &Definitions) -> Result<()> {
149     let mut impls = TokenStream::new();
150     for node in &defs.types {
151         impls.extend(expand_impl(defs, node));
152     }
153 
154     file::write(
155         EQ_SRC,
156         quote! {
157             #[cfg(any(feature = "derive", feature = "full"))]
158             use crate::tt::TokenStreamHelper;
159             use crate::*;
160 
161             #impls
162         },
163     )?;
164 
165     Ok(())
166 }
167