1 //! Derive macro for `AsCborValue`.
2 use proc_macro2::TokenStream;
3 use quote::{format_ident, quote, quote_spanned};
4 use syn::{
5     parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam,
6     Generics, Ident, Index,
7 };
8 
9 /// Derive macro that implements the `AsCborValue` trait.  Using this macro requires
10 /// that `AsCborValue`, `CborError` and `cbor_type_error` are locally `use`d.
11 #[proc_macro_derive(AsCborValue)]
derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream12 pub fn derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13     let input = parse_macro_input!(input as DeriveInput);
14     derive_as_cbor_value_internal(&input)
15 }
16 
derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream17 fn derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream {
18     let name = &input.ident;
19 
20     // Add a bound `T: AsCborValue` for every type parameter `T`.
21     let generics = add_trait_bounds(&input.generics);
22     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
23 
24     let from_val = from_val_struct(&input.data);
25     let to_val = to_val_struct(&input.data);
26     let cddl = cddl_struct(name, &input.data);
27 
28     let expanded = quote! {
29         // The generated impl
30         impl #impl_generics AsCborValue for #name #ty_generics #where_clause {
31             fn from_cbor_value(value: ciborium::value::Value) -> Result<Self, CborError> {
32                 #from_val
33             }
34             fn to_cbor_value(self) -> Result<ciborium::value::Value, CborError> {
35                 #to_val
36             }
37             fn cddl_typename() -> Option<String> {
38                 Some(stringify!(#name).to_string())
39             }
40             fn cddl_schema() -> Option<String> {
41                 #cddl
42             }
43         }
44     };
45 
46     expanded.into()
47 }
48 
49 /// Add a bound `T: AsCborValue` for every type parameter `T`.
add_trait_bounds(generics: &Generics) -> Generics50 fn add_trait_bounds(generics: &Generics) -> Generics {
51     let mut generics = generics.clone();
52     for param in &mut generics.params {
53         if let GenericParam::Type(ref mut type_param) = *param {
54             type_param.bounds.push(parse_quote!(AsCborValue));
55         }
56     }
57     generics
58 }
59 
60 /// Generate an expression to convert an instance of a compound type to `ciborium::value::Value`
to_val_struct(data: &Data) -> TokenStream61 fn to_val_struct(data: &Data) -> TokenStream {
62     match *data {
63         Data::Struct(ref data) => {
64             match data.fields {
65                 Fields::Named(ref fields) => {
66                     // Expands to an expression like
67                     //
68                     //     {
69                     //         let mut v = Vec::new();
70                     //         v.try_reserve(3).map_err(|_e| CborError::AllocationFailed)?;
71                     //         v.push(AsCborValue::to_cbor_value(self.x)?);
72                     //         v.push(AsCborValue::to_cbor_value(self.y)?);
73                     //         v.push(AsCborValue::to_cbor_value(self.z)?);
74                     //         Ok(ciborium::value::Value::Array(v))
75                     //     }
76                     let nfields = fields.named.len();
77                     let recurse = fields.named.iter().map(|f| {
78                         let name = &f.ident;
79                         quote_spanned! {f.span()=>
80                             v.push(AsCborValue::to_cbor_value(self.#name)?)
81                         }
82                     });
83                     quote! {
84                         {
85                             let mut v = Vec::new();
86                             v.try_reserve(#nfields).map_err(|_e| CborError::AllocationFailed)?;
87                             #(#recurse; )*
88                             Ok(ciborium::value::Value::Array(v))
89                         }
90                     }
91                 }
92                 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
93                     // For a newtype, expands to an expression
94                     //
95                     //     self.0.to_cbor_value()
96                     quote! {
97                         self.0.to_cbor_value()
98                     }
99                 }
100                 Fields::Unnamed(ref fields) => {
101                     // Expands to an expression like
102                     //
103                     //
104                     //     {
105                     //         let mut v = Vec::new();
106                     //         v.try_reserve(3).map_err(|_e| CborError::AllocationFailed)?;
107                     //         v.push(AsCborValue::to_cbor_value(self.0)?);
108                     //         v.push(AsCborValue::to_cbor_value(self.1)?);
109                     //         v.push(AsCborValue::to_cbor_value(self.2)?);
110                     //         Ok(ciborium::value::Value::Array(v))
111                     //     }
112                     let nfields = fields.unnamed.len();
113                     let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
114                         let index = Index::from(i);
115                         quote_spanned! {f.span()=>
116                             v.push(AsCborValue::to_cbor_value(self.#index)?)
117                         }
118                     });
119                     quote! {
120                         {
121                             let mut v = Vec::new();
122                             v.try_reserve(#nfields).map_err(|_e| CborError::AllocationFailed)?;
123                             #(#recurse; )*
124                             Ok(ciborium::value::Value::Array(v))
125                         }
126                     }
127                 }
128                 Fields::Unit => unimplemented!(),
129             }
130         }
131         Data::Enum(_) => {
132             quote! {
133                 let v: ciborium::value::Integer = (self as i32).into();
134                 Ok(ciborium::value::Value::Integer(v))
135             }
136         }
137         Data::Union(_) => unimplemented!(),
138     }
139 }
140 
141 /// Generate an expression to convert a `ciborium::value::Value` into an instance of a compound
142 /// type.
from_val_struct(data: &Data) -> TokenStream143 fn from_val_struct(data: &Data) -> TokenStream {
144     match data {
145         Data::Struct(ref data) => {
146             match data.fields {
147                 Fields::Named(ref fields) => {
148                     // Expands to an expression like
149                     //
150                     //     let mut a = match value {
151                     //         ciborium::value::Value::Array(a) => a,
152                     //         _ => return cbor_type_error(&value, "arr"),
153                     //     };
154                     //     if a.len() != 3 {
155                     //         return Err(CborError::UnexpectedItem("arr", "arr len 3"));
156                     //     }
157                     //     // Fields specified in reverse order to reduce shifting.
158                     //     Ok(Self {
159                     //         z: <ZType>::from_cbor_value(a.remove(2))?,
160                     //         y: <YType>::from_cbor_value(a.remove(1))?,
161                     //         x: <XType>::from_cbor_value(a.remove(0))?,
162                     //     })
163                     //
164                     // but using fully qualified function call syntax.
165                     let nfields = fields.named.len();
166                     let recurse = fields.named.iter().enumerate().rev().map(|(i, f)| {
167                         let name = &f.ident;
168                         let index = Index::from(i);
169                         let typ = &f.ty;
170                         quote_spanned! {f.span()=>
171                                         #name: <#typ>::from_cbor_value(a.remove(#index))?
172                         }
173                     });
174                     quote! {
175                         let mut a = match value {
176                             ciborium::value::Value::Array(a) => a,
177                             _ => return cbor_type_error(&value, "arr"),
178                         };
179                         if a.len() != #nfields {
180                             return Err(CborError::UnexpectedItem(
181                                 "arr",
182                                 concat!("arr len ", stringify!(#nfields)),
183                             ));
184                         }
185                         // Fields specified in reverse order to reduce shifting.
186                         Ok(Self {
187                             #(#recurse, )*
188                         })
189                     }
190                 }
191                 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
192                     // For a newtype, expands to an expression like
193                     //
194                     //     Ok(Self(<InnerType>::from_cbor_value(value)?))
195                     let inner = fields.unnamed.first().unwrap();
196                     let typ = &inner.ty;
197                     quote! {
198                         Ok(Self(<#typ>::from_cbor_value(value)?))
199                     }
200                 }
201                 Fields::Unnamed(ref fields) => {
202                     // Expands to an expression like
203                     //
204                     //     let mut a = match value {
205                     //         ciborium::value::Value::Array(a) => a,
206                     //         _ => return cbor_type_error(&value, "arr"),
207                     //     };
208                     //     if a.len() != 3 {
209                     //         return Err(CborError::UnexpectedItem("arr", "arr len 3"));
210                     //     }
211                     //     // Fields specified in reverse order to reduce shifting.
212                     //     let field_2 = <Type2>::from_cbor_value(a.remove(2))?;
213                     //     let field_1 = <Type1>::from_cbor_value(a.remove(1))?;
214                     //     let field_0 = <Type0>::from_cbor_value(a.remove(0))?;
215                     //     Ok(Self(field_0, field_1, field_2))
216                     let nfields = fields.unnamed.len();
217                     let recurse1 = fields.unnamed.iter().enumerate().rev().map(|(i, f)| {
218                         let typ = &f.ty;
219                         let varname = format_ident!("field_{}", i);
220                         quote_spanned! {f.span()=>
221                                         let #varname = <#typ>::from_cbor_value(a.remove(#i))?;
222                         }
223                     });
224                     let recurse2 = fields.unnamed.iter().enumerate().map(|(i, f)| {
225                         let varname = format_ident!("field_{}", i);
226                         quote_spanned! {f.span()=>
227                                         #varname
228                         }
229                     });
230                     quote! {
231                         let mut a = match value {
232                             ciborium::value::Value::Array(a) => a,
233                             _ => return cbor_type_error(&value, "arr"),
234                         };
235                         if a.len() != #nfields {
236                             return Err(CborError::UnexpectedItem("arr",
237                                                                  concat!("arr len ",
238                                                                          stringify!(#nfields))));
239                         }
240                         // Fields specified in reverse order to reduce shifting.
241                         #(#recurse1)*
242 
243                         Ok(Self( #(#recurse2, )* ))
244                     }
245                 }
246                 Fields::Unit => unimplemented!(),
247             }
248         }
249         Data::Enum(enum_data) => {
250             // This only copes with variants with no fields.
251             // Expands to an expression like:
252             //
253             //     use core::convert::TryInto;
254             //     let v: i32 = match value {
255             //         ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
256             //             CborError::OutOfRangeIntegerValue
257             //         })?,
258             //         v => return cbor_type_error(&v, &"int"),
259             //     };
260             //     match v {
261             //         x if x == Self::Variant1 as i32 => Ok(Self::Variant1),
262             //         x if x == Self::Variant2 as i32 => Ok(Self::Variant2),
263             //         x if x == Self::Variant3 as i32 => Ok(Self::Variant3),
264             //         _ => Err( CborError::OutOfRangeIntegerValue),
265             //     }
266             let recurse = enum_data.variants.iter().map(|variant| {
267                 let vname = &variant.ident;
268                 quote_spanned! {variant.span()=>
269                                 x if x == Self::#vname as i32 => Ok(Self::#vname),
270                 }
271             });
272 
273             quote! {
274                 use core::convert::TryInto;
275                 // First get the int value as an `i32`.
276                 let v: i32 = match value {
277                     ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
278                         CborError::OutOfRangeIntegerValue
279                     })?,
280                     v => return cbor_type_error(&v, &"int"),
281                 };
282                 // Now match against enum possibilities.
283                 match v {
284                     #(#recurse)*
285                     _ => Err(
286                         CborError::OutOfRangeIntegerValue
287                     ),
288                 }
289             }
290         }
291         Data::Union(_) => unimplemented!(),
292     }
293 }
294 
295 /// Generate an expression that expresses the CDDL schema for the type.
cddl_struct(name: &Ident, data: &Data) -> TokenStream296 fn cddl_struct(name: &Ident, data: &Data) -> TokenStream {
297     match *data {
298         Data::Struct(ref data) => {
299             match data.fields {
300                 Fields::Named(ref fields) => {
301                     if fields.named.iter().next().is_none() {
302                         return quote! {
303                             Some(format!("[]"))
304                         };
305                     }
306                     // Expands to an expression like
307                     //
308                     //     format!("[
309                     //         x: {},
310                     //         y: {},
311                     //         z: {},
312                     //     ]",
313                     //         <TypeX>::cddl_ref(),
314                     //         <TypeY>::cddl_ref(),
315                     //         <TypeZ>::cddl_ref(),
316                     //     )
317                     let fmt_recurse = fields.named.iter().map(|f| {
318                         let name = &f.ident;
319                         quote_spanned! {f.span()=>
320                                         concat!("    ", stringify!(#name), ": {},\n")
321                         }
322                     });
323                     let fmt = quote! {
324                         concat!("[\n",
325                                 #(#fmt_recurse, )*
326                                 "]")
327                     };
328                     let recurse = fields.named.iter().map(|f| {
329                         let typ = &f.ty;
330                         quote_spanned! {f.span()=>
331                                         <#typ>::cddl_ref()
332                         }
333                     });
334                     quote! {
335                         Some(format!(
336                             #fmt,
337                             #(#recurse, )*
338                         ))
339                     }
340                 }
341                 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
342                     let inner = fields.unnamed.first().unwrap();
343                     let typ = &inner.ty;
344                     quote! {
345                         Some(<#typ>::cddl_ref())
346                     }
347                 }
348                 Fields::Unnamed(ref fields) => {
349                     if fields.unnamed.iter().next().is_none() {
350                         return quote! {
351                             Some(format!("()"))
352                         };
353                     }
354                     // Expands to an expression like
355                     //
356                     //     format!("[
357                     //         {},
358                     //         {},
359                     //         {},
360                     //     ]",
361                     //         <TypeX>::cddl_ref(),
362                     //         <TypeY>::cddl_ref(),
363                     //         <TypeZ>::cddl_ref(),
364                     //     )
365                     //
366                     let fmt_recurse = fields.unnamed.iter().map(|f| {
367                         quote_spanned! {f.span()=>
368                                         "    {},\n"
369                         }
370                     });
371                     let fmt = quote! {
372                         concat!("[\n",
373                                  #(#fmt_recurse, )*
374                                  "]")
375                     };
376                     let recurse = fields.unnamed.iter().map(|f| {
377                         let typ = &f.ty;
378                         quote_spanned! {f.span()=>
379                                         <#typ>::cddl_ref()
380                         }
381                     });
382                     quote! {
383                         Some(format!(
384                             #fmt,
385                             #(#recurse, )*
386                         ))
387                     }
388                 }
389                 Fields::Unit => unimplemented!(),
390             }
391         }
392         Data::Enum(ref enum_data) => {
393             // This only copes with variants with no fields.
394             // Expands to an expression like:
395             //
396             //     format!("&(
397             //         EnumName_Variant1: {},
398             //         EnumName_Variant2: {},
399             //         EnumName_Variant3: {},
400             //     )",
401             //         Self::Variant1 as i32,
402             //         Self::Variant2 as i32,
403             //         Self::Variant3 as i32,
404             //     )
405             //
406             let fmt_recurse = enum_data.variants.iter().map(|variant| {
407                 let vname = &variant.ident;
408                 quote_spanned! {variant.span()=>
409                                 concat!("    ",
410                                         stringify!(#name),
411                                         "_",
412                                         stringify!(#vname),
413                                         ": {},\n")
414                 }
415             });
416             let fmt = quote! {
417                 concat!("&(\n",
418                          #(#fmt_recurse, )*
419                          ")")
420             };
421             let recurse = enum_data.variants.iter().map(|variant| {
422                 let vname = &variant.ident;
423                 quote_spanned! {variant.span()=>
424                                 Self::#vname as i32
425                 }
426             });
427             quote! {
428                 Some(format!(
429                     #fmt,
430                     #(#recurse, )*
431                 ))
432             }
433         }
434         Data::Union(_) => unimplemented!(),
435     }
436 }
437 
438 /// Derive macro that implements a `from_raw_tag_value` method for the `Tag` enum.
439 #[proc_macro_derive(FromRawTag)]
derive_from_raw_tag(input: proc_macro::TokenStream) -> proc_macro::TokenStream440 pub fn derive_from_raw_tag(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
441     let input = parse_macro_input!(input as DeriveInput);
442     derive_from_raw_tag_internal(&input)
443 }
444 
derive_from_raw_tag_internal(input: &DeriveInput) -> proc_macro::TokenStream445 fn derive_from_raw_tag_internal(input: &DeriveInput) -> proc_macro::TokenStream {
446     let name = &input.ident;
447     let from_val = from_raw_tag(name, &input.data);
448     let expanded = quote! {
449         pub fn from_raw_tag_value(raw_tag: u32) -> #name {
450             #from_val
451         }
452     };
453     expanded.into()
454 }
455 
456 /// Generate an expression to convert a `u32` into an instance of an fieldless enum.
457 /// Assumes the existence of an `Invalid` variant as a fallback, and assumes that a
458 /// `raw_tag_value` function is in scope.
from_raw_tag(name: &Ident, data: &Data) -> TokenStream459 fn from_raw_tag(name: &Ident, data: &Data) -> TokenStream {
460     match data {
461         Data::Enum(enum_data) => {
462             let recurse = enum_data.variants.iter().map(|variant| {
463                 let vname = &variant.ident;
464                 quote_spanned! {variant.span()=>
465                                 x if x == raw_tag_value(#name::#vname) => #name::#vname,
466                 }
467             });
468 
469             quote! {
470                 match raw_tag {
471                     #(#recurse)*
472                     _ => #name::Invalid,
473                 }
474             }
475         }
476         _ => unimplemented!(),
477     }
478 }
479 
480 /// Derive macro that implements the `legacy::InnerSerialize` trait.  Using this macro requires
481 /// that `InnerSerialize` and `Error` from `kmr_wire::legacy` be locally `use`d.
482 #[proc_macro_derive(LegacySerialize)]
derive_legacy_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream483 pub fn derive_legacy_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
484     let input = parse_macro_input!(input as DeriveInput);
485     derive_legacy_serialize_internal(&input)
486 }
487 
derive_legacy_serialize_internal(input: &DeriveInput) -> proc_macro::TokenStream488 fn derive_legacy_serialize_internal(input: &DeriveInput) -> proc_macro::TokenStream {
489     let name = &input.ident;
490 
491     let deserialize_val = deserialize_struct(&input.data);
492     let serialize_val = serialize_struct(&input.data);
493 
494     let expanded = quote! {
495         impl InnerSerialize for #name {
496             fn deserialize(data: &[u8]) -> Result<(Self, &[u8]), Error> {
497                 #deserialize_val
498             }
499             fn serialize_into(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
500                 #serialize_val
501             }
502         }
503     };
504 
505     expanded.into()
506 }
507 
deserialize_struct(data: &Data) -> TokenStream508 fn deserialize_struct(data: &Data) -> TokenStream {
509     match data {
510         Data::Struct(ref data) => {
511             match data.fields {
512                 Fields::Named(ref fields) => {
513                     // Expands to an expression like
514                     //
515                     //     let (x, data) = <XType>::deserialize(data)?;
516                     //     let (y, data) = <YType>::deserialize(data)?;
517                     //     let (z, data) = <ZType>::deserialize(data)?;
518                     //     Ok((Self {
519                     //             x,
520                     //             y,
521                     //             z,
522                     //     }, data))
523                     //
524                     let recurse1 = fields.named.iter().map(|f| {
525                         let name = &f.ident;
526                         let typ = &f.ty;
527                         quote_spanned! {f.span()=>
528                                         let (#name, data) = <#typ>::deserialize(data)?;
529                         }
530                     });
531                     let recurse2 = fields.named.iter().map(|f| {
532                         let name = &f.ident;
533                         quote_spanned! {f.span()=>
534                                         #name
535                         }
536                     });
537                     quote! {
538                         #(#recurse1)*
539                         Ok((Self {
540                             #(#recurse2, )*
541                         }, data))
542                     }
543                 }
544                 Fields::Unnamed(_) => unimplemented!(),
545                 Fields::Unit => unimplemented!(),
546             }
547         }
548         Data::Enum(_) => unimplemented!(),
549         Data::Union(_) => unimplemented!(),
550     }
551 }
552 
serialize_struct(data: &Data) -> TokenStream553 fn serialize_struct(data: &Data) -> TokenStream {
554     match data {
555         Data::Struct(ref data) => {
556             match data.fields {
557                 Fields::Named(ref fields) => {
558                     // Expands to an expression like
559                     //
560                     //     self.x.serialize_into(buf)?;
561                     //     self.y.serialize_into(buf)?;
562                     //     self.z.serialize_into(buf)?;
563                     //     Ok(())
564                     //
565                     let recurse = fields.named.iter().map(|f| {
566                         let name = &f.ident;
567                         quote_spanned! {f.span()=>
568                                         self.#name.serialize_into(buf)?;
569                         }
570                     });
571                     quote! {
572                         #(#recurse)*
573                         Ok(())
574                     }
575                 }
576                 Fields::Unnamed(_) => unimplemented!(),
577                 Fields::Unit => unimplemented!(),
578             }
579         }
580         Data::Enum(_) => unimplemented!(),
581         Data::Union(_) => unimplemented!(),
582     }
583 }
584