• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Derives a 9P wire format encoding for a struct by recursively calling
6 //! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
7 //! This is only intended to be used from within the `p9` crate.
8 
9 #![recursion_limit = "256"]
10 
11 extern crate proc_macro;
12 extern crate proc_macro2;
13 
14 #[macro_use]
15 extern crate quote;
16 
17 #[macro_use]
18 extern crate syn;
19 
20 use proc_macro2::Span;
21 use proc_macro2::TokenStream;
22 use syn::spanned::Spanned;
23 use syn::Data;
24 use syn::DeriveInput;
25 use syn::Fields;
26 use syn::Ident;
27 
28 /// The function that derives the actual implementation.
29 #[proc_macro_derive(P9WireFormat)]
p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream30 pub fn p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31     let input = parse_macro_input!(input as DeriveInput);
32     p9_wire_format_inner(input).into()
33 }
34 
p9_wire_format_inner(input: DeriveInput) -> TokenStream35 fn p9_wire_format_inner(input: DeriveInput) -> TokenStream {
36     if !input.generics.params.is_empty() {
37         return quote! {
38             compile_error!("derive(P9WireFormat) does not support generic parameters");
39         };
40     }
41 
42     let container = input.ident;
43 
44     let byte_size_impl = byte_size_sum(&input.data);
45     let encode_impl = encode_wire_format(&input.data);
46     let decode_impl = decode_wire_format(&input.data, &container);
47 
48     let scope = format!("wire_format_{}", container).to_lowercase();
49     let scope = Ident::new(&scope, Span::call_site());
50     quote! {
51         mod #scope {
52             extern crate std;
53             use self::std::io;
54             use self::std::result::Result::Ok;
55 
56             use super::#container;
57 
58             use protocol::WireFormat;
59 
60             impl WireFormat for #container {
61                 fn byte_size(&self) -> u32 {
62                     #byte_size_impl
63                 }
64 
65                 fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
66                     #encode_impl
67                 }
68 
69                 fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
70                     #decode_impl
71                 }
72             }
73         }
74     }
75 }
76 
77 // Generate code that recursively calls byte_size on every field in the struct.
byte_size_sum(data: &Data) -> TokenStream78 fn byte_size_sum(data: &Data) -> TokenStream {
79     if let Data::Struct(ref data) = *data {
80         if let Fields::Named(ref fields) = data.fields {
81             let fields = fields.named.iter().map(|f| {
82                 let field = &f.ident;
83                 let span = field.span();
84                 quote_spanned! {span=>
85                     WireFormat::byte_size(&self.#field)
86                 }
87             });
88 
89             quote! {
90                 0 #(+ #fields)*
91             }
92         } else {
93             unimplemented!();
94         }
95     } else {
96         unimplemented!();
97     }
98 }
99 
100 // Generate code that recursively calls encode on every field in the struct.
encode_wire_format(data: &Data) -> TokenStream101 fn encode_wire_format(data: &Data) -> TokenStream {
102     if let Data::Struct(ref data) = *data {
103         if let Fields::Named(ref fields) = data.fields {
104             let fields = fields.named.iter().map(|f| {
105                 let field = &f.ident;
106                 let span = field.span();
107                 quote_spanned! {span=>
108                     WireFormat::encode(&self.#field, _writer)?;
109                 }
110             });
111 
112             quote! {
113                 #(#fields)*
114 
115                 Ok(())
116             }
117         } else {
118             unimplemented!();
119         }
120     } else {
121         unimplemented!();
122     }
123 }
124 
125 // Generate code that recursively calls decode on every field in the struct.
decode_wire_format(data: &Data, container: &Ident) -> TokenStream126 fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
127     if let Data::Struct(ref data) = *data {
128         if let Fields::Named(ref fields) = data.fields {
129             let values = fields.named.iter().map(|f| {
130                 let field = &f.ident;
131                 let span = field.span();
132                 quote_spanned! {span=>
133                     let #field = WireFormat::decode(_reader)?;
134                 }
135             });
136 
137             let members = fields.named.iter().map(|f| {
138                 let field = &f.ident;
139                 quote! {
140                     #field: #field,
141                 }
142             });
143 
144             quote! {
145                 #(#values)*
146 
147                 Ok(#container {
148                     #(#members)*
149                 })
150             }
151         } else {
152             unimplemented!();
153         }
154     } else {
155         unimplemented!();
156     }
157 }
158 
159 #[cfg(test)]
160 mod tests {
161     use super::*;
162 
163     #[test]
byte_size()164     fn byte_size() {
165         let input: DeriveInput = parse_quote! {
166             struct Item {
167                 ident: u32,
168                 with_underscores: String,
169                 other: u8,
170             }
171         };
172 
173         let expected = quote! {
174             0
175                 + WireFormat::byte_size(&self.ident)
176                 + WireFormat::byte_size(&self.with_underscores)
177                 + WireFormat::byte_size(&self.other)
178         };
179 
180         assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string());
181     }
182 
183     #[test]
encode()184     fn encode() {
185         let input: DeriveInput = parse_quote! {
186             struct Item {
187                 ident: u32,
188                 with_underscores: String,
189                 other: u8,
190             }
191         };
192 
193         let expected = quote! {
194             WireFormat::encode(&self.ident, _writer)?;
195             WireFormat::encode(&self.with_underscores, _writer)?;
196             WireFormat::encode(&self.other, _writer)?;
197             Ok(())
198         };
199 
200         assert_eq!(
201             encode_wire_format(&input.data).to_string(),
202             expected.to_string(),
203         );
204     }
205 
206     #[test]
decode()207     fn decode() {
208         let input: DeriveInput = parse_quote! {
209             struct Item {
210                 ident: u32,
211                 with_underscores: String,
212                 other: u8,
213             }
214         };
215 
216         let container = Ident::new("Item", Span::call_site());
217         let expected = quote! {
218             let ident = WireFormat::decode(_reader)?;
219             let with_underscores = WireFormat::decode(_reader)?;
220             let other = WireFormat::decode(_reader)?;
221             Ok(Item {
222                 ident: ident,
223                 with_underscores: with_underscores,
224                 other: other,
225             })
226         };
227 
228         assert_eq!(
229             decode_wire_format(&input.data, &container).to_string(),
230             expected.to_string(),
231         );
232     }
233 
234     #[test]
end_to_end()235     fn end_to_end() {
236         let input: DeriveInput = parse_quote! {
237             struct Niijima_先輩 {
238                 a: u8,
239                 b: u16,
240                 c: u32,
241                 d: u64,
242                 e: String,
243                 f: Vec<String>,
244                 g: Nested,
245             }
246         };
247 
248         let expected = quote! {
249             mod wire_format_niijima_先輩 {
250                 extern crate std;
251                 use self::std::io;
252                 use self::std::result::Result::Ok;
253 
254                 use super::Niijima_先輩;
255 
256                 use protocol::WireFormat;
257 
258                 impl WireFormat for Niijima_先輩 {
259                     fn byte_size(&self) -> u32 {
260                         0
261                         + WireFormat::byte_size(&self.a)
262                         + WireFormat::byte_size(&self.b)
263                         + WireFormat::byte_size(&self.c)
264                         + WireFormat::byte_size(&self.d)
265                         + WireFormat::byte_size(&self.e)
266                         + WireFormat::byte_size(&self.f)
267                         + WireFormat::byte_size(&self.g)
268                     }
269 
270                     fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
271                         WireFormat::encode(&self.a, _writer)?;
272                         WireFormat::encode(&self.b, _writer)?;
273                         WireFormat::encode(&self.c, _writer)?;
274                         WireFormat::encode(&self.d, _writer)?;
275                         WireFormat::encode(&self.e, _writer)?;
276                         WireFormat::encode(&self.f, _writer)?;
277                         WireFormat::encode(&self.g, _writer)?;
278                         Ok(())
279                     }
280                     fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
281                         let a = WireFormat::decode(_reader)?;
282                         let b = WireFormat::decode(_reader)?;
283                         let c = WireFormat::decode(_reader)?;
284                         let d = WireFormat::decode(_reader)?;
285                         let e = WireFormat::decode(_reader)?;
286                         let f = WireFormat::decode(_reader)?;
287                         let g = WireFormat::decode(_reader)?;
288                         Ok(Niijima_先輩 {
289                             a: a,
290                             b: b,
291                             c: c,
292                             d: d,
293                             e: e,
294                             f: f,
295                             g: g,
296                         })
297                     }
298                 }
299             }
300         };
301 
302         assert_eq!(
303             p9_wire_format_inner(input).to_string(),
304             expected.to_string(),
305         );
306     }
307 }
308