• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Derives a 9P wire format encoding for a struct by recursively calling
6 //! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
7 //! This is only intended to be used from within the `p9` crate.
8 
9 #![recursion_limit = "256"]
10 
11 extern crate proc_macro;
12 extern crate proc_macro2;
13 
14 #[macro_use]
15 extern crate quote;
16 
17 #[macro_use]
18 extern crate syn;
19 
20 use proc_macro2::{Span, TokenStream};
21 use syn::spanned::Spanned;
22 use syn::{Data, DeriveInput, Fields, Ident};
23 
24 /// The function that derives the actual implementation.
25 #[proc_macro_derive(P9WireFormat)]
p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream26 pub fn p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27     let input = parse_macro_input!(input as DeriveInput);
28     p9_wire_format_inner(input).into()
29 }
30 
p9_wire_format_inner(input: DeriveInput) -> TokenStream31 fn p9_wire_format_inner(input: DeriveInput) -> TokenStream {
32     if !input.generics.params.is_empty() {
33         return quote! {
34             compile_error!("derive(P9WireFormat) does not support generic parameters");
35         };
36     }
37 
38     let container = input.ident;
39 
40     let byte_size_impl = byte_size_sum(&input.data);
41     let encode_impl = encode_wire_format(&input.data);
42     let decode_impl = decode_wire_format(&input.data, &container);
43 
44     let scope = format!("wire_format_{}", container).to_lowercase();
45     let scope = Ident::new(&scope, Span::call_site());
46     quote! {
47         mod #scope {
48             extern crate std;
49             use self::std::io;
50             use self::std::result::Result::Ok;
51 
52             use super::#container;
53 
54             use protocol::WireFormat;
55 
56             impl WireFormat for #container {
57                 fn byte_size(&self) -> u32 {
58                     #byte_size_impl
59                 }
60 
61                 fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
62                     #encode_impl
63                 }
64 
65                 fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
66                     #decode_impl
67                 }
68             }
69         }
70     }
71 }
72 
73 // Generate code that recursively calls byte_size on every field in the struct.
byte_size_sum(data: &Data) -> TokenStream74 fn byte_size_sum(data: &Data) -> TokenStream {
75     if let Data::Struct(ref data) = *data {
76         if let Fields::Named(ref fields) = data.fields {
77             let fields = fields.named.iter().map(|f| {
78                 let field = &f.ident;
79                 let span = field.span();
80                 quote_spanned! {span=>
81                     WireFormat::byte_size(&self.#field)
82                 }
83             });
84 
85             quote! {
86                 0 #(+ #fields)*
87             }
88         } else {
89             unimplemented!();
90         }
91     } else {
92         unimplemented!();
93     }
94 }
95 
96 // Generate code that recursively calls encode on every field in the struct.
encode_wire_format(data: &Data) -> TokenStream97 fn encode_wire_format(data: &Data) -> TokenStream {
98     if let Data::Struct(ref data) = *data {
99         if let Fields::Named(ref fields) = data.fields {
100             let fields = fields.named.iter().map(|f| {
101                 let field = &f.ident;
102                 let span = field.span();
103                 quote_spanned! {span=>
104                     WireFormat::encode(&self.#field, _writer)?;
105                 }
106             });
107 
108             quote! {
109                 #(#fields)*
110 
111                 Ok(())
112             }
113         } else {
114             unimplemented!();
115         }
116     } else {
117         unimplemented!();
118     }
119 }
120 
121 // Generate code that recursively calls decode on every field in the struct.
decode_wire_format(data: &Data, container: &Ident) -> TokenStream122 fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
123     if let Data::Struct(ref data) = *data {
124         if let Fields::Named(ref fields) = data.fields {
125             let values = fields.named.iter().map(|f| {
126                 let field = &f.ident;
127                 let span = field.span();
128                 quote_spanned! {span=>
129                     let #field = WireFormat::decode(_reader)?;
130                 }
131             });
132 
133             let members = fields.named.iter().map(|f| {
134                 let field = &f.ident;
135                 quote! {
136                     #field: #field,
137                 }
138             });
139 
140             quote! {
141                 #(#values)*
142 
143                 Ok(#container {
144                     #(#members)*
145                 })
146             }
147         } else {
148             unimplemented!();
149         }
150     } else {
151         unimplemented!();
152     }
153 }
154 
155 #[cfg(test)]
156 mod tests {
157     use super::*;
158 
159     #[test]
byte_size()160     fn byte_size() {
161         let input: DeriveInput = parse_quote! {
162             struct Item {
163                 ident: u32,
164                 with_underscores: String,
165                 other: u8,
166             }
167         };
168 
169         let expected = quote! {
170             0
171                 + WireFormat::byte_size(&self.ident)
172                 + WireFormat::byte_size(&self.with_underscores)
173                 + WireFormat::byte_size(&self.other)
174         };
175 
176         assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string());
177     }
178 
179     #[test]
encode()180     fn encode() {
181         let input: DeriveInput = parse_quote! {
182             struct Item {
183                 ident: u32,
184                 with_underscores: String,
185                 other: u8,
186             }
187         };
188 
189         let expected = quote! {
190             WireFormat::encode(&self.ident, _writer)?;
191             WireFormat::encode(&self.with_underscores, _writer)?;
192             WireFormat::encode(&self.other, _writer)?;
193             Ok(())
194         };
195 
196         assert_eq!(
197             encode_wire_format(&input.data).to_string(),
198             expected.to_string(),
199         );
200     }
201 
202     #[test]
decode()203     fn decode() {
204         let input: DeriveInput = parse_quote! {
205             struct Item {
206                 ident: u32,
207                 with_underscores: String,
208                 other: u8,
209             }
210         };
211 
212         let container = Ident::new("Item", Span::call_site());
213         let expected = quote! {
214             let ident = WireFormat::decode(_reader)?;
215             let with_underscores = WireFormat::decode(_reader)?;
216             let other = WireFormat::decode(_reader)?;
217             Ok(Item {
218                 ident: ident,
219                 with_underscores: with_underscores,
220                 other: other,
221             })
222         };
223 
224         assert_eq!(
225             decode_wire_format(&input.data, &container).to_string(),
226             expected.to_string(),
227         );
228     }
229 
230     #[test]
end_to_end()231     fn end_to_end() {
232         let input: DeriveInput = parse_quote! {
233             struct Niijima_先輩 {
234                 a: u8,
235                 b: u16,
236                 c: u32,
237                 d: u64,
238                 e: String,
239                 f: Vec<String>,
240                 g: Nested,
241             }
242         };
243 
244         let expected = quote! {
245             mod wire_format_niijima_先輩 {
246                 extern crate std;
247                 use self::std::io;
248                 use self::std::result::Result::Ok;
249 
250                 use super::Niijima_先輩;
251 
252                 use protocol::WireFormat;
253 
254                 impl WireFormat for Niijima_先輩 {
255                     fn byte_size(&self) -> u32 {
256                         0
257                         + WireFormat::byte_size(&self.a)
258                         + WireFormat::byte_size(&self.b)
259                         + WireFormat::byte_size(&self.c)
260                         + WireFormat::byte_size(&self.d)
261                         + WireFormat::byte_size(&self.e)
262                         + WireFormat::byte_size(&self.f)
263                         + WireFormat::byte_size(&self.g)
264                     }
265 
266                     fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
267                         WireFormat::encode(&self.a, _writer)?;
268                         WireFormat::encode(&self.b, _writer)?;
269                         WireFormat::encode(&self.c, _writer)?;
270                         WireFormat::encode(&self.d, _writer)?;
271                         WireFormat::encode(&self.e, _writer)?;
272                         WireFormat::encode(&self.f, _writer)?;
273                         WireFormat::encode(&self.g, _writer)?;
274                         Ok(())
275                     }
276                     fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
277                         let a = WireFormat::decode(_reader)?;
278                         let b = WireFormat::decode(_reader)?;
279                         let c = WireFormat::decode(_reader)?;
280                         let d = WireFormat::decode(_reader)?;
281                         let e = WireFormat::decode(_reader)?;
282                         let f = WireFormat::decode(_reader)?;
283                         let g = WireFormat::decode(_reader)?;
284                         Ok(Niijima_先輩 {
285                             a: a,
286                             b: b,
287                             c: c,
288                             d: d,
289                             e: e,
290                             f: f,
291                             g: g,
292                         })
293                     }
294                 }
295             }
296         };
297 
298         assert_eq!(
299             p9_wire_format_inner(input).to_string(),
300             expected.to_string(),
301         );
302     }
303 }
304