1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Derives a 9P wire format encoding for a struct by recursively calling
6 //! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
7 //! This is only intended to be used from within the `p9` crate.
8
9 #![recursion_limit = "256"]
10
11 extern crate proc_macro;
12 extern crate proc_macro2;
13
14 #[macro_use]
15 extern crate quote;
16
17 #[macro_use]
18 extern crate syn;
19
20 use proc_macro2::{Span, TokenStream};
21 use syn::spanned::Spanned;
22 use syn::{Data, DeriveInput, Fields, Ident};
23
24 /// The function that derives the actual implementation.
25 #[proc_macro_derive(P9WireFormat)]
p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream26 pub fn p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27 let input = parse_macro_input!(input as DeriveInput);
28 p9_wire_format_inner(input).into()
29 }
30
p9_wire_format_inner(input: DeriveInput) -> TokenStream31 fn p9_wire_format_inner(input: DeriveInput) -> TokenStream {
32 if !input.generics.params.is_empty() {
33 return quote! {
34 compile_error!("derive(P9WireFormat) does not support generic parameters");
35 };
36 }
37
38 let container = input.ident;
39
40 let byte_size_impl = byte_size_sum(&input.data);
41 let encode_impl = encode_wire_format(&input.data);
42 let decode_impl = decode_wire_format(&input.data, &container);
43
44 let scope = format!("wire_format_{}", container).to_lowercase();
45 let scope = Ident::new(&scope, Span::call_site());
46 quote! {
47 mod #scope {
48 extern crate std;
49 use self::std::io;
50 use self::std::result::Result::Ok;
51
52 use super::#container;
53
54 use protocol::WireFormat;
55
56 impl WireFormat for #container {
57 fn byte_size(&self) -> u32 {
58 #byte_size_impl
59 }
60
61 fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
62 #encode_impl
63 }
64
65 fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
66 #decode_impl
67 }
68 }
69 }
70 }
71 }
72
73 // Generate code that recursively calls byte_size on every field in the struct.
byte_size_sum(data: &Data) -> TokenStream74 fn byte_size_sum(data: &Data) -> TokenStream {
75 if let Data::Struct(ref data) = *data {
76 if let Fields::Named(ref fields) = data.fields {
77 let fields = fields.named.iter().map(|f| {
78 let field = &f.ident;
79 let span = field.span();
80 quote_spanned! {span=>
81 WireFormat::byte_size(&self.#field)
82 }
83 });
84
85 quote! {
86 0 #(+ #fields)*
87 }
88 } else {
89 unimplemented!();
90 }
91 } else {
92 unimplemented!();
93 }
94 }
95
96 // Generate code that recursively calls encode on every field in the struct.
encode_wire_format(data: &Data) -> TokenStream97 fn encode_wire_format(data: &Data) -> TokenStream {
98 if let Data::Struct(ref data) = *data {
99 if let Fields::Named(ref fields) = data.fields {
100 let fields = fields.named.iter().map(|f| {
101 let field = &f.ident;
102 let span = field.span();
103 quote_spanned! {span=>
104 WireFormat::encode(&self.#field, _writer)?;
105 }
106 });
107
108 quote! {
109 #(#fields)*
110
111 Ok(())
112 }
113 } else {
114 unimplemented!();
115 }
116 } else {
117 unimplemented!();
118 }
119 }
120
121 // Generate code that recursively calls decode on every field in the struct.
decode_wire_format(data: &Data, container: &Ident) -> TokenStream122 fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
123 if let Data::Struct(ref data) = *data {
124 if let Fields::Named(ref fields) = data.fields {
125 let values = fields.named.iter().map(|f| {
126 let field = &f.ident;
127 let span = field.span();
128 quote_spanned! {span=>
129 let #field = WireFormat::decode(_reader)?;
130 }
131 });
132
133 let members = fields.named.iter().map(|f| {
134 let field = &f.ident;
135 quote! {
136 #field: #field,
137 }
138 });
139
140 quote! {
141 #(#values)*
142
143 Ok(#container {
144 #(#members)*
145 })
146 }
147 } else {
148 unimplemented!();
149 }
150 } else {
151 unimplemented!();
152 }
153 }
154
155 #[cfg(test)]
156 mod tests {
157 use super::*;
158
159 #[test]
byte_size()160 fn byte_size() {
161 let input: DeriveInput = parse_quote! {
162 struct Item {
163 ident: u32,
164 with_underscores: String,
165 other: u8,
166 }
167 };
168
169 let expected = quote! {
170 0
171 + WireFormat::byte_size(&self.ident)
172 + WireFormat::byte_size(&self.with_underscores)
173 + WireFormat::byte_size(&self.other)
174 };
175
176 assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string());
177 }
178
179 #[test]
encode()180 fn encode() {
181 let input: DeriveInput = parse_quote! {
182 struct Item {
183 ident: u32,
184 with_underscores: String,
185 other: u8,
186 }
187 };
188
189 let expected = quote! {
190 WireFormat::encode(&self.ident, _writer)?;
191 WireFormat::encode(&self.with_underscores, _writer)?;
192 WireFormat::encode(&self.other, _writer)?;
193 Ok(())
194 };
195
196 assert_eq!(
197 encode_wire_format(&input.data).to_string(),
198 expected.to_string(),
199 );
200 }
201
202 #[test]
decode()203 fn decode() {
204 let input: DeriveInput = parse_quote! {
205 struct Item {
206 ident: u32,
207 with_underscores: String,
208 other: u8,
209 }
210 };
211
212 let container = Ident::new("Item", Span::call_site());
213 let expected = quote! {
214 let ident = WireFormat::decode(_reader)?;
215 let with_underscores = WireFormat::decode(_reader)?;
216 let other = WireFormat::decode(_reader)?;
217 Ok(Item {
218 ident: ident,
219 with_underscores: with_underscores,
220 other: other,
221 })
222 };
223
224 assert_eq!(
225 decode_wire_format(&input.data, &container).to_string(),
226 expected.to_string(),
227 );
228 }
229
230 #[test]
end_to_end()231 fn end_to_end() {
232 let input: DeriveInput = parse_quote! {
233 struct Niijima_先輩 {
234 a: u8,
235 b: u16,
236 c: u32,
237 d: u64,
238 e: String,
239 f: Vec<String>,
240 g: Nested,
241 }
242 };
243
244 let expected = quote! {
245 mod wire_format_niijima_先輩 {
246 extern crate std;
247 use self::std::io;
248 use self::std::result::Result::Ok;
249
250 use super::Niijima_先輩;
251
252 use protocol::WireFormat;
253
254 impl WireFormat for Niijima_先輩 {
255 fn byte_size(&self) -> u32 {
256 0
257 + WireFormat::byte_size(&self.a)
258 + WireFormat::byte_size(&self.b)
259 + WireFormat::byte_size(&self.c)
260 + WireFormat::byte_size(&self.d)
261 + WireFormat::byte_size(&self.e)
262 + WireFormat::byte_size(&self.f)
263 + WireFormat::byte_size(&self.g)
264 }
265
266 fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
267 WireFormat::encode(&self.a, _writer)?;
268 WireFormat::encode(&self.b, _writer)?;
269 WireFormat::encode(&self.c, _writer)?;
270 WireFormat::encode(&self.d, _writer)?;
271 WireFormat::encode(&self.e, _writer)?;
272 WireFormat::encode(&self.f, _writer)?;
273 WireFormat::encode(&self.g, _writer)?;
274 Ok(())
275 }
276 fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
277 let a = WireFormat::decode(_reader)?;
278 let b = WireFormat::decode(_reader)?;
279 let c = WireFormat::decode(_reader)?;
280 let d = WireFormat::decode(_reader)?;
281 let e = WireFormat::decode(_reader)?;
282 let f = WireFormat::decode(_reader)?;
283 let g = WireFormat::decode(_reader)?;
284 Ok(Niijima_先輩 {
285 a: a,
286 b: b,
287 c: c,
288 d: d,
289 e: e,
290 f: f,
291 g: g,
292 })
293 }
294 }
295 }
296 };
297
298 assert_eq!(
299 p9_wire_format_inner(input).to_string(),
300 expected.to_string(),
301 );
302 }
303 }
304