• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use crate::backends::rust_legacy::{
16     constraint_to_value, find_constrained_parent_fields, mask_bits, types, ToIdent,
17     ToUpperCamelCase,
18 };
19 use crate::{analyzer, ast};
20 use quote::{format_ident, quote};
21 use std::collections::{BTreeSet, HashMap};
22 
size_field_ident(id: &str) -> proc_macro2::Ident23 fn size_field_ident(id: &str) -> proc_macro2::Ident {
24     format_ident!("{}_size", id.trim_matches('_'))
25 }
26 
27 /// A single bit-field.
28 struct BitField<'a> {
29     shift: usize, // The shift to apply to this field.
30     field: &'a ast::Field,
31 }
32 
33 pub struct FieldParser<'a> {
34     scope: &'a analyzer::Scope<'a>,
35     schema: &'a analyzer::Schema,
36     endianness: ast::EndiannessValue,
37     decl: &'a ast::Decl,
38     packet_name: &'a str,
39     span: &'a proc_macro2::Ident,
40     chunk: Vec<BitField<'a>>,
41     code: Vec<proc_macro2::TokenStream>,
42     shift: usize,
43     offset: usize,
44 }
45 
46 impl<'a> FieldParser<'a> {
new( scope: &'a analyzer::Scope<'a>, schema: &'a analyzer::Schema, endianness: ast::EndiannessValue, packet_name: &'a str, span: &'a proc_macro2::Ident, ) -> FieldParser<'a>47     pub fn new(
48         scope: &'a analyzer::Scope<'a>,
49         schema: &'a analyzer::Schema,
50         endianness: ast::EndiannessValue,
51         packet_name: &'a str,
52         span: &'a proc_macro2::Ident,
53     ) -> FieldParser<'a> {
54         FieldParser {
55             scope,
56             schema,
57             endianness,
58             decl: scope.typedef[packet_name],
59             packet_name,
60             span,
61             chunk: Vec::new(),
62             code: Vec::new(),
63             shift: 0,
64             offset: 0,
65         }
66     }
67 
add(&mut self, field: &'a ast::Field)68     pub fn add(&mut self, field: &'a ast::Field) {
69         match &field.desc {
70             _ if field.cond.is_some() => self.add_optional_field(field),
71             _ if self.scope.is_bitfield(field) => self.add_bit_field(field),
72             ast::FieldDesc::Padding { .. } => (),
73             ast::FieldDesc::Array { id, width, type_id, size, .. } => self.add_array_field(
74                 id,
75                 *width,
76                 type_id.as_deref(),
77                 *size,
78                 self.schema.padded_size(field.key),
79                 self.scope.get_type_declaration(field),
80             ),
81             ast::FieldDesc::Typedef { id, type_id } => self.add_typedef_field(id, type_id),
82             ast::FieldDesc::Payload { size_modifier, .. } => {
83                 self.add_payload_field(size_modifier.as_deref())
84             }
85             ast::FieldDesc::Body { .. } => self.add_payload_field(None),
86             _ => todo!("{field:?}"),
87         }
88     }
89 
add_optional_field(&mut self, field: &'a ast::Field)90     fn add_optional_field(&mut self, field: &'a ast::Field) {
91         let cond_id = field.cond.as_ref().unwrap().id.to_ident();
92         let cond_value = syn::parse_str::<syn::LitInt>(&format!(
93             "{}",
94             field.cond.as_ref().unwrap().value.unwrap()
95         ))
96         .unwrap();
97 
98         self.code.push(match &field.desc {
99             ast::FieldDesc::Scalar { id, width } => {
100                 let id = id.to_ident();
101                 let value = types::get_uint(self.endianness, *width, self.span);
102                 quote! {
103                     let #id = (#cond_id == #cond_value).then(|| #value);
104                 }
105             }
106             ast::FieldDesc::Typedef { id, type_id } => match &self.scope.typedef[type_id].desc {
107                 ast::DeclDesc::Enum { width, .. } => {
108                     let name = id;
109                     let type_name = type_id;
110                     let id = id.to_ident();
111                     let type_id = type_id.to_ident();
112                     let decl_id = &self.packet_name;
113                     let value = types::get_uint(self.endianness, *width, self.span);
114                     quote! {
115                         let #id = (#cond_id == #cond_value)
116                             .then(||
117                                 #type_id::try_from(#value).map_err(|unknown_val| {
118                                     DecodeError::InvalidEnumValueError {
119                                         obj: #decl_id,
120                                         field: #name,
121                                         value: unknown_val as u64,
122                                         type_: #type_name,
123                                     }
124                                 }))
125                             .transpose()?;
126                     }
127                 }
128                 ast::DeclDesc::Struct { .. } => {
129                     let id = id.to_ident();
130                     let type_id = type_id.to_ident();
131                     let span = self.span;
132                     quote! {
133                         let #id = (#cond_id == #cond_value)
134                             .then(|| #type_id::parse_inner(&mut #span))
135                             .transpose()?;
136                     }
137                 }
138                 _ => unreachable!(),
139             },
140             _ => unreachable!(),
141         })
142     }
143 
add_bit_field(&mut self, field: &'a ast::Field)144     fn add_bit_field(&mut self, field: &'a ast::Field) {
145         self.chunk.push(BitField { shift: self.shift, field });
146         self.shift += self.schema.field_size(field.key).static_().unwrap();
147         if self.shift % 8 != 0 {
148             return;
149         }
150 
151         let size = self.shift / 8;
152         let end_offset = self.offset + size;
153 
154         let wanted = proc_macro2::Literal::usize_unsuffixed(size);
155         self.check_size(self.span, &quote!(#wanted));
156 
157         let chunk_type = types::Integer::new(self.shift);
158         // TODO(mgeisler): generate Rust variable names which cannot
159         // conflict with PDL field names. An option would be to start
160         // Rust variable names with `_`, but that has a special
161         // semantic in Rust.
162         let chunk_name = format_ident!("chunk");
163 
164         let get = types::get_uint(self.endianness, self.shift, self.span);
165         if self.chunk.len() > 1 {
166             // Multiple values: we read into a local variable.
167             self.code.push(quote! {
168                 let #chunk_name = #get;
169             });
170         }
171 
172         let single_value = self.chunk.len() == 1; // && self.chunk[0].offset == 0;
173         for BitField { shift, field } in self.chunk.drain(..) {
174             let mut v = if single_value {
175                 // Single value: read directly.
176                 quote! { #get }
177             } else {
178                 // Multiple values: read from `chunk_name`.
179                 quote! { #chunk_name }
180             };
181 
182             if shift > 0 {
183                 let shift = proc_macro2::Literal::usize_unsuffixed(shift);
184                 v = quote! { (#v >> #shift) }
185             }
186 
187             let width = self.schema.field_size(field.key).static_().unwrap();
188             let value_type = types::Integer::new(width);
189             if !single_value && width < value_type.width {
190                 // Mask value if we grabbed more than `width` and if
191                 // `as #value_type` doesn't already do the masking.
192                 let mask = mask_bits(width, "u64");
193                 v = quote! { (#v & #mask) };
194             }
195 
196             if value_type.width < chunk_type.width {
197                 v = quote! { #v as #value_type };
198             }
199 
200             self.code.push(match &field.desc {
201                 ast::FieldDesc::Scalar { id, .. }
202                 | ast::FieldDesc::Flag { id, .. } => {
203                     let id = id.to_ident();
204                     quote! {
205                         let #id = #v;
206                     }
207                 }
208                 ast::FieldDesc::FixedEnum { enum_id, tag_id, .. } => {
209                     let enum_id = enum_id.to_ident();
210                     let tag_id = tag_id.to_upper_camel_case().to_ident();
211                     quote! {
212                         let fixed_value = #v;
213                         if fixed_value != #value_type::from(#enum_id::#tag_id)  {
214                             return Err(DecodeError::InvalidFixedValue {
215                                 expected: #value_type::from(#enum_id::#tag_id) as u64,
216                                 actual: fixed_value as u64,
217                             });
218                         }
219                     }
220                 }
221                 ast::FieldDesc::FixedScalar { value, .. } => {
222                     let value = proc_macro2::Literal::usize_unsuffixed(*value);
223                     quote! {
224                         let fixed_value = #v;
225                         if fixed_value != #value {
226                             return Err(DecodeError::InvalidFixedValue {
227                                 expected: #value,
228                                 actual: fixed_value as u64,
229                             });
230                         }
231                     }
232                 }
233                 ast::FieldDesc::Typedef { id, type_id } => {
234                     let field_name = id;
235                     let type_name = type_id;
236                     let packet_name = &self.packet_name;
237                     let id = id.to_ident();
238                     let type_id = type_id.to_ident();
239                     quote! {
240                         let #id = #type_id::try_from(#v).map_err(|unknown_val| DecodeError::InvalidEnumValueError {
241                             obj: #packet_name,
242                             field: #field_name,
243                             value: unknown_val as u64,
244                             type_: #type_name,
245                         })?;
246                     }
247                 }
248                 ast::FieldDesc::Reserved { .. } => {
249                     if single_value {
250                         let span = self.span;
251                         let size = proc_macro2::Literal::usize_unsuffixed(size);
252                         quote! {
253                             #span.get_mut().advance(#size);
254                         }
255                     } else {
256                         //  Otherwise we don't need anything: we will
257                         //  have advanced past the reserved field when
258                         //  reading the chunk above.
259                         quote! {}
260                     }
261                 }
262                 ast::FieldDesc::Size { field_id, .. } => {
263                     let id = size_field_ident(field_id);
264                     quote! {
265                         let #id = #v as usize;
266                     }
267                 }
268                 ast::FieldDesc::Count { field_id, .. } => {
269                     let id = format_ident!("{field_id}_count");
270                     quote! {
271                         let #id = #v as usize;
272                     }
273                 }
274                 _ => todo!(),
275             });
276         }
277 
278         self.offset = end_offset;
279         self.shift = 0;
280     }
281 
find_count_field(&self, id: &str) -> Option<proc_macro2::Ident>282     fn find_count_field(&self, id: &str) -> Option<proc_macro2::Ident> {
283         match self.decl.array_size(id)?.desc {
284             ast::FieldDesc::Count { .. } => Some(format_ident!("{id}_count")),
285             _ => None,
286         }
287     }
288 
find_size_field(&self, id: &str) -> Option<proc_macro2::Ident>289     fn find_size_field(&self, id: &str) -> Option<proc_macro2::Ident> {
290         match self.decl.array_size(id)?.desc {
291             ast::FieldDesc::Size { .. } => Some(size_field_ident(id)),
292             _ => None,
293         }
294     }
295 
payload_field_offset_from_end(&self) -> Option<usize>296     fn payload_field_offset_from_end(&self) -> Option<usize> {
297         let decl = self.scope.typedef[self.packet_name];
298         let mut fields = decl.fields();
299         fields.find(|f| {
300             matches!(f.desc, ast::FieldDesc::Body { .. } | ast::FieldDesc::Payload { .. })
301         })?;
302 
303         let mut offset = 0;
304         for field in fields {
305             if let Some(width) =
306                 self.schema.padded_size(field.key).or(self.schema.field_size(field.key).static_())
307             {
308                 offset += width;
309             } else {
310                 return None;
311             }
312         }
313 
314         Some(offset)
315     }
316 
check_size(&mut self, span: &proc_macro2::Ident, wanted: &proc_macro2::TokenStream)317     fn check_size(&mut self, span: &proc_macro2::Ident, wanted: &proc_macro2::TokenStream) {
318         let packet_name = &self.packet_name;
319         self.code.push(quote! {
320             if #span.get().remaining() < #wanted {
321                 return Err(DecodeError::InvalidLengthError {
322                     obj: #packet_name,
323                     wanted: #wanted,
324                     got: #span.get().remaining(),
325                 });
326             }
327         });
328     }
329 
add_array_field( &mut self, id: &str, width: Option<usize>, type_id: Option<&str>, size: Option<usize>, padding_size: Option<usize>, decl: Option<&ast::Decl>, )330     fn add_array_field(
331         &mut self,
332         id: &str,
333         // `width`: the width in bits of the array elements (if Some).
334         width: Option<usize>,
335         // `type_id`: the enum type of the array elements (if Some).
336         // Mutually exclusive with `width`.
337         type_id: Option<&str>,
338         // `size`: the size of the array in number of elements (if
339         // known). If None, the array is a Vec with a dynamic size.
340         size: Option<usize>,
341         padding_size: Option<usize>,
342         decl: Option<&ast::Decl>,
343     ) {
344         enum ElementWidth {
345             Static(usize), // Static size in bytes.
346             Unknown,
347         }
348         let element_width =
349             match width.or_else(|| self.schema.total_size(decl.unwrap().key).static_()) {
350                 Some(w) => {
351                     assert_eq!(w % 8, 0, "Array element size ({w}) is not a multiple of 8");
352                     ElementWidth::Static(w / 8)
353                 }
354                 None => ElementWidth::Unknown,
355             };
356 
357         // The "shape" of the array, i.e., the number of elements
358         // given via a static count, a count field, a size field, or
359         // unknown.
360         enum ArrayShape {
361             Static(usize),                  // Static count
362             CountField(proc_macro2::Ident), // Count based on count field
363             SizeField(proc_macro2::Ident),  // Count based on size and field
364             Unknown,                        // Variable count based on remaining bytes
365         }
366         let array_shape = if let Some(count) = size {
367             ArrayShape::Static(count)
368         } else if let Some(count_field) = self.find_count_field(id) {
369             ArrayShape::CountField(count_field)
370         } else if let Some(size_field) = self.find_size_field(id) {
371             ArrayShape::SizeField(size_field)
372         } else {
373             ArrayShape::Unknown
374         };
375 
376         // TODO size modifier
377 
378         let span = match padding_size {
379             Some(padding_size) => {
380                 let span = self.span;
381                 let padding_octets = padding_size / 8;
382                 self.check_size(span, &quote!(#padding_octets));
383                 self.code.push(quote! {
384                     let (head, tail) = #span.get().split_at(#padding_octets);
385                     let mut head = &mut Cell::new(head);
386                     #span.replace(tail);
387                 });
388                 format_ident!("head")
389             }
390             None => self.span.clone(),
391         };
392 
393         let id = id.to_ident();
394 
395         let parse_element = self.parse_array_element(&span, width, type_id, decl);
396         match (element_width, &array_shape) {
397             (ElementWidth::Unknown, ArrayShape::SizeField(size_field)) => {
398                 // The element width is not known, but the array full
399                 // octet size is known by size field. Parse elements
400                 // item by item as a vector.
401                 self.check_size(&span, &quote!(#size_field));
402                 let parse_element =
403                     self.parse_array_element(&format_ident!("head"), width, type_id, decl);
404                 self.code.push(quote! {
405                     let (head, tail) = #span.get().split_at(#size_field);
406                     let mut head = &mut Cell::new(head);
407                     #span.replace(tail);
408                     let mut #id = Vec::new();
409                     while !head.get().is_empty() {
410                         #id.push(#parse_element?);
411                     }
412                 });
413             }
414             (ElementWidth::Unknown, ArrayShape::Static(count)) => {
415                 // The element width is not known, but the array
416                 // element count is known statically. Parse elements
417                 // item by item as an array.
418                 let count = syn::Index::from(*count);
419                 self.code.push(quote! {
420                     // TODO(mgeisler): use
421                     // https://doc.rust-lang.org/std/array/fn.try_from_fn.html
422                     // when stabilized.
423                     let #id = (0..#count)
424                         .map(|_| #parse_element)
425                         .collect::<Result<Vec<_>, DecodeError>>()?
426                         .try_into()
427                         .map_err(|_| DecodeError::InvalidPacketError)?;
428                 });
429             }
430             (ElementWidth::Unknown, ArrayShape::CountField(count_field)) => {
431                 // The element width is not known, but the array
432                 // element count is known by the count field. Parse
433                 // elements item by item as a vector.
434                 self.code.push(quote! {
435                     let #id = (0..#count_field)
436                         .map(|_| #parse_element)
437                         .collect::<Result<Vec<_>, DecodeError>>()?;
438                 });
439             }
440             (ElementWidth::Unknown, ArrayShape::Unknown) => {
441                 // Neither the count not size is known, parse elements
442                 // until the end of the span.
443                 self.code.push(quote! {
444                     let mut #id = Vec::new();
445                     while !#span.get().is_empty() {
446                         #id.push(#parse_element?);
447                     }
448                 });
449             }
450             (ElementWidth::Static(element_width), ArrayShape::Static(count)) => {
451                 // The element width is known, and the array element
452                 // count is known statically.
453                 let count = syn::Index::from(*count);
454                 // This creates a nicely formatted size.
455                 let array_size = if element_width == 1 {
456                     quote!(#count)
457                 } else {
458                     let element_width = syn::Index::from(element_width);
459                     quote!(#count * #element_width)
460                 };
461                 self.check_size(&span, &quote! { #array_size });
462                 self.code.push(quote! {
463                     // TODO(mgeisler): use
464                     // https://doc.rust-lang.org/std/array/fn.try_from_fn.html
465                     // when stabilized.
466                     let #id = (0..#count)
467                         .map(|_| #parse_element)
468                         .collect::<Result<Vec<_>, DecodeError>>()?
469                         .try_into()
470                         .map_err(|_| DecodeError::InvalidPacketError)?;
471                 });
472             }
473             (ElementWidth::Static(element_width), ArrayShape::CountField(count_field)) => {
474                 // The element width is known, and the array element
475                 // count is known dynamically by the count field.
476                 self.check_size(&span, &quote!(#count_field * #element_width));
477                 self.code.push(quote! {
478                     let #id = (0..#count_field)
479                         .map(|_| #parse_element)
480                         .collect::<Result<Vec<_>, DecodeError>>()?;
481                 });
482             }
483             (ElementWidth::Static(element_width), ArrayShape::SizeField(_))
484             | (ElementWidth::Static(element_width), ArrayShape::Unknown) => {
485                 // The element width is known, and the array full size
486                 // is known by size field, or unknown (in which case
487                 // it is the remaining span length).
488                 let array_size = if let ArrayShape::SizeField(size_field) = &array_shape {
489                     self.check_size(&span, &quote!(#size_field));
490                     quote!(#size_field)
491                 } else {
492                     quote!(#span.get().remaining())
493                 };
494                 let count_field = format_ident!("{id}_count");
495                 let array_count = if element_width != 1 {
496                     let element_width = syn::Index::from(element_width);
497                     self.code.push(quote! {
498                         if #array_size % #element_width != 0 {
499                             return Err(DecodeError::InvalidArraySize {
500                                 array: #array_size,
501                                 element: #element_width,
502                             });
503                         }
504                         let #count_field = #array_size / #element_width;
505                     });
506                     quote!(#count_field)
507                 } else {
508                     array_size
509                 };
510 
511                 self.code.push(quote! {
512                     let mut #id = Vec::with_capacity(#array_count);
513                     for _ in 0..#array_count {
514                         #id.push(#parse_element?);
515                     }
516                 });
517             }
518         }
519     }
520 
521     /// Parse typedef fields.
522     ///
523     /// This is only for non-enum fields: enums are parsed via
524     /// add_bit_field.
add_typedef_field(&mut self, id: &str, type_id: &str)525     fn add_typedef_field(&mut self, id: &str, type_id: &str) {
526         assert_eq!(self.shift, 0, "Typedef field does not start on an octet boundary");
527 
528         let decl = self.scope.typedef[type_id];
529         if let ast::DeclDesc::Struct { parent_id: Some(_), .. } = &decl.desc {
530             panic!("Derived struct used in typedef field");
531         }
532 
533         let span = self.span;
534         let id = id.to_ident();
535         let type_id = type_id.to_ident();
536 
537         self.code.push(match self.schema.decl_size(decl.key) {
538             analyzer::Size::Unknown | analyzer::Size::Dynamic => quote! {
539                 let #id = #type_id::parse_inner(&mut #span)?;
540             },
541             analyzer::Size::Static(width) => {
542                 assert_eq!(width % 8, 0, "Typedef field type size is not a multiple of 8");
543                 match &decl.desc {
544                     ast::DeclDesc::Checksum { .. } => todo!(),
545                     ast::DeclDesc::CustomField { .. } if [8, 16, 32, 64].contains(&width) => {
546                         let get_uint = types::get_uint(self.endianness, width, span);
547                         quote! {
548                             let #id = #get_uint.into();
549                         }
550                     }
551                     ast::DeclDesc::CustomField { .. } => {
552                         let get_uint = types::get_uint(self.endianness, width, span);
553                         quote! {
554                             let #id = (#get_uint)
555                                 .try_into()
556                                 .unwrap(); // Value is masked and conversion must succeed.
557                         }
558                     }
559                     ast::DeclDesc::Struct { .. } => {
560                         let width = syn::Index::from(width / 8);
561                         quote! {
562                             let (head, tail) = #span.get().split_at(#width);
563                             #span.replace(tail);
564                             let #id = #type_id::parse(head)?;
565                         }
566                     }
567                     _ => unreachable!(),
568                 }
569             }
570         });
571     }
572 
573     /// Parse body and payload fields.
add_payload_field(&mut self, size_modifier: Option<&str>)574     fn add_payload_field(&mut self, size_modifier: Option<&str>) {
575         let span = self.span;
576         let payload_size_field = self.decl.payload_size();
577         let offset_from_end = self.payload_field_offset_from_end();
578 
579         if self.shift != 0 {
580             todo!("Unexpected non byte aligned payload");
581         }
582 
583         if let Some(ast::FieldDesc::Size { field_id, .. }) = &payload_size_field.map(|f| &f.desc) {
584             // The payload or body has a known size. Consume the
585             // payload and update the span in case fields are placed
586             // after the payload.
587             let size_field = size_field_ident(field_id);
588             if let Some(size_modifier) = size_modifier {
589                 let size_modifier = proc_macro2::Literal::usize_unsuffixed(
590                     size_modifier.parse::<usize>().expect("failed to parse the size modifier"),
591                 );
592                 let packet_name = &self.packet_name;
593                 // Push code to check that the size is greater than the size
594                 // modifier. Required to safely substract the modifier from the
595                 // size.
596                 self.code.push(quote! {
597                     if #size_field < #size_modifier {
598                         return Err(DecodeError::InvalidLengthError {
599                             obj: #packet_name,
600                             wanted: #size_modifier,
601                             got: #size_field,
602                         });
603                     }
604                     let #size_field = #size_field - #size_modifier;
605                 });
606             }
607             self.check_size(self.span, &quote!(#size_field ));
608             self.code.push(quote! {
609                 let payload = &#span.get()[..#size_field];
610                 #span.get_mut().advance(#size_field);
611             });
612         } else if offset_from_end == Some(0) {
613             // The payload or body is the last field of a packet,
614             // consume the remaining span.
615             self.code.push(quote! {
616                 let payload = #span.get();
617                 #span.get_mut().advance(payload.len());
618             });
619         } else if let Some(offset_from_end) = offset_from_end {
620             // The payload or body is followed by fields of static
621             // size. Consume the span that is not reserved for the
622             // following fields.
623             assert_eq!(
624                 offset_from_end % 8,
625                 0,
626                 "Payload field offset from end of packet is not a multiple of 8"
627             );
628             let offset_from_end = syn::Index::from(offset_from_end / 8);
629             self.check_size(self.span, &quote!(#offset_from_end));
630             self.code.push(quote! {
631                 let payload = &#span.get()[..#span.get().len() - #offset_from_end];
632                 #span.get_mut().advance(payload.len());
633             });
634         }
635 
636         let decl = self.scope.typedef[self.packet_name];
637         if let ast::DeclDesc::Struct { .. } = &decl.desc {
638             self.code.push(quote! {
639                 let payload = Vec::from(payload);
640             });
641         }
642     }
643 
644     /// Parse a single array field element from `span`.
parse_array_element( &self, span: &proc_macro2::Ident, width: Option<usize>, type_id: Option<&str>, decl: Option<&ast::Decl>, ) -> proc_macro2::TokenStream645     fn parse_array_element(
646         &self,
647         span: &proc_macro2::Ident,
648         width: Option<usize>,
649         type_id: Option<&str>,
650         decl: Option<&ast::Decl>,
651     ) -> proc_macro2::TokenStream {
652         if let Some(width) = width {
653             let get_uint = types::get_uint(self.endianness, width, span);
654             return quote! {
655                 Ok::<_, DecodeError>(#get_uint)
656             };
657         }
658 
659         if let Some(ast::DeclDesc::Enum { id, width, .. }) = decl.map(|decl| &decl.desc) {
660             let get_uint = types::get_uint(self.endianness, *width, span);
661             let type_id = id.to_ident();
662             let packet_name = &self.packet_name;
663             return quote! {
664                 #type_id::try_from(#get_uint).map_err(|unknown_val| DecodeError::InvalidEnumValueError {
665                     obj: #packet_name,
666                     field: "", // TODO(mgeisler): fill out or remove
667                     value: unknown_val as u64,
668                     type_: #id,
669                 })
670             };
671         }
672 
673         let type_id = type_id.unwrap().to_ident();
674         quote! {
675             #type_id::parse_inner(#span)
676         }
677     }
678 
done(&mut self)679     pub fn done(&mut self) {
680         let decl = self.scope.typedef[self.packet_name];
681         if let ast::DeclDesc::Struct { .. } = &decl.desc {
682             return; // Structs don't parse the child structs recursively.
683         }
684 
685         let children = self.scope.iter_children(decl).collect::<Vec<_>>();
686         if children.is_empty() && self.decl.payload().is_none() {
687             return;
688         }
689 
690         let all_fields = HashMap::<String, _>::from_iter(
691             self.scope.iter_fields(decl).filter_map(|f| f.id().map(|id| (id.to_string(), f))),
692         );
693 
694         // Gather fields that are constrained in immediate child declarations.
695         // Keep the fields sorted by name.
696         // TODO: fields that are only matched in grand children will not be included.
697         let constrained_fields = children
698             .iter()
699             .flat_map(|child| child.constraints().map(|c| &c.id))
700             .collect::<BTreeSet<_>>();
701 
702         let mut match_values = Vec::new();
703         let mut child_parse_args = Vec::new();
704         let mut child_ids_data = Vec::new();
705         let mut child_ids = Vec::new();
706 
707         let get_constraint_value = |mut constraints: std::slice::Iter<'_, ast::Constraint>,
708                                     id: &str|
709          -> Option<proc_macro2::TokenStream> {
710             constraints.find(|c| c.id == id).map(|c| constraint_to_value(&all_fields, c))
711         };
712 
713         for child in children.iter() {
714             let tuple_values = constrained_fields
715                 .iter()
716                 .map(|id| {
717                     get_constraint_value(child.constraints(), id).map(|v| vec![v]).unwrap_or_else(
718                         || {
719                             self.scope
720                                 .file
721                                 .iter_children(child)
722                                 .filter_map(|d| get_constraint_value(d.constraints(), id))
723                                 .collect()
724                         },
725                     )
726                 })
727                 .collect::<Vec<_>>();
728 
729             // If no constraint values are found for the tuple just skip the child
730             // packet as it would capture unwanted input packets.
731             if tuple_values.iter().all(|v| v.is_empty()) {
732                 continue;
733             }
734 
735             let tuple_values = tuple_values
736                 .iter()
737                 .map(|v| v.is_empty().then_some(quote!(_)).unwrap_or_else(|| quote!( #(#v)|* )))
738                 .collect::<Vec<_>>();
739 
740             let fields = find_constrained_parent_fields(self.scope, child.id().unwrap())
741                 .iter()
742                 .map(|field| field.id().unwrap().to_ident())
743                 .collect::<Vec<_>>();
744 
745             match_values.push(quote!( (#(#tuple_values),*) ));
746             child_parse_args.push(quote!( #(, #fields)*));
747             child_ids_data.push(format_ident!("{}Data", child.id().unwrap()));
748             child_ids.push(child.id().unwrap().to_ident());
749         }
750 
751         let constrained_field_idents = constrained_fields.iter().map(|field| field.to_ident());
752         let packet_data_child = format_ident!("{}DataChild", self.packet_name);
753 
754         // Parsing of packet children requires having a payload field;
755         // it is allowed to inherit from a packet with empty payload, in this
756         // case generate an empty payload value.
757         if !decl
758             .fields()
759             .any(|f| matches!(&f.desc, ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body))
760         {
761             self.code.push(quote! {
762                 let payload: &[u8] = &[];
763             })
764         }
765         self.code.push(quote! {
766             let child = match (#(#constrained_field_idents),*) {
767                 #(#match_values if #child_ids_data::conforms(&payload) => {
768                     let mut cell = Cell::new(payload);
769                     let child_data = #child_ids_data::parse_inner(&mut cell #child_parse_args)?;
770                     // TODO(mgeisler): communicate back to user if !cell.get().is_empty()?
771                     #packet_data_child::#child_ids(child_data)
772                 }),*
773                 _ if !payload.is_empty() => {
774                     #packet_data_child::Payload(Bytes::copy_from_slice(payload))
775                 }
776                 _ => #packet_data_child::None,
777             };
778         });
779     }
780 }
781 
782 impl quote::ToTokens for FieldParser<'_> {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)783     fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
784         let code = &self.code;
785         tokens.extend(quote! {
786             #(#code)*
787         });
788     }
789 }
790 
791 #[cfg(test)]
792 mod tests {
793     use super::*;
794     use crate::analyzer;
795     use crate::ast;
796     use crate::parser::parse_inline;
797 
798     /// Parse a string fragment as a PDL file.
799     ///
800     /// # Panics
801     ///
802     /// Panics on parse errors.
parse_str(text: &str) -> ast::File803     pub fn parse_str(text: &str) -> ast::File {
804         let mut db = ast::SourceDatabase::new();
805         let file = parse_inline(&mut db, "stdin", String::from(text)).expect("parse error");
806         analyzer::analyze(&file).expect("analyzer error")
807     }
808 
809     #[test]
test_find_fields_static()810     fn test_find_fields_static() {
811         let code = "
812               little_endian_packets
813               packet P {
814                 a: 24[3],
815               }
816             ";
817         let file = parse_str(code);
818         let scope = analyzer::Scope::new(&file).unwrap();
819         let schema = analyzer::Schema::new(&file);
820         let span = format_ident!("bytes");
821         let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
822         assert_eq!(parser.find_size_field("a"), None);
823         assert_eq!(parser.find_count_field("a"), None);
824     }
825 
826     #[test]
test_find_fields_dynamic_count()827     fn test_find_fields_dynamic_count() {
828         let code = "
829               little_endian_packets
830               packet P {
831                 _count_(b): 24,
832                 b: 16[],
833               }
834             ";
835         let file = parse_str(code);
836         let scope = analyzer::Scope::new(&file).unwrap();
837         let schema = analyzer::Schema::new(&file);
838         let span = format_ident!("bytes");
839         let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
840         assert_eq!(parser.find_size_field("b"), None);
841         assert_eq!(parser.find_count_field("b"), Some(format_ident!("b_count")));
842     }
843 
844     #[test]
test_find_fields_dynamic_size()845     fn test_find_fields_dynamic_size() {
846         let code = "
847               little_endian_packets
848               packet P {
849                 _size_(c): 8,
850                 c: 24[],
851               }
852             ";
853         let file = parse_str(code);
854         let scope = analyzer::Scope::new(&file).unwrap();
855         let schema = analyzer::Schema::new(&file);
856         let span = format_ident!("bytes");
857         let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
858         assert_eq!(parser.find_size_field("c"), Some(format_ident!("c_size")));
859         assert_eq!(parser.find_count_field("c"), None);
860     }
861 }
862