• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright © 2023 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3 
4 extern crate proc_macro;
5 extern crate proc_macro2;
6 #[macro_use]
7 extern crate quote;
8 extern crate syn;
9 
10 use proc_macro::TokenStream;
11 use proc_macro2::{Span, TokenStream as TokenStream2};
12 use syn::*;
13 
expr_as_usize(expr: &syn::Expr) -> usize14 fn expr_as_usize(expr: &syn::Expr) -> usize {
15     let lit = match expr {
16         syn::Expr::Lit(lit) => lit,
17         _ => panic!("Expected a literal, found an expression"),
18     };
19     let lit_int = match &lit.lit {
20         syn::Lit::Int(i) => i,
21         _ => panic!("Expected a literal integer"),
22     };
23     assert!(lit.attrs.is_empty());
24     lit_int
25         .base10_parse()
26         .expect("Failed to parse integer literal")
27 }
28 
count_type(ty: &Type, search_type: &str) -> usize29 fn count_type(ty: &Type, search_type: &str) -> usize {
30     match ty {
31         syn::Type::Array(a) => {
32             let elems = count_type(a.elem.as_ref(), search_type);
33             if elems > 0 {
34                 elems * expr_as_usize(&a.len)
35             } else {
36                 0
37             }
38         }
39         syn::Type::Path(p) => {
40             if p.qself.is_none() && p.path.is_ident(search_type) {
41                 1
42             } else {
43                 0
44             }
45         }
46         _ => 0,
47     }
48 }
49 
get_src_type(field: &Field) -> Option<String>50 fn get_src_type(field: &Field) -> Option<String> {
51     for attr in &field.attrs {
52         if let Meta::List(ml) = &attr.meta {
53             if ml.path.is_ident("src_type") {
54                 return Some(format!("{}", ml.tokens));
55             }
56         }
57     }
58     None
59 }
60 
derive_as_slice( input: TokenStream, trait_name: &str, func_prefix: &str, search_type: &str, ) -> TokenStream61 fn derive_as_slice(
62     input: TokenStream,
63     trait_name: &str,
64     func_prefix: &str,
65     search_type: &str,
66 ) -> TokenStream {
67     let DeriveInput {
68         attrs, ident, data, ..
69     } = parse_macro_input!(input);
70 
71     let trait_name = Ident::new(trait_name, Span::call_site());
72     let elem_type = Ident::new(search_type, Span::call_site());
73     let as_slice =
74         Ident::new(&format!("{}_as_slice", func_prefix), Span::call_site());
75     let as_mut_slice =
76         Ident::new(&format!("{}_as_mut_slice", func_prefix), Span::call_site());
77 
78     match data {
79         Data::Struct(s) => {
80             let mut has_repr_c = false;
81             for attr in attrs {
82                 match attr.meta {
83                     Meta::List(ml) => {
84                         if ml.path.is_ident("repr")
85                             && format!("{}", ml.tokens) == "C"
86                         {
87                             has_repr_c = true;
88                         }
89                     }
90                     _ => (),
91                 }
92             }
93             assert!(has_repr_c, "Struct must be declared #[repr(C)]");
94 
95             let mut first = None;
96             let mut count = 0_usize;
97             let mut found_last = false;
98             let mut src_types = TokenStream2::new();
99 
100             if let Fields::Named(named) = s.fields {
101                 for f in named.named {
102                     let ty_count = count_type(&f.ty, search_type);
103 
104                     if search_type == "Src" {
105                         let src_type = get_src_type(&f);
106                         if ty_count == 0 && !src_type.is_none() {
107                             panic!(
108                                 "src_type attribute is only allowed on sources"
109                             );
110                         }
111 
112                         let src_type = if let Some(s) = src_type {
113                             let s = syn::parse_str::<Ident>(&s).unwrap();
114                             quote! { SrcType::#s, }
115                         } else {
116                             quote! { SrcType::DEFAULT, }
117                         };
118 
119                         for _ in 0..ty_count {
120                             src_types.extend(src_type.clone());
121                         }
122                     }
123 
124                     if ty_count > 0 {
125                         assert!(
126                             !found_last,
127                             "All fields of type {} must be consecutive",
128                             search_type
129                         );
130                         first.get_or_insert(f.ident);
131                         count += ty_count;
132                     } else {
133                         if !first.is_none() {
134                             found_last = true;
135                         }
136                     }
137                 }
138             } else {
139                 panic!("Fields are not named");
140             }
141 
142             let src_type_func = if search_type == "Src" {
143                 quote! {
144                     fn src_types(&self) -> SrcTypeList {
145                         static SRC_TYPES: [SrcType; #count]  = [#src_types];
146                         SrcTypeList::Array(&SRC_TYPES)
147                     }
148                 }
149             } else {
150                 TokenStream2::new()
151             };
152 
153             if let Some(name) = first {
154                 quote! {
155                     impl #trait_name for #ident {
156                         fn #as_slice(&self) -> &[#elem_type] {
157                             unsafe {
158                                 let first = &self.#name as *const #elem_type;
159                                 std::slice::from_raw_parts(first, #count)
160                             }
161                         }
162 
163                         fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
164                             unsafe {
165                                 let first = &mut self.#name as *mut #elem_type;
166                                 std::slice::from_raw_parts_mut(first, #count)
167                             }
168                         }
169 
170                         #src_type_func
171                     }
172                 }
173             } else {
174                 quote! {
175                     impl #trait_name for #ident {
176                         fn #as_slice(&self) -> &[#elem_type] {
177                             &[]
178                         }
179 
180                         fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
181                             &mut []
182                         }
183 
184                         #src_type_func
185                     }
186                 }
187             }
188             .into()
189         }
190         Data::Enum(e) => {
191             let mut as_slice_cases = TokenStream2::new();
192             let mut as_mut_slice_cases = TokenStream2::new();
193             let mut src_types_cases = TokenStream2::new();
194             for v in e.variants {
195                 let case = v.ident;
196                 as_slice_cases.extend(quote! {
197                     #ident::#case(x) => x.#as_slice(),
198                 });
199                 as_mut_slice_cases.extend(quote! {
200                     #ident::#case(x) => x.#as_mut_slice(),
201                 });
202                 if search_type == "Src" {
203                     src_types_cases.extend(quote! {
204                         #ident::#case(x) => x.src_types(),
205                     });
206                 }
207             }
208             let src_type_func = if search_type == "Src" {
209                 quote! {
210                     fn src_types(&self) -> SrcTypeList {
211                         match self {
212                             #src_types_cases
213                         }
214                     }
215                 }
216             } else {
217                 TokenStream2::new()
218             };
219             quote! {
220                 impl #trait_name for #ident {
221                     fn #as_slice(&self) -> &[#elem_type] {
222                         match self {
223                             #as_slice_cases
224                         }
225                     }
226 
227                     fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
228                         match self {
229                             #as_mut_slice_cases
230                         }
231                     }
232 
233                     #src_type_func
234                 }
235             }
236             .into()
237         }
238         _ => panic!("Not a struct type"),
239     }
240 }
241 
242 #[proc_macro_derive(SrcsAsSlice, attributes(src_type))]
derive_srcs_as_slice(input: TokenStream) -> TokenStream243 pub fn derive_srcs_as_slice(input: TokenStream) -> TokenStream {
244     derive_as_slice(input, "SrcsAsSlice", "srcs", "Src")
245 }
246 
247 #[proc_macro_derive(DstsAsSlice)]
derive_dsts_as_slice(input: TokenStream) -> TokenStream248 pub fn derive_dsts_as_slice(input: TokenStream) -> TokenStream {
249     derive_as_slice(input, "DstsAsSlice", "dsts", "Dst")
250 }
251 
252 #[proc_macro_derive(DisplayOp)]
enum_derive_display_op(input: TokenStream) -> TokenStream253 pub fn enum_derive_display_op(input: TokenStream) -> TokenStream {
254     let DeriveInput { ident, data, .. } = parse_macro_input!(input);
255 
256     if let Data::Enum(e) = data {
257         let mut fmt_dsts_cases = TokenStream2::new();
258         let mut fmt_op_cases = TokenStream2::new();
259         for v in e.variants {
260             let case = v.ident;
261             fmt_dsts_cases.extend(quote! {
262                 #ident::#case(x) => x.fmt_dsts(f),
263             });
264             fmt_op_cases.extend(quote! {
265                 #ident::#case(x) => x.fmt_op(f),
266             });
267         }
268         quote! {
269             impl DisplayOp for #ident {
270                 fn fmt_dsts(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271                     match self {
272                         #fmt_dsts_cases
273                     }
274                 }
275 
276                 fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277                     match self {
278                         #fmt_op_cases
279                     }
280                 }
281             }
282         }
283         .into()
284     } else {
285         panic!("Not an enum type");
286     }
287 }
288 
289 #[proc_macro_derive(FromVariants)]
derive_from_variants(input: TokenStream) -> TokenStream290 pub fn derive_from_variants(input: TokenStream) -> TokenStream {
291     let DeriveInput { ident, data, .. } = parse_macro_input!(input);
292     let enum_type = ident;
293 
294     let mut impls = TokenStream2::new();
295 
296     if let Data::Enum(e) = data {
297         for v in e.variants {
298             let var_ident = v.ident;
299             let from_type = match v.fields {
300                 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => unnamed,
301                 _ => panic!("Expected Op(OpFoo)"),
302             };
303 
304             let quote = quote! {
305                 impl From<#from_type> for #enum_type {
306                     fn from (op: #from_type) -> #enum_type {
307                         #enum_type::#var_ident(op)
308                     }
309                 }
310             };
311 
312             impls.extend(quote);
313         }
314     }
315 
316     impls.into()
317 }
318