1 // Copyright 2018 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Derives a 9P wire format encoding for a struct by recursively calling
6 //! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
7 //! This is only intended to be used from within the `p9` crate.
8
9 #![recursion_limit = "256"]
10
11 extern crate proc_macro;
12 extern crate proc_macro2;
13
14 #[macro_use]
15 extern crate quote;
16
17 #[macro_use]
18 extern crate syn;
19
20 use proc_macro2::Span;
21 use proc_macro2::TokenStream;
22 use syn::spanned::Spanned;
23 use syn::Data;
24 use syn::DeriveInput;
25 use syn::Fields;
26 use syn::Ident;
27
28 /// The function that derives the actual implementation.
29 #[proc_macro_derive(P9WireFormat)]
p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream30 pub fn p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 p9_wire_format_inner(input).into()
33 }
34
p9_wire_format_inner(input: DeriveInput) -> TokenStream35 fn p9_wire_format_inner(input: DeriveInput) -> TokenStream {
36 if !input.generics.params.is_empty() {
37 return quote! {
38 compile_error!("derive(P9WireFormat) does not support generic parameters");
39 };
40 }
41
42 let container = input.ident;
43
44 let byte_size_impl = byte_size_sum(&input.data);
45 let encode_impl = encode_wire_format(&input.data);
46 let decode_impl = decode_wire_format(&input.data, &container);
47
48 let scope = format!("wire_format_{}", container).to_lowercase();
49 let scope = Ident::new(&scope, Span::call_site());
50 quote! {
51 mod #scope {
52 extern crate std;
53 use self::std::io;
54 use self::std::result::Result::Ok;
55
56 use super::#container;
57
58 use protocol::WireFormat;
59
60 impl WireFormat for #container {
61 fn byte_size(&self) -> u32 {
62 #byte_size_impl
63 }
64
65 fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
66 #encode_impl
67 }
68
69 fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
70 #decode_impl
71 }
72 }
73 }
74 }
75 }
76
77 // Generate code that recursively calls byte_size on every field in the struct.
byte_size_sum(data: &Data) -> TokenStream78 fn byte_size_sum(data: &Data) -> TokenStream {
79 if let Data::Struct(ref data) = *data {
80 if let Fields::Named(ref fields) = data.fields {
81 let fields = fields.named.iter().map(|f| {
82 let field = &f.ident;
83 let span = field.span();
84 quote_spanned! {span=>
85 WireFormat::byte_size(&self.#field)
86 }
87 });
88
89 quote! {
90 0 #(+ #fields)*
91 }
92 } else {
93 unimplemented!();
94 }
95 } else {
96 unimplemented!();
97 }
98 }
99
100 // Generate code that recursively calls encode on every field in the struct.
encode_wire_format(data: &Data) -> TokenStream101 fn encode_wire_format(data: &Data) -> TokenStream {
102 if let Data::Struct(ref data) = *data {
103 if let Fields::Named(ref fields) = data.fields {
104 let fields = fields.named.iter().map(|f| {
105 let field = &f.ident;
106 let span = field.span();
107 quote_spanned! {span=>
108 WireFormat::encode(&self.#field, _writer)?;
109 }
110 });
111
112 quote! {
113 #(#fields)*
114
115 Ok(())
116 }
117 } else {
118 unimplemented!();
119 }
120 } else {
121 unimplemented!();
122 }
123 }
124
125 // Generate code that recursively calls decode on every field in the struct.
decode_wire_format(data: &Data, container: &Ident) -> TokenStream126 fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
127 if let Data::Struct(ref data) = *data {
128 if let Fields::Named(ref fields) = data.fields {
129 let values = fields.named.iter().map(|f| {
130 let field = &f.ident;
131 let span = field.span();
132 quote_spanned! {span=>
133 let #field = WireFormat::decode(_reader)?;
134 }
135 });
136
137 let members = fields.named.iter().map(|f| {
138 let field = &f.ident;
139 quote! {
140 #field: #field,
141 }
142 });
143
144 quote! {
145 #(#values)*
146
147 Ok(#container {
148 #(#members)*
149 })
150 }
151 } else {
152 unimplemented!();
153 }
154 } else {
155 unimplemented!();
156 }
157 }
158
159 #[cfg(test)]
160 mod tests {
161 use super::*;
162
163 #[test]
byte_size()164 fn byte_size() {
165 let input: DeriveInput = parse_quote! {
166 struct Item {
167 ident: u32,
168 with_underscores: String,
169 other: u8,
170 }
171 };
172
173 let expected = quote! {
174 0
175 + WireFormat::byte_size(&self.ident)
176 + WireFormat::byte_size(&self.with_underscores)
177 + WireFormat::byte_size(&self.other)
178 };
179
180 assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string());
181 }
182
183 #[test]
encode()184 fn encode() {
185 let input: DeriveInput = parse_quote! {
186 struct Item {
187 ident: u32,
188 with_underscores: String,
189 other: u8,
190 }
191 };
192
193 let expected = quote! {
194 WireFormat::encode(&self.ident, _writer)?;
195 WireFormat::encode(&self.with_underscores, _writer)?;
196 WireFormat::encode(&self.other, _writer)?;
197 Ok(())
198 };
199
200 assert_eq!(
201 encode_wire_format(&input.data).to_string(),
202 expected.to_string(),
203 );
204 }
205
206 #[test]
decode()207 fn decode() {
208 let input: DeriveInput = parse_quote! {
209 struct Item {
210 ident: u32,
211 with_underscores: String,
212 other: u8,
213 }
214 };
215
216 let container = Ident::new("Item", Span::call_site());
217 let expected = quote! {
218 let ident = WireFormat::decode(_reader)?;
219 let with_underscores = WireFormat::decode(_reader)?;
220 let other = WireFormat::decode(_reader)?;
221 Ok(Item {
222 ident: ident,
223 with_underscores: with_underscores,
224 other: other,
225 })
226 };
227
228 assert_eq!(
229 decode_wire_format(&input.data, &container).to_string(),
230 expected.to_string(),
231 );
232 }
233
234 #[test]
end_to_end()235 fn end_to_end() {
236 let input: DeriveInput = parse_quote! {
237 struct Niijima_先輩 {
238 a: u8,
239 b: u16,
240 c: u32,
241 d: u64,
242 e: String,
243 f: Vec<String>,
244 g: Nested,
245 }
246 };
247
248 let expected = quote! {
249 mod wire_format_niijima_先輩 {
250 extern crate std;
251 use self::std::io;
252 use self::std::result::Result::Ok;
253
254 use super::Niijima_先輩;
255
256 use protocol::WireFormat;
257
258 impl WireFormat for Niijima_先輩 {
259 fn byte_size(&self) -> u32 {
260 0
261 + WireFormat::byte_size(&self.a)
262 + WireFormat::byte_size(&self.b)
263 + WireFormat::byte_size(&self.c)
264 + WireFormat::byte_size(&self.d)
265 + WireFormat::byte_size(&self.e)
266 + WireFormat::byte_size(&self.f)
267 + WireFormat::byte_size(&self.g)
268 }
269
270 fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
271 WireFormat::encode(&self.a, _writer)?;
272 WireFormat::encode(&self.b, _writer)?;
273 WireFormat::encode(&self.c, _writer)?;
274 WireFormat::encode(&self.d, _writer)?;
275 WireFormat::encode(&self.e, _writer)?;
276 WireFormat::encode(&self.f, _writer)?;
277 WireFormat::encode(&self.g, _writer)?;
278 Ok(())
279 }
280 fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
281 let a = WireFormat::decode(_reader)?;
282 let b = WireFormat::decode(_reader)?;
283 let c = WireFormat::decode(_reader)?;
284 let d = WireFormat::decode(_reader)?;
285 let e = WireFormat::decode(_reader)?;
286 let f = WireFormat::decode(_reader)?;
287 let g = WireFormat::decode(_reader)?;
288 Ok(Niijima_先輩 {
289 a: a,
290 b: b,
291 c: c,
292 d: d,
293 e: e,
294 f: f,
295 g: g,
296 })
297 }
298 }
299 }
300 };
301
302 assert_eq!(
303 p9_wire_format_inner(input).to_string(),
304 expected.to_string(),
305 );
306 }
307 }
308