1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use crate::backends::rust_legacy::{
16 constraint_to_value, find_constrained_parent_fields, mask_bits, types, ToIdent,
17 ToUpperCamelCase,
18 };
19 use crate::{analyzer, ast};
20 use quote::{format_ident, quote};
21 use std::collections::{BTreeSet, HashMap};
22
size_field_ident(id: &str) -> proc_macro2::Ident23 fn size_field_ident(id: &str) -> proc_macro2::Ident {
24 format_ident!("{}_size", id.trim_matches('_'))
25 }
26
27 /// A single bit-field.
28 struct BitField<'a> {
29 shift: usize, // The shift to apply to this field.
30 field: &'a ast::Field,
31 }
32
33 pub struct FieldParser<'a> {
34 scope: &'a analyzer::Scope<'a>,
35 schema: &'a analyzer::Schema,
36 endianness: ast::EndiannessValue,
37 decl: &'a ast::Decl,
38 packet_name: &'a str,
39 span: &'a proc_macro2::Ident,
40 chunk: Vec<BitField<'a>>,
41 code: Vec<proc_macro2::TokenStream>,
42 shift: usize,
43 offset: usize,
44 }
45
46 impl<'a> FieldParser<'a> {
new( scope: &'a analyzer::Scope<'a>, schema: &'a analyzer::Schema, endianness: ast::EndiannessValue, packet_name: &'a str, span: &'a proc_macro2::Ident, ) -> FieldParser<'a>47 pub fn new(
48 scope: &'a analyzer::Scope<'a>,
49 schema: &'a analyzer::Schema,
50 endianness: ast::EndiannessValue,
51 packet_name: &'a str,
52 span: &'a proc_macro2::Ident,
53 ) -> FieldParser<'a> {
54 FieldParser {
55 scope,
56 schema,
57 endianness,
58 decl: scope.typedef[packet_name],
59 packet_name,
60 span,
61 chunk: Vec::new(),
62 code: Vec::new(),
63 shift: 0,
64 offset: 0,
65 }
66 }
67
add(&mut self, field: &'a ast::Field)68 pub fn add(&mut self, field: &'a ast::Field) {
69 match &field.desc {
70 _ if field.cond.is_some() => self.add_optional_field(field),
71 _ if self.scope.is_bitfield(field) => self.add_bit_field(field),
72 ast::FieldDesc::Padding { .. } => (),
73 ast::FieldDesc::Array { id, width, type_id, size, .. } => self.add_array_field(
74 id,
75 *width,
76 type_id.as_deref(),
77 *size,
78 self.schema.padded_size(field.key),
79 self.scope.get_type_declaration(field),
80 ),
81 ast::FieldDesc::Typedef { id, type_id } => self.add_typedef_field(id, type_id),
82 ast::FieldDesc::Payload { size_modifier, .. } => {
83 self.add_payload_field(size_modifier.as_deref())
84 }
85 ast::FieldDesc::Body { .. } => self.add_payload_field(None),
86 _ => todo!("{field:?}"),
87 }
88 }
89
add_optional_field(&mut self, field: &'a ast::Field)90 fn add_optional_field(&mut self, field: &'a ast::Field) {
91 let cond_id = field.cond.as_ref().unwrap().id.to_ident();
92 let cond_value = syn::parse_str::<syn::LitInt>(&format!(
93 "{}",
94 field.cond.as_ref().unwrap().value.unwrap()
95 ))
96 .unwrap();
97
98 self.code.push(match &field.desc {
99 ast::FieldDesc::Scalar { id, width } => {
100 let id = id.to_ident();
101 let value = types::get_uint(self.endianness, *width, self.span);
102 quote! {
103 let #id = (#cond_id == #cond_value).then(|| #value);
104 }
105 }
106 ast::FieldDesc::Typedef { id, type_id } => match &self.scope.typedef[type_id].desc {
107 ast::DeclDesc::Enum { width, .. } => {
108 let name = id;
109 let type_name = type_id;
110 let id = id.to_ident();
111 let type_id = type_id.to_ident();
112 let decl_id = &self.packet_name;
113 let value = types::get_uint(self.endianness, *width, self.span);
114 quote! {
115 let #id = (#cond_id == #cond_value)
116 .then(||
117 #type_id::try_from(#value).map_err(|unknown_val| {
118 DecodeError::InvalidEnumValueError {
119 obj: #decl_id,
120 field: #name,
121 value: unknown_val as u64,
122 type_: #type_name,
123 }
124 }))
125 .transpose()?;
126 }
127 }
128 ast::DeclDesc::Struct { .. } => {
129 let id = id.to_ident();
130 let type_id = type_id.to_ident();
131 let span = self.span;
132 quote! {
133 let #id = (#cond_id == #cond_value)
134 .then(|| #type_id::parse_inner(&mut #span))
135 .transpose()?;
136 }
137 }
138 _ => unreachable!(),
139 },
140 _ => unreachable!(),
141 })
142 }
143
add_bit_field(&mut self, field: &'a ast::Field)144 fn add_bit_field(&mut self, field: &'a ast::Field) {
145 self.chunk.push(BitField { shift: self.shift, field });
146 self.shift += self.schema.field_size(field.key).static_().unwrap();
147 if self.shift % 8 != 0 {
148 return;
149 }
150
151 let size = self.shift / 8;
152 let end_offset = self.offset + size;
153
154 let wanted = proc_macro2::Literal::usize_unsuffixed(size);
155 self.check_size(self.span, "e!(#wanted));
156
157 let chunk_type = types::Integer::new(self.shift);
158 // TODO(mgeisler): generate Rust variable names which cannot
159 // conflict with PDL field names. An option would be to start
160 // Rust variable names with `_`, but that has a special
161 // semantic in Rust.
162 let chunk_name = format_ident!("chunk");
163
164 let get = types::get_uint(self.endianness, self.shift, self.span);
165 if self.chunk.len() > 1 {
166 // Multiple values: we read into a local variable.
167 self.code.push(quote! {
168 let #chunk_name = #get;
169 });
170 }
171
172 let single_value = self.chunk.len() == 1; // && self.chunk[0].offset == 0;
173 for BitField { shift, field } in self.chunk.drain(..) {
174 let mut v = if single_value {
175 // Single value: read directly.
176 quote! { #get }
177 } else {
178 // Multiple values: read from `chunk_name`.
179 quote! { #chunk_name }
180 };
181
182 if shift > 0 {
183 let shift = proc_macro2::Literal::usize_unsuffixed(shift);
184 v = quote! { (#v >> #shift) }
185 }
186
187 let width = self.schema.field_size(field.key).static_().unwrap();
188 let value_type = types::Integer::new(width);
189 if !single_value && width < value_type.width {
190 // Mask value if we grabbed more than `width` and if
191 // `as #value_type` doesn't already do the masking.
192 let mask = mask_bits(width, "u64");
193 v = quote! { (#v & #mask) };
194 }
195
196 if value_type.width < chunk_type.width {
197 v = quote! { #v as #value_type };
198 }
199
200 self.code.push(match &field.desc {
201 ast::FieldDesc::Scalar { id, .. }
202 | ast::FieldDesc::Flag { id, .. } => {
203 let id = id.to_ident();
204 quote! {
205 let #id = #v;
206 }
207 }
208 ast::FieldDesc::FixedEnum { enum_id, tag_id, .. } => {
209 let enum_id = enum_id.to_ident();
210 let tag_id = tag_id.to_upper_camel_case().to_ident();
211 quote! {
212 let fixed_value = #v;
213 if fixed_value != #value_type::from(#enum_id::#tag_id) {
214 return Err(DecodeError::InvalidFixedValue {
215 expected: #value_type::from(#enum_id::#tag_id) as u64,
216 actual: fixed_value as u64,
217 });
218 }
219 }
220 }
221 ast::FieldDesc::FixedScalar { value, .. } => {
222 let value = proc_macro2::Literal::usize_unsuffixed(*value);
223 quote! {
224 let fixed_value = #v;
225 if fixed_value != #value {
226 return Err(DecodeError::InvalidFixedValue {
227 expected: #value,
228 actual: fixed_value as u64,
229 });
230 }
231 }
232 }
233 ast::FieldDesc::Typedef { id, type_id } => {
234 let field_name = id;
235 let type_name = type_id;
236 let packet_name = &self.packet_name;
237 let id = id.to_ident();
238 let type_id = type_id.to_ident();
239 quote! {
240 let #id = #type_id::try_from(#v).map_err(|unknown_val| DecodeError::InvalidEnumValueError {
241 obj: #packet_name,
242 field: #field_name,
243 value: unknown_val as u64,
244 type_: #type_name,
245 })?;
246 }
247 }
248 ast::FieldDesc::Reserved { .. } => {
249 if single_value {
250 let span = self.span;
251 let size = proc_macro2::Literal::usize_unsuffixed(size);
252 quote! {
253 #span.get_mut().advance(#size);
254 }
255 } else {
256 // Otherwise we don't need anything: we will
257 // have advanced past the reserved field when
258 // reading the chunk above.
259 quote! {}
260 }
261 }
262 ast::FieldDesc::Size { field_id, .. } => {
263 let id = size_field_ident(field_id);
264 quote! {
265 let #id = #v as usize;
266 }
267 }
268 ast::FieldDesc::Count { field_id, .. } => {
269 let id = format_ident!("{field_id}_count");
270 quote! {
271 let #id = #v as usize;
272 }
273 }
274 _ => todo!(),
275 });
276 }
277
278 self.offset = end_offset;
279 self.shift = 0;
280 }
281
find_count_field(&self, id: &str) -> Option<proc_macro2::Ident>282 fn find_count_field(&self, id: &str) -> Option<proc_macro2::Ident> {
283 match self.decl.array_size(id)?.desc {
284 ast::FieldDesc::Count { .. } => Some(format_ident!("{id}_count")),
285 _ => None,
286 }
287 }
288
find_size_field(&self, id: &str) -> Option<proc_macro2::Ident>289 fn find_size_field(&self, id: &str) -> Option<proc_macro2::Ident> {
290 match self.decl.array_size(id)?.desc {
291 ast::FieldDesc::Size { .. } => Some(size_field_ident(id)),
292 _ => None,
293 }
294 }
295
payload_field_offset_from_end(&self) -> Option<usize>296 fn payload_field_offset_from_end(&self) -> Option<usize> {
297 let decl = self.scope.typedef[self.packet_name];
298 let mut fields = decl.fields();
299 fields.find(|f| {
300 matches!(f.desc, ast::FieldDesc::Body { .. } | ast::FieldDesc::Payload { .. })
301 })?;
302
303 let mut offset = 0;
304 for field in fields {
305 if let Some(width) =
306 self.schema.padded_size(field.key).or(self.schema.field_size(field.key).static_())
307 {
308 offset += width;
309 } else {
310 return None;
311 }
312 }
313
314 Some(offset)
315 }
316
check_size(&mut self, span: &proc_macro2::Ident, wanted: &proc_macro2::TokenStream)317 fn check_size(&mut self, span: &proc_macro2::Ident, wanted: &proc_macro2::TokenStream) {
318 let packet_name = &self.packet_name;
319 self.code.push(quote! {
320 if #span.get().remaining() < #wanted {
321 return Err(DecodeError::InvalidLengthError {
322 obj: #packet_name,
323 wanted: #wanted,
324 got: #span.get().remaining(),
325 });
326 }
327 });
328 }
329
add_array_field( &mut self, id: &str, width: Option<usize>, type_id: Option<&str>, size: Option<usize>, padding_size: Option<usize>, decl: Option<&ast::Decl>, )330 fn add_array_field(
331 &mut self,
332 id: &str,
333 // `width`: the width in bits of the array elements (if Some).
334 width: Option<usize>,
335 // `type_id`: the enum type of the array elements (if Some).
336 // Mutually exclusive with `width`.
337 type_id: Option<&str>,
338 // `size`: the size of the array in number of elements (if
339 // known). If None, the array is a Vec with a dynamic size.
340 size: Option<usize>,
341 padding_size: Option<usize>,
342 decl: Option<&ast::Decl>,
343 ) {
344 enum ElementWidth {
345 Static(usize), // Static size in bytes.
346 Unknown,
347 }
348 let element_width =
349 match width.or_else(|| self.schema.total_size(decl.unwrap().key).static_()) {
350 Some(w) => {
351 assert_eq!(w % 8, 0, "Array element size ({w}) is not a multiple of 8");
352 ElementWidth::Static(w / 8)
353 }
354 None => ElementWidth::Unknown,
355 };
356
357 // The "shape" of the array, i.e., the number of elements
358 // given via a static count, a count field, a size field, or
359 // unknown.
360 enum ArrayShape {
361 Static(usize), // Static count
362 CountField(proc_macro2::Ident), // Count based on count field
363 SizeField(proc_macro2::Ident), // Count based on size and field
364 Unknown, // Variable count based on remaining bytes
365 }
366 let array_shape = if let Some(count) = size {
367 ArrayShape::Static(count)
368 } else if let Some(count_field) = self.find_count_field(id) {
369 ArrayShape::CountField(count_field)
370 } else if let Some(size_field) = self.find_size_field(id) {
371 ArrayShape::SizeField(size_field)
372 } else {
373 ArrayShape::Unknown
374 };
375
376 // TODO size modifier
377
378 let span = match padding_size {
379 Some(padding_size) => {
380 let span = self.span;
381 let padding_octets = padding_size / 8;
382 self.check_size(span, "e!(#padding_octets));
383 self.code.push(quote! {
384 let (head, tail) = #span.get().split_at(#padding_octets);
385 let mut head = &mut Cell::new(head);
386 #span.replace(tail);
387 });
388 format_ident!("head")
389 }
390 None => self.span.clone(),
391 };
392
393 let id = id.to_ident();
394
395 let parse_element = self.parse_array_element(&span, width, type_id, decl);
396 match (element_width, &array_shape) {
397 (ElementWidth::Unknown, ArrayShape::SizeField(size_field)) => {
398 // The element width is not known, but the array full
399 // octet size is known by size field. Parse elements
400 // item by item as a vector.
401 self.check_size(&span, "e!(#size_field));
402 let parse_element =
403 self.parse_array_element(&format_ident!("head"), width, type_id, decl);
404 self.code.push(quote! {
405 let (head, tail) = #span.get().split_at(#size_field);
406 let mut head = &mut Cell::new(head);
407 #span.replace(tail);
408 let mut #id = Vec::new();
409 while !head.get().is_empty() {
410 #id.push(#parse_element?);
411 }
412 });
413 }
414 (ElementWidth::Unknown, ArrayShape::Static(count)) => {
415 // The element width is not known, but the array
416 // element count is known statically. Parse elements
417 // item by item as an array.
418 let count = syn::Index::from(*count);
419 self.code.push(quote! {
420 // TODO(mgeisler): use
421 // https://doc.rust-lang.org/std/array/fn.try_from_fn.html
422 // when stabilized.
423 let #id = (0..#count)
424 .map(|_| #parse_element)
425 .collect::<Result<Vec<_>, DecodeError>>()?
426 .try_into()
427 .map_err(|_| DecodeError::InvalidPacketError)?;
428 });
429 }
430 (ElementWidth::Unknown, ArrayShape::CountField(count_field)) => {
431 // The element width is not known, but the array
432 // element count is known by the count field. Parse
433 // elements item by item as a vector.
434 self.code.push(quote! {
435 let #id = (0..#count_field)
436 .map(|_| #parse_element)
437 .collect::<Result<Vec<_>, DecodeError>>()?;
438 });
439 }
440 (ElementWidth::Unknown, ArrayShape::Unknown) => {
441 // Neither the count not size is known, parse elements
442 // until the end of the span.
443 self.code.push(quote! {
444 let mut #id = Vec::new();
445 while !#span.get().is_empty() {
446 #id.push(#parse_element?);
447 }
448 });
449 }
450 (ElementWidth::Static(element_width), ArrayShape::Static(count)) => {
451 // The element width is known, and the array element
452 // count is known statically.
453 let count = syn::Index::from(*count);
454 // This creates a nicely formatted size.
455 let array_size = if element_width == 1 {
456 quote!(#count)
457 } else {
458 let element_width = syn::Index::from(element_width);
459 quote!(#count * #element_width)
460 };
461 self.check_size(&span, "e! { #array_size });
462 self.code.push(quote! {
463 // TODO(mgeisler): use
464 // https://doc.rust-lang.org/std/array/fn.try_from_fn.html
465 // when stabilized.
466 let #id = (0..#count)
467 .map(|_| #parse_element)
468 .collect::<Result<Vec<_>, DecodeError>>()?
469 .try_into()
470 .map_err(|_| DecodeError::InvalidPacketError)?;
471 });
472 }
473 (ElementWidth::Static(element_width), ArrayShape::CountField(count_field)) => {
474 // The element width is known, and the array element
475 // count is known dynamically by the count field.
476 self.check_size(&span, "e!(#count_field * #element_width));
477 self.code.push(quote! {
478 let #id = (0..#count_field)
479 .map(|_| #parse_element)
480 .collect::<Result<Vec<_>, DecodeError>>()?;
481 });
482 }
483 (ElementWidth::Static(element_width), ArrayShape::SizeField(_))
484 | (ElementWidth::Static(element_width), ArrayShape::Unknown) => {
485 // The element width is known, and the array full size
486 // is known by size field, or unknown (in which case
487 // it is the remaining span length).
488 let array_size = if let ArrayShape::SizeField(size_field) = &array_shape {
489 self.check_size(&span, "e!(#size_field));
490 quote!(#size_field)
491 } else {
492 quote!(#span.get().remaining())
493 };
494 let count_field = format_ident!("{id}_count");
495 let array_count = if element_width != 1 {
496 let element_width = syn::Index::from(element_width);
497 self.code.push(quote! {
498 if #array_size % #element_width != 0 {
499 return Err(DecodeError::InvalidArraySize {
500 array: #array_size,
501 element: #element_width,
502 });
503 }
504 let #count_field = #array_size / #element_width;
505 });
506 quote!(#count_field)
507 } else {
508 array_size
509 };
510
511 self.code.push(quote! {
512 let mut #id = Vec::with_capacity(#array_count);
513 for _ in 0..#array_count {
514 #id.push(#parse_element?);
515 }
516 });
517 }
518 }
519 }
520
521 /// Parse typedef fields.
522 ///
523 /// This is only for non-enum fields: enums are parsed via
524 /// add_bit_field.
add_typedef_field(&mut self, id: &str, type_id: &str)525 fn add_typedef_field(&mut self, id: &str, type_id: &str) {
526 assert_eq!(self.shift, 0, "Typedef field does not start on an octet boundary");
527
528 let decl = self.scope.typedef[type_id];
529 if let ast::DeclDesc::Struct { parent_id: Some(_), .. } = &decl.desc {
530 panic!("Derived struct used in typedef field");
531 }
532
533 let span = self.span;
534 let id = id.to_ident();
535 let type_id = type_id.to_ident();
536
537 self.code.push(match self.schema.decl_size(decl.key) {
538 analyzer::Size::Unknown | analyzer::Size::Dynamic => quote! {
539 let #id = #type_id::parse_inner(&mut #span)?;
540 },
541 analyzer::Size::Static(width) => {
542 assert_eq!(width % 8, 0, "Typedef field type size is not a multiple of 8");
543 match &decl.desc {
544 ast::DeclDesc::Checksum { .. } => todo!(),
545 ast::DeclDesc::CustomField { .. } if [8, 16, 32, 64].contains(&width) => {
546 let get_uint = types::get_uint(self.endianness, width, span);
547 quote! {
548 let #id = #get_uint.into();
549 }
550 }
551 ast::DeclDesc::CustomField { .. } => {
552 let get_uint = types::get_uint(self.endianness, width, span);
553 quote! {
554 let #id = (#get_uint)
555 .try_into()
556 .unwrap(); // Value is masked and conversion must succeed.
557 }
558 }
559 ast::DeclDesc::Struct { .. } => {
560 let width = syn::Index::from(width / 8);
561 quote! {
562 let (head, tail) = #span.get().split_at(#width);
563 #span.replace(tail);
564 let #id = #type_id::parse(head)?;
565 }
566 }
567 _ => unreachable!(),
568 }
569 }
570 });
571 }
572
573 /// Parse body and payload fields.
add_payload_field(&mut self, size_modifier: Option<&str>)574 fn add_payload_field(&mut self, size_modifier: Option<&str>) {
575 let span = self.span;
576 let payload_size_field = self.decl.payload_size();
577 let offset_from_end = self.payload_field_offset_from_end();
578
579 if self.shift != 0 {
580 todo!("Unexpected non byte aligned payload");
581 }
582
583 if let Some(ast::FieldDesc::Size { field_id, .. }) = &payload_size_field.map(|f| &f.desc) {
584 // The payload or body has a known size. Consume the
585 // payload and update the span in case fields are placed
586 // after the payload.
587 let size_field = size_field_ident(field_id);
588 if let Some(size_modifier) = size_modifier {
589 let size_modifier = proc_macro2::Literal::usize_unsuffixed(
590 size_modifier.parse::<usize>().expect("failed to parse the size modifier"),
591 );
592 let packet_name = &self.packet_name;
593 // Push code to check that the size is greater than the size
594 // modifier. Required to safely substract the modifier from the
595 // size.
596 self.code.push(quote! {
597 if #size_field < #size_modifier {
598 return Err(DecodeError::InvalidLengthError {
599 obj: #packet_name,
600 wanted: #size_modifier,
601 got: #size_field,
602 });
603 }
604 let #size_field = #size_field - #size_modifier;
605 });
606 }
607 self.check_size(self.span, "e!(#size_field ));
608 self.code.push(quote! {
609 let payload = &#span.get()[..#size_field];
610 #span.get_mut().advance(#size_field);
611 });
612 } else if offset_from_end == Some(0) {
613 // The payload or body is the last field of a packet,
614 // consume the remaining span.
615 self.code.push(quote! {
616 let payload = #span.get();
617 #span.get_mut().advance(payload.len());
618 });
619 } else if let Some(offset_from_end) = offset_from_end {
620 // The payload or body is followed by fields of static
621 // size. Consume the span that is not reserved for the
622 // following fields.
623 assert_eq!(
624 offset_from_end % 8,
625 0,
626 "Payload field offset from end of packet is not a multiple of 8"
627 );
628 let offset_from_end = syn::Index::from(offset_from_end / 8);
629 self.check_size(self.span, "e!(#offset_from_end));
630 self.code.push(quote! {
631 let payload = &#span.get()[..#span.get().len() - #offset_from_end];
632 #span.get_mut().advance(payload.len());
633 });
634 }
635
636 let decl = self.scope.typedef[self.packet_name];
637 if let ast::DeclDesc::Struct { .. } = &decl.desc {
638 self.code.push(quote! {
639 let payload = Vec::from(payload);
640 });
641 }
642 }
643
644 /// Parse a single array field element from `span`.
parse_array_element( &self, span: &proc_macro2::Ident, width: Option<usize>, type_id: Option<&str>, decl: Option<&ast::Decl>, ) -> proc_macro2::TokenStream645 fn parse_array_element(
646 &self,
647 span: &proc_macro2::Ident,
648 width: Option<usize>,
649 type_id: Option<&str>,
650 decl: Option<&ast::Decl>,
651 ) -> proc_macro2::TokenStream {
652 if let Some(width) = width {
653 let get_uint = types::get_uint(self.endianness, width, span);
654 return quote! {
655 Ok::<_, DecodeError>(#get_uint)
656 };
657 }
658
659 if let Some(ast::DeclDesc::Enum { id, width, .. }) = decl.map(|decl| &decl.desc) {
660 let get_uint = types::get_uint(self.endianness, *width, span);
661 let type_id = id.to_ident();
662 let packet_name = &self.packet_name;
663 return quote! {
664 #type_id::try_from(#get_uint).map_err(|unknown_val| DecodeError::InvalidEnumValueError {
665 obj: #packet_name,
666 field: "", // TODO(mgeisler): fill out or remove
667 value: unknown_val as u64,
668 type_: #id,
669 })
670 };
671 }
672
673 let type_id = type_id.unwrap().to_ident();
674 quote! {
675 #type_id::parse_inner(#span)
676 }
677 }
678
done(&mut self)679 pub fn done(&mut self) {
680 let decl = self.scope.typedef[self.packet_name];
681 if let ast::DeclDesc::Struct { .. } = &decl.desc {
682 return; // Structs don't parse the child structs recursively.
683 }
684
685 let children = self.scope.iter_children(decl).collect::<Vec<_>>();
686 if children.is_empty() && self.decl.payload().is_none() {
687 return;
688 }
689
690 let all_fields = HashMap::<String, _>::from_iter(
691 self.scope.iter_fields(decl).filter_map(|f| f.id().map(|id| (id.to_string(), f))),
692 );
693
694 // Gather fields that are constrained in immediate child declarations.
695 // Keep the fields sorted by name.
696 // TODO: fields that are only matched in grand children will not be included.
697 let constrained_fields = children
698 .iter()
699 .flat_map(|child| child.constraints().map(|c| &c.id))
700 .collect::<BTreeSet<_>>();
701
702 let mut match_values = Vec::new();
703 let mut child_parse_args = Vec::new();
704 let mut child_ids_data = Vec::new();
705 let mut child_ids = Vec::new();
706
707 let get_constraint_value = |mut constraints: std::slice::Iter<'_, ast::Constraint>,
708 id: &str|
709 -> Option<proc_macro2::TokenStream> {
710 constraints.find(|c| c.id == id).map(|c| constraint_to_value(&all_fields, c))
711 };
712
713 for child in children.iter() {
714 let tuple_values = constrained_fields
715 .iter()
716 .map(|id| {
717 get_constraint_value(child.constraints(), id).map(|v| vec![v]).unwrap_or_else(
718 || {
719 self.scope
720 .file
721 .iter_children(child)
722 .filter_map(|d| get_constraint_value(d.constraints(), id))
723 .collect()
724 },
725 )
726 })
727 .collect::<Vec<_>>();
728
729 // If no constraint values are found for the tuple just skip the child
730 // packet as it would capture unwanted input packets.
731 if tuple_values.iter().all(|v| v.is_empty()) {
732 continue;
733 }
734
735 let tuple_values = tuple_values
736 .iter()
737 .map(|v| v.is_empty().then_some(quote!(_)).unwrap_or_else(|| quote!( #(#v)|* )))
738 .collect::<Vec<_>>();
739
740 let fields = find_constrained_parent_fields(self.scope, child.id().unwrap())
741 .iter()
742 .map(|field| field.id().unwrap().to_ident())
743 .collect::<Vec<_>>();
744
745 match_values.push(quote!( (#(#tuple_values),*) ));
746 child_parse_args.push(quote!( #(, #fields)*));
747 child_ids_data.push(format_ident!("{}Data", child.id().unwrap()));
748 child_ids.push(child.id().unwrap().to_ident());
749 }
750
751 let constrained_field_idents = constrained_fields.iter().map(|field| field.to_ident());
752 let packet_data_child = format_ident!("{}DataChild", self.packet_name);
753
754 // Parsing of packet children requires having a payload field;
755 // it is allowed to inherit from a packet with empty payload, in this
756 // case generate an empty payload value.
757 if !decl
758 .fields()
759 .any(|f| matches!(&f.desc, ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body))
760 {
761 self.code.push(quote! {
762 let payload: &[u8] = &[];
763 })
764 }
765 self.code.push(quote! {
766 let child = match (#(#constrained_field_idents),*) {
767 #(#match_values if #child_ids_data::conforms(&payload) => {
768 let mut cell = Cell::new(payload);
769 let child_data = #child_ids_data::parse_inner(&mut cell #child_parse_args)?;
770 // TODO(mgeisler): communicate back to user if !cell.get().is_empty()?
771 #packet_data_child::#child_ids(child_data)
772 }),*
773 _ if !payload.is_empty() => {
774 #packet_data_child::Payload(Bytes::copy_from_slice(payload))
775 }
776 _ => #packet_data_child::None,
777 };
778 });
779 }
780 }
781
782 impl quote::ToTokens for FieldParser<'_> {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)783 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
784 let code = &self.code;
785 tokens.extend(quote! {
786 #(#code)*
787 });
788 }
789 }
790
791 #[cfg(test)]
792 mod tests {
793 use super::*;
794 use crate::analyzer;
795 use crate::ast;
796 use crate::parser::parse_inline;
797
798 /// Parse a string fragment as a PDL file.
799 ///
800 /// # Panics
801 ///
802 /// Panics on parse errors.
parse_str(text: &str) -> ast::File803 pub fn parse_str(text: &str) -> ast::File {
804 let mut db = ast::SourceDatabase::new();
805 let file = parse_inline(&mut db, "stdin", String::from(text)).expect("parse error");
806 analyzer::analyze(&file).expect("analyzer error")
807 }
808
809 #[test]
test_find_fields_static()810 fn test_find_fields_static() {
811 let code = "
812 little_endian_packets
813 packet P {
814 a: 24[3],
815 }
816 ";
817 let file = parse_str(code);
818 let scope = analyzer::Scope::new(&file).unwrap();
819 let schema = analyzer::Schema::new(&file);
820 let span = format_ident!("bytes");
821 let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
822 assert_eq!(parser.find_size_field("a"), None);
823 assert_eq!(parser.find_count_field("a"), None);
824 }
825
826 #[test]
test_find_fields_dynamic_count()827 fn test_find_fields_dynamic_count() {
828 let code = "
829 little_endian_packets
830 packet P {
831 _count_(b): 24,
832 b: 16[],
833 }
834 ";
835 let file = parse_str(code);
836 let scope = analyzer::Scope::new(&file).unwrap();
837 let schema = analyzer::Schema::new(&file);
838 let span = format_ident!("bytes");
839 let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
840 assert_eq!(parser.find_size_field("b"), None);
841 assert_eq!(parser.find_count_field("b"), Some(format_ident!("b_count")));
842 }
843
844 #[test]
test_find_fields_dynamic_size()845 fn test_find_fields_dynamic_size() {
846 let code = "
847 little_endian_packets
848 packet P {
849 _size_(c): 8,
850 c: 24[],
851 }
852 ";
853 let file = parse_str(code);
854 let scope = analyzer::Scope::new(&file).unwrap();
855 let schema = analyzer::Schema::new(&file);
856 let span = format_ident!("bytes");
857 let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
858 assert_eq!(parser.find_size_field("c"), Some(format_ident!("c_size")));
859 assert_eq!(parser.find_count_field("c"), None);
860 }
861 }
862