1 // Copyright © 2023 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3
4 extern crate proc_macro;
5 extern crate proc_macro2;
6 #[macro_use]
7 extern crate quote;
8 extern crate syn;
9
10 use proc_macro::TokenStream;
11 use proc_macro2::{Span, TokenStream as TokenStream2};
12 use syn::*;
13
expr_as_usize(expr: &syn::Expr) -> usize14 fn expr_as_usize(expr: &syn::Expr) -> usize {
15 let lit = match expr {
16 syn::Expr::Lit(lit) => lit,
17 _ => panic!("Expected a literal, found an expression"),
18 };
19 let lit_int = match &lit.lit {
20 syn::Lit::Int(i) => i,
21 _ => panic!("Expected a literal integer"),
22 };
23 assert!(lit.attrs.is_empty());
24 lit_int
25 .base10_parse()
26 .expect("Failed to parse integer literal")
27 }
28
count_type(ty: &Type, search_type: &str) -> usize29 fn count_type(ty: &Type, search_type: &str) -> usize {
30 match ty {
31 syn::Type::Array(a) => {
32 let elems = count_type(a.elem.as_ref(), search_type);
33 if elems > 0 {
34 elems * expr_as_usize(&a.len)
35 } else {
36 0
37 }
38 }
39 syn::Type::Path(p) => {
40 if p.qself.is_none() && p.path.is_ident(search_type) {
41 1
42 } else {
43 0
44 }
45 }
46 _ => 0,
47 }
48 }
49
get_src_type(field: &Field) -> Option<String>50 fn get_src_type(field: &Field) -> Option<String> {
51 for attr in &field.attrs {
52 if let Meta::List(ml) = &attr.meta {
53 if ml.path.is_ident("src_type") {
54 return Some(format!("{}", ml.tokens));
55 }
56 }
57 }
58 None
59 }
60
derive_as_slice( input: TokenStream, trait_name: &str, func_prefix: &str, search_type: &str, ) -> TokenStream61 fn derive_as_slice(
62 input: TokenStream,
63 trait_name: &str,
64 func_prefix: &str,
65 search_type: &str,
66 ) -> TokenStream {
67 let DeriveInput {
68 attrs, ident, data, ..
69 } = parse_macro_input!(input);
70
71 let trait_name = Ident::new(trait_name, Span::call_site());
72 let elem_type = Ident::new(search_type, Span::call_site());
73 let as_slice =
74 Ident::new(&format!("{}_as_slice", func_prefix), Span::call_site());
75 let as_mut_slice =
76 Ident::new(&format!("{}_as_mut_slice", func_prefix), Span::call_site());
77
78 match data {
79 Data::Struct(s) => {
80 let mut has_repr_c = false;
81 for attr in attrs {
82 match attr.meta {
83 Meta::List(ml) => {
84 if ml.path.is_ident("repr")
85 && format!("{}", ml.tokens) == "C"
86 {
87 has_repr_c = true;
88 }
89 }
90 _ => (),
91 }
92 }
93 assert!(has_repr_c, "Struct must be declared #[repr(C)]");
94
95 let mut first = None;
96 let mut count = 0_usize;
97 let mut found_last = false;
98 let mut src_types = TokenStream2::new();
99
100 if let Fields::Named(named) = s.fields {
101 for f in named.named {
102 let ty_count = count_type(&f.ty, search_type);
103
104 if search_type == "Src" {
105 let src_type = get_src_type(&f);
106 if ty_count == 0 && !src_type.is_none() {
107 panic!(
108 "src_type attribute is only allowed on sources"
109 );
110 }
111
112 let src_type = if let Some(s) = src_type {
113 let s = syn::parse_str::<Ident>(&s).unwrap();
114 quote! { SrcType::#s, }
115 } else {
116 quote! { SrcType::DEFAULT, }
117 };
118
119 for _ in 0..ty_count {
120 src_types.extend(src_type.clone());
121 }
122 }
123
124 if ty_count > 0 {
125 assert!(
126 !found_last,
127 "All fields of type {} must be consecutive",
128 search_type
129 );
130 first.get_or_insert(f.ident);
131 count += ty_count;
132 } else {
133 if !first.is_none() {
134 found_last = true;
135 }
136 }
137 }
138 } else {
139 panic!("Fields are not named");
140 }
141
142 let src_type_func = if search_type == "Src" {
143 quote! {
144 fn src_types(&self) -> SrcTypeList {
145 static SRC_TYPES: [SrcType; #count] = [#src_types];
146 SrcTypeList::Array(&SRC_TYPES)
147 }
148 }
149 } else {
150 TokenStream2::new()
151 };
152
153 if let Some(name) = first {
154 quote! {
155 impl #trait_name for #ident {
156 fn #as_slice(&self) -> &[#elem_type] {
157 unsafe {
158 let first = &self.#name as *const #elem_type;
159 std::slice::from_raw_parts(first, #count)
160 }
161 }
162
163 fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
164 unsafe {
165 let first = &mut self.#name as *mut #elem_type;
166 std::slice::from_raw_parts_mut(first, #count)
167 }
168 }
169
170 #src_type_func
171 }
172 }
173 } else {
174 quote! {
175 impl #trait_name for #ident {
176 fn #as_slice(&self) -> &[#elem_type] {
177 &[]
178 }
179
180 fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
181 &mut []
182 }
183
184 #src_type_func
185 }
186 }
187 }
188 .into()
189 }
190 Data::Enum(e) => {
191 let mut as_slice_cases = TokenStream2::new();
192 let mut as_mut_slice_cases = TokenStream2::new();
193 let mut src_types_cases = TokenStream2::new();
194 for v in e.variants {
195 let case = v.ident;
196 as_slice_cases.extend(quote! {
197 #ident::#case(x) => x.#as_slice(),
198 });
199 as_mut_slice_cases.extend(quote! {
200 #ident::#case(x) => x.#as_mut_slice(),
201 });
202 if search_type == "Src" {
203 src_types_cases.extend(quote! {
204 #ident::#case(x) => x.src_types(),
205 });
206 }
207 }
208 let src_type_func = if search_type == "Src" {
209 quote! {
210 fn src_types(&self) -> SrcTypeList {
211 match self {
212 #src_types_cases
213 }
214 }
215 }
216 } else {
217 TokenStream2::new()
218 };
219 quote! {
220 impl #trait_name for #ident {
221 fn #as_slice(&self) -> &[#elem_type] {
222 match self {
223 #as_slice_cases
224 }
225 }
226
227 fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
228 match self {
229 #as_mut_slice_cases
230 }
231 }
232
233 #src_type_func
234 }
235 }
236 .into()
237 }
238 _ => panic!("Not a struct type"),
239 }
240 }
241
242 #[proc_macro_derive(SrcsAsSlice, attributes(src_type))]
derive_srcs_as_slice(input: TokenStream) -> TokenStream243 pub fn derive_srcs_as_slice(input: TokenStream) -> TokenStream {
244 derive_as_slice(input, "SrcsAsSlice", "srcs", "Src")
245 }
246
247 #[proc_macro_derive(DstsAsSlice)]
derive_dsts_as_slice(input: TokenStream) -> TokenStream248 pub fn derive_dsts_as_slice(input: TokenStream) -> TokenStream {
249 derive_as_slice(input, "DstsAsSlice", "dsts", "Dst")
250 }
251
252 #[proc_macro_derive(DisplayOp)]
enum_derive_display_op(input: TokenStream) -> TokenStream253 pub fn enum_derive_display_op(input: TokenStream) -> TokenStream {
254 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
255
256 if let Data::Enum(e) = data {
257 let mut fmt_dsts_cases = TokenStream2::new();
258 let mut fmt_op_cases = TokenStream2::new();
259 for v in e.variants {
260 let case = v.ident;
261 fmt_dsts_cases.extend(quote! {
262 #ident::#case(x) => x.fmt_dsts(f),
263 });
264 fmt_op_cases.extend(quote! {
265 #ident::#case(x) => x.fmt_op(f),
266 });
267 }
268 quote! {
269 impl DisplayOp for #ident {
270 fn fmt_dsts(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271 match self {
272 #fmt_dsts_cases
273 }
274 }
275
276 fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277 match self {
278 #fmt_op_cases
279 }
280 }
281 }
282 }
283 .into()
284 } else {
285 panic!("Not an enum type");
286 }
287 }
288
289 #[proc_macro_derive(FromVariants)]
derive_from_variants(input: TokenStream) -> TokenStream290 pub fn derive_from_variants(input: TokenStream) -> TokenStream {
291 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
292 let enum_type = ident;
293
294 let mut impls = TokenStream2::new();
295
296 if let Data::Enum(e) = data {
297 for v in e.variants {
298 let var_ident = v.ident;
299 let from_type = match v.fields {
300 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => unnamed,
301 _ => panic!("Expected Op(OpFoo)"),
302 };
303
304 let quote = quote! {
305 impl From<#from_type> for #enum_type {
306 fn from (op: #from_type) -> #enum_type {
307 #enum_type::#var_ident(op)
308 }
309 }
310 };
311
312 impls.extend(quote);
313 }
314 }
315
316 impls.into()
317 }
318