• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18from dataclasses import dataclass, field
19import json
20from pathlib import Path
21import sys
22from textwrap import dedent
23from typing import List, Tuple, Union, Optional
24
25from pdl import ast, core
26from pdl.utils import indent, to_pascal_case
27
28
29def mask(width: int) -> str:
30    return hex((1 << width) - 1)
31
32
33def deref(var: Optional[str], id: str) -> str:
34    return f'{var}.{id}' if var else id
35
36
37def get_cxx_scalar_type(width: int) -> str:
38    """Return the cxx scalar type to be used to back a PDL type."""
39    for n in [8, 16, 32, 64]:
40        if width <= n:
41            return f'uint{n}_t'
42    # PDL type does not fit on non-extended scalar types.
43    assert False
44
45
46@dataclass
47class FieldParser:
48    byteorder: str
49    offset: int = 0
50    shift: int = 0
51    extract_arrays: bool = field(default=False)
52    chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: [])
53    chunk_nr: int = 0
54    unchecked_code: List[str] = field(default_factory=lambda: [])
55    code: List[str] = field(default_factory=lambda: [])
56
57    def unchecked_append_(self, line: str):
58        """Append unchecked field parsing code.
59        The function check_size_ must be called to generate a size guard
60        after parsing is completed."""
61        self.unchecked_code.append(line)
62
63    def append_(self, line: str):
64        """Append field parsing code.
65        There must be no unchecked code left before this function is called."""
66        assert len(self.unchecked_code) == 0
67        self.code.append(line)
68
69    def check_size_(self, size: str):
70        """Generate a check of the current span size."""
71        self.append_(f"if (span.size() < {size}) {{")
72        self.append_("    return false;")
73        self.append_("}")
74
75    def check_code_(self):
76        """Generate a size check for pending field parsing."""
77        if len(self.unchecked_code) > 0:
78            assert len(self.chunk) == 0
79            unchecked_code = self.unchecked_code
80            self.unchecked_code = []
81            self.check_size_(str(self.offset))
82            self.code.extend(unchecked_code)
83            self.offset = 0
84
85    def parse_bit_field_(self, field: ast.Field):
86        """Parse the selected field as a bit field.
87        The field is added to the current chunk. When a byte boundary
88        is reached all saved fields are extracted together."""
89
90        # Add to current chunk.
91        width = core.get_field_size(field)
92        self.chunk.append((self.shift, width, field))
93        self.shift += width
94
95        # Wait for more fields if not on a byte boundary.
96        if (self.shift % 8) != 0:
97            return
98
99        # Parse the backing integer using the configured endianness,
100        # extract field values.
101        size = int(self.shift / 8)
102        backing_type = get_cxx_scalar_type(self.shift)
103
104        # Special case when no field is actually used from
105        # the chunk.
106        should_skip_value = all(isinstance(field, ast.ReservedField) for (_, _, field) in self.chunk)
107        if should_skip_value:
108            self.unchecked_append_(f"span.skip({size}); // skip reserved fields")
109            self.offset += size
110            self.shift = 0
111            self.chunk = []
112            return
113
114        if len(self.chunk) > 1:
115            value = f"chunk{self.chunk_nr}"
116            self.unchecked_append_(f"{backing_type} {value} = span.read_{self.byteorder}<{backing_type}, {size}>();")
117            self.chunk_nr += 1
118        else:
119            value = f"span.read_{self.byteorder}<{backing_type}, {size}>()"
120
121        for shift, width, field in self.chunk:
122            v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}")
123
124            if isinstance(field, ast.ScalarField):
125                self.unchecked_append_(f"{field.id}_ = {v};")
126            elif isinstance(field, ast.FixedField) and field.enum_id:
127                self.unchecked_append_(f"if ({field.enum_id}({v}) != {field.enum_id}::{field.tag_id}) {{")
128                self.unchecked_append_("    return false;")
129                self.unchecked_append_("}")
130            elif isinstance(field, ast.FixedField):
131                self.unchecked_append_(f"if (({v}) != {hex(field.value)}) {{")
132                self.unchecked_append_("    return false;")
133                self.unchecked_append_("}")
134            elif isinstance(field, ast.TypedefField):
135                self.unchecked_append_(f"{field.id}_ = {field.type_id}({v});")
136            elif isinstance(field, ast.SizeField):
137                self.unchecked_append_(f"{field.field_id}_size = {v};")
138            elif isinstance(field, ast.CountField):
139                self.unchecked_append_(f"{field.field_id}_count = {v};")
140            elif isinstance(field, ast.ReservedField):
141                pass
142            else:
143                raise Exception(f'Unsupported bit field type {field.kind}')
144
145        # Reset state.
146        self.offset += size
147        self.shift = 0
148        self.chunk = []
149
150    def parse_typedef_field_(self, field: ast.TypedefField):
151        """Parse a typedef field, to the exclusion of Enum fields."""
152        if self.shift != 0:
153            raise Exception('Typedef field does not start on an octet boundary')
154
155        self.check_code_()
156        self.append_(
157            dedent("""\
158            if (!{field_type}::Parse(span, &{field_id}_)) {{
159                return false;
160            }}""".format(field_type=field.type.id, field_id=field.id)))
161
162    def parse_array_field_lite_(self, field: ast.ArrayField):
163        """Parse the selected array field.
164        This function does not attempt to parse all elements but just to
165        identify the span of the array."""
166        array_size = core.get_array_field_size(field)
167        element_width = core.get_array_element_size(field)
168        padded_size = field.padded_size
169
170        if element_width:
171            element_width = int(element_width / 8)
172
173        if isinstance(array_size, int):
174            size = None
175            count = array_size
176        elif isinstance(array_size, ast.SizeField):
177            size = f'{field.id}_size'
178            count = None
179        elif isinstance(array_size, ast.CountField):
180            size = None
181            count = f'{field.id}_count'
182        else:
183            size = None
184            count = None
185
186        # Shift the span to reset the offset to 0.
187        self.check_code_()
188
189        # Apply the size modifier.
190        if field.size_modifier and size:
191            self.append_(f"{size} = {size} - {field.size_modifier};")
192
193        # Compute the array size if the count and element width are known.
194        if count is not None and element_width is not None:
195            size = f"{count} * {element_width}"
196
197        # Parse from the padded array if padding is present.
198        if padded_size:
199            self.check_size_(padded_size)
200            self.append_("{")
201            self.append_(
202                f"pdl::packet::slice remaining_span = span.subrange({padded_size}, span.size() - {padded_size});")
203            self.append_(f"span = span.subrange(0, {padded_size});")
204
205        # The array size is known in bytes.
206        if size is not None:
207            self.check_size_(size)
208            self.append_(f"{field.id}_ = span.subrange(0, {size});")
209            self.append_(f"span.skip({size});")
210
211        # The array count is known. The element width is dynamic.
212        # Parse each element iteratively and derive the array span.
213        elif count is not None:
214            self.append_("{")
215            self.append_("pdl::packet::slice temp_span = span;")
216            self.append_(f"for (size_t n = 0; n < {count}; n++) {{")
217            self.append_(f"    {field.type_id} element;")
218            self.append_(f"    if (!{field.type_id}::Parse(temp_span, &element)) {{")
219            self.append_("        return false;")
220            self.append_("    }")
221            self.append_("}")
222            self.append_(f"{field.id}_ = span.subrange(0, span.size() - temp_span.size());")
223            self.append_(f"span.skip({field.id}_.size());")
224            self.append_("}")
225
226        # The array size is not known, assume the array takes the
227        # full remaining space. TODO support having fixed sized fields
228        # following the array.
229        else:
230            self.append_(f"{field.id}_ = span;")
231            self.append_("span.clear();")
232
233        if padded_size:
234            self.append_(f"span = remaining_span;")
235            self.append_("}")
236
237    def parse_array_field_full_(self, field: ast.ArrayField):
238        """Parse the selected array field.
239        This function does not attempt to parse all elements but just to
240        identify the span of the array."""
241        array_size = core.get_array_field_size(field)
242        element_width = core.get_array_element_size(field)
243        element_type = field.type_id or get_cxx_scalar_type(field.width)
244        padded_size = field.padded_size
245
246        if element_width:
247            element_width = int(element_width / 8)
248
249        if isinstance(array_size, int):
250            size = None
251            count = array_size
252        elif isinstance(array_size, ast.SizeField):
253            size = f'{field.id}_size'
254            count = None
255        elif isinstance(array_size, ast.CountField):
256            size = None
257            count = f'{field.id}_count'
258        else:
259            size = None
260            count = None
261
262        # Shift the span to reset the offset to 0.
263        self.check_code_()
264
265        # Apply the size modifier.
266        if field.size_modifier and size:
267            self.append_(f"{size} = {size} - {field.size_modifier};")
268
269        # Compute the array size if the count and element width are known.
270        if count is not None and element_width is not None:
271            size = f"{count} * {element_width}"
272
273        # Parse from the padded array if padding is present.
274        if padded_size:
275            self.check_size_(padded_size)
276            self.append_("{")
277            self.append_(
278                f"pdl::packet::slice remaining_span = span.subrange({padded_size}, span.size() - {padded_size});")
279            self.append_(f"span = span.subrange(0, {padded_size});")
280
281        # The array size is known in bytes.
282        if size is not None:
283            self.check_size_(size)
284            self.append_("{")
285            self.append_(f"pdl::packet::slice temp_span = span.subrange(0, {size});")
286            self.append_(f"span.skip({size});")
287            self.append_(f"while (temp_span.size() > 0) {{")
288            if field.width:
289                element_size = int(field.width / 8)
290                self.append_(f"    if (temp_span.size() < {element_size}) {{")
291                self.append_(f"        return false;")
292                self.append_("    }")
293                self.append_(
294                    f"    {field.id}_.push_back(temp_span.read_{self.byteorder}<{element_type}, {element_size}>());")
295            elif isinstance(field.type, ast.EnumDeclaration):
296                backing_type = get_cxx_scalar_type(field.type.width)
297                element_size = int(field.type.width / 8)
298                self.append_(f"    if (temp_span.size() < {element_size}) {{")
299                self.append_(f"        return false;")
300                self.append_("    }")
301                self.append_(
302                    f"    {field.id}_.push_back({element_type}(temp_span.read_{self.byteorder}<{backing_type}, {element_size}>()));"
303                )
304            else:
305                self.append_(f"    {element_type} element;")
306                self.append_(f"    if (!{element_type}::Parse(temp_span, &element)) {{")
307                self.append_(f"        return false;")
308                self.append_("    }")
309                self.append_(f"    {field.id}_.emplace_back(std::move(element));")
310            self.append_("}")
311            self.append_("}")
312
313        # The array count is known. The element width is dynamic.
314        # Parse each element iteratively and derive the array span.
315        elif count is not None:
316            self.append_(f"for (size_t n = 0; n < {count}; n++) {{")
317            self.append_(f"    {element_type} element;")
318            self.append_(f"    if (!{field.type_id}::Parse(span, &element)) {{")
319            self.append_("        return false;")
320            self.append_("    }")
321            self.append_(f"    {field.id}_.emplace_back(std::move(element));")
322            self.append_("}")
323
324        # The array size is not known, assume the array takes the
325        # full remaining space. TODO support having fixed sized fields
326        # following the array.
327        elif field.width:
328            element_size = int(field.width / 8)
329            self.append_(f"while (span.size() > 0) {{")
330            self.append_(f"    if (span.size() < {element_size}) {{")
331            self.append_(f"        return false;")
332            self.append_("    }")
333            self.append_(f"    {field.id}_.push_back(span.read_{self.byteorder}<{element_type}, {element_size}>());")
334            self.append_("}")
335        elif isinstance(field.type, ast.EnumDeclaration):
336            element_size = int(field.type.width / 8)
337            backing_type = get_cxx_scalar_type(field.type.width)
338            self.append_(f"while (span.size() > 0) {{")
339            self.append_(f"    if (span.size() < {element_size}) {{")
340            self.append_(f"        return false;")
341            self.append_("    }")
342            self.append_(
343                f"    {field.id}_.push_back({element_type}(span.read_{self.byteorder}<{backing_type}, {element_size}>()));"
344            )
345            self.append_("}")
346        else:
347            self.append_(f"while (span.size() > 0) {{")
348            self.append_(f"    {element_type} element;")
349            self.append_(f"    if (!{element_type}::Parse(span, &element)) {{")
350            self.append_(f"        return false;")
351            self.append_("    }")
352            self.append_(f"    {field.id}_.emplace_back(std::move(element));")
353            self.append_("}")
354
355        if padded_size:
356            self.append_(f"span = remaining_span;")
357            self.append_("}")
358
359    def parse_payload_field_lite_(self, field: Union[ast.BodyField, ast.PayloadField]):
360        """Parse body and payload fields."""
361        if self.shift != 0:
362            raise Exception('Payload field does not start on an octet boundary')
363
364        payload_size = core.get_payload_field_size(field)
365        offset_from_end = core.get_field_offset_from_end(field)
366        self.check_code_()
367
368        if payload_size and getattr(field, 'size_modifier', None):
369            self.append_(f"{field.id}_size -= {field.size_modifier};")
370
371        # The payload or body has a known size.
372        # Consume the payload and update the span in case
373        # fields are placed after the payload.
374        if payload_size:
375            self.check_size_(f"{field.id}_size")
376            self.append_(f"payload_ = span.subrange(0, {field.id}_size);")
377            self.append_(f"span.skip({field.id}_size);")
378        # The payload or body is the last field of a packet,
379        # consume the remaining span.
380        elif offset_from_end == 0:
381            self.append_(f"payload_ = span;")
382            self.append_(f"span.clear();")
383        # The payload or body is followed by fields of static size.
384        # Consume the span that is not reserved for the following fields.
385        elif offset_from_end:
386            if (offset_from_end % 8) != 0:
387                raise Exception('Payload field offset from end of packet is not a multiple of 8')
388            offset_from_end = int(offset_from_end / 8)
389            self.check_size_(f'{offset_from_end}')
390            self.append_(f"payload_ = span.subrange(0, span.size() - {offset_from_end});")
391            self.append_(f"span.skip(payload_.size());")
392
393    def parse_payload_field_full_(self, field: Union[ast.BodyField, ast.PayloadField]):
394        """Parse body and payload fields."""
395        if self.shift != 0:
396            raise Exception('Payload field does not start on an octet boundary')
397
398        payload_size = core.get_payload_field_size(field)
399        offset_from_end = core.get_field_offset_from_end(field)
400        self.check_code_()
401
402        if payload_size and getattr(field, 'size_modifier', None):
403            self.append_(f"{field.id}_size -= {field.size_modifier};")
404
405        # The payload or body has a known size.
406        # Consume the payload and update the span in case
407        # fields are placed after the payload.
408        if payload_size:
409            self.check_size_(f"{field.id}_size")
410            self.append_(f"for (size_t n = 0; n < {field.id}_size; n++) {{")
411            self.append_(f"    payload_.push_back(span.read_{self.byteorder}<uint8_t>();")
412            self.append_("}")
413        # The payload or body is the last field of a packet,
414        # consume the remaining span.
415        elif offset_from_end == 0:
416            self.append_("while (span.size() > 0) {")
417            self.append_(f"    payload_.push_back(span.read_{self.byteorder}<uint8_t>();")
418            self.append_("}")
419        # The payload or body is followed by fields of static size.
420        # Consume the span that is not reserved for the following fields.
421        elif offset_from_end is not None:
422            if (offset_from_end % 8) != 0:
423                raise Exception('Payload field offset from end of packet is not a multiple of 8')
424            offset_from_end = int(offset_from_end / 8)
425            self.check_size_(f'{offset_from_end}')
426            self.append_(f"while (span.size() > {offset_from_end}) {{")
427            self.append_(f"    payload_.push_back(span.read_{self.byteorder}<uint8_t>();")
428            self.append_("}")
429
430    def parse(self, field: ast.Field):
431        # Field has bit granularity.
432        # Append the field to the current chunk,
433        # check if a byte boundary was reached.
434        if core.is_bit_field(field):
435            self.parse_bit_field_(field)
436
437        # Padding fields.
438        elif isinstance(field, ast.PaddingField):
439            pass
440
441        # Array fields.
442        elif isinstance(field, ast.ArrayField) and self.extract_arrays:
443            self.parse_array_field_full_(field)
444
445        elif isinstance(field, ast.ArrayField) and not self.extract_arrays:
446            self.parse_array_field_lite_(field)
447
448        # Other typedef fields.
449        elif isinstance(field, ast.TypedefField):
450            self.parse_typedef_field_(field)
451
452        # Payload and body fields.
453        elif isinstance(field, (ast.PayloadField, ast.BodyField)) and self.extract_arrays:
454            self.parse_payload_field_full_(field)
455
456        elif isinstance(field, (ast.PayloadField, ast.BodyField)) and not self.extract_arrays:
457            self.parse_payload_field_lite_(field)
458
459        else:
460            raise Exception(f'Unsupported field type {field.kind}')
461
462    def done(self):
463        self.check_code_()
464
465
466@dataclass
467class FieldSerializer:
468    byteorder: str
469    shift: int = 0
470    value: List[Tuple[str, int]] = field(default_factory=lambda: [])
471    code: List[str] = field(default_factory=lambda: [])
472    indent: int = 0
473
474    def indent_(self):
475        self.indent += 1
476
477    def unindent_(self):
478        self.indent -= 1
479
480    def append_(self, line: str):
481        """Append field serializing code."""
482        lines = line.split('\n')
483        self.code.extend(['    ' * self.indent + line for line in lines])
484
485    def get_payload_field_size(self, var: Optional[str], payload: ast.PayloadField, decl: ast.Declaration) -> str:
486        """Compute the size of the selected payload field, with the information
487        of the builder for the selected declaration. The payload field can be
488        the payload of any of the parent declarations, or the current declaration."""
489
490        if payload.parent.id == decl.id:
491            return deref(var, 'payload_.size()')
492
493        # Get the child packet declaration that will match the current
494        # declaration further down.
495        child = decl
496        while child.parent_id != payload.parent.id:
497            child = child.parent
498
499        # The payload is the result of serializing the children fields.
500        constant_width = 0
501        variable_width = []
502        for f in child.fields:
503            field_size = core.get_field_size(f)
504            if field_size is not None:
505                constant_width += field_size
506            elif isinstance(f, (ast.PayloadField, ast.BodyField)):
507                variable_width.append(self.get_payload_field_size(var, f, decl))
508            elif isinstance(f, ast.TypedefField):
509                variable_width.append(f"{f.id}_.GetSize()")
510            elif isinstance(f, ast.ArrayField):
511                variable_width.append(f"Get{to_pascal_case(f.id)}Size()")
512            else:
513                raise Exception("Unsupported field type")
514
515        constant_width = int(constant_width / 8)
516        if constant_width and not variable_width:
517            return str(constant_width)
518
519        temp_var = f'{payload.parent.id.lower()}_payload_size'
520        self.append_(f"size_t {temp_var} = {constant_width};")
521        for dyn in variable_width:
522            self.append_(f"{temp_var} += {dyn};")
523        return temp_var
524
525    def serialize_array_element_(self, field: ast.ArrayField, var: str):
526        """Serialize a single array field element."""
527        if field.width:
528            backing_type = get_cxx_scalar_type(field.width)
529            element_size = int(field.width / 8)
530            self.append_(
531                f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {element_size}>(output, {var});")
532        elif isinstance(field.type, ast.EnumDeclaration):
533            backing_type = get_cxx_scalar_type(field.type.width)
534            element_size = int(field.type.width / 8)
535            self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {element_size}>(" +
536                         f"output, static_cast<{backing_type}>({var}));")
537        else:
538            self.append_(f"{var}.Serialize(output);")
539
540    def serialize_array_field_(self, field: ast.ArrayField, var: str):
541        """Serialize the selected array field."""
542        if field.padded_size:
543            self.append_(f"size_t {field.id}_end = output.size() + {field.padded_size};")
544
545        if field.width == 8:
546            self.append_(f"output.insert(output.end(), {var}.begin(), {var}.end());")
547        else:
548            self.append_(f"for (size_t n = 0; n < {var}.size(); n++) {{")
549            self.indent_()
550            self.serialize_array_element_(field, f'{var}[n]')
551            self.unindent_()
552            self.append_("}")
553
554        if field.padded_size:
555            self.append_(f"while (output.size() < {field.id}_end) {{")
556            self.append_("    output.push_back(0);")
557            self.append_("}")
558
559    def serialize_bit_field_(self, field: ast.Field, parent_var: Optional[str], var: Optional[str],
560                             decl: ast.Declaration):
561        """Serialize the selected field as a bit field.
562        The field is added to the current chunk. When a byte boundary
563        is reached all saved fields are serialized together."""
564
565        # Add to current chunk.
566        width = core.get_field_size(field)
567        shift = self.shift
568
569        if isinstance(field, ast.ScalarField):
570            self.value.append((f"{var} & {mask(field.width)}", shift))
571        elif isinstance(field, ast.FixedField) and field.enum_id:
572            self.value.append((f"{field.enum_id}::{field.tag_id}", shift))
573        elif isinstance(field, ast.FixedField):
574            self.value.append((f"{field.value}", shift))
575        elif isinstance(field, ast.TypedefField):
576            self.value.append((f"{var}", shift))
577
578        elif isinstance(field, ast.SizeField):
579            max_size = (1 << field.width) - 1
580            value_field = core.get_packet_field(field.parent, field.field_id)
581            size_modifier = ''
582
583            if getattr(value_field, 'size_modifier', None):
584                size_modifier = f' + {value_field.size_modifier}'
585
586            if isinstance(value_field, (ast.PayloadField, ast.BodyField)):
587                array_size = self.get_payload_field_size(var, field, decl) + size_modifier
588
589            elif isinstance(value_field, ast.ArrayField):
590                accessor_name = to_pascal_case(field.field_id)
591                array_size = deref(var, f'Get{accessor_name}Size()') + size_modifier
592
593            self.value.append((f"{array_size}", shift))
594
595        elif isinstance(field, ast.CountField):
596            max_count = (1 << field.width) - 1
597            self.value.append((f"{field.field_id}_.size()", shift))
598
599        elif isinstance(field, ast.ReservedField):
600            pass
601        else:
602            raise Exception(f'Unsupported bit field type {field.kind}')
603
604        # Check if a byte boundary is reached.
605        self.shift += width
606        if (self.shift % 8) == 0:
607            self.pack_bit_fields_()
608
609    def pack_bit_fields_(self):
610        """Pack serialized bit fields."""
611
612        # Should have an integral number of bytes now.
613        assert (self.shift % 8) == 0
614
615        # Generate the backing integer, and serialize it
616        # using the configured endiannes,
617        size = int(self.shift / 8)
618        backing_type = get_cxx_scalar_type(self.shift)
619        value = [f"(static_cast<{backing_type}>({v[0]}) << {v[1]})" for v in self.value]
620
621        if len(value) == 0:
622            self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, 0);")
623        elif len(value) == 1:
624            self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, {value[0]});")
625        else:
626            self.append_(
627                f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, {' | '.join(value)});")
628
629        # Reset state.
630        self.shift = 0
631        self.value = []
632
633    def serialize_typedef_field_(self, field: ast.TypedefField, var: str):
634        """Serialize a typedef field, to the exclusion of Enum fields."""
635
636        if self.shift != 0:
637            raise Exception('Typedef field does not start on an octet boundary')
638        if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
639            raise Exception('Derived struct used in typedef field')
640
641        self.append_(f"{var}.Serialize(output);")
642
643    def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField], var: str):
644        """Serialize body and payload fields."""
645
646        if self.shift != 0:
647            raise Exception('Payload field does not start on an octet boundary')
648
649        self.append_(f"output.insert(output.end(), {var}.begin(), {var}.end());")
650
651    def serialize(self, field: ast.Field, decl: ast.Declaration, var: Optional[str] = None):
652        field_var = deref(var, f'{field.id}_') if hasattr(field, 'id') else None
653
654        # Field has bit granularity.
655        # Append the field to the current chunk,
656        # check if a byte boundary was reached.
657        if core.is_bit_field(field):
658            self.serialize_bit_field_(field, var, field_var, decl)
659
660        # Padding fields.
661        elif isinstance(field, ast.PaddingField):
662            pass
663
664        # Array fields.
665        elif isinstance(field, ast.ArrayField):
666            self.serialize_array_field_(field, field_var)
667
668        # Other typedef fields.
669        elif isinstance(field, ast.TypedefField):
670            self.serialize_typedef_field_(field, field_var)
671
672        # Payload and body fields.
673        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
674            self.serialize_payload_field_(field, deref(var, 'payload_'))
675
676        else:
677            raise Exception(f'Unimplemented field type {field.kind}')
678
679
680def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
681    """Generate the implementation of an enum type."""
682
683    enum_name = decl.id
684    enum_type = get_cxx_scalar_type(decl.width)
685    tag_decls = []
686    for t in decl.tags:
687        tag_decls.append(f"{t.id} = {hex(t.value)},")
688
689    return dedent("""\
690
691        enum class {enum_name} : {enum_type} {{
692            {tag_decls}
693        }};
694        """).format(enum_name=enum_name, enum_type=enum_type, tag_decls=indent(tag_decls, 1))
695
696
697def generate_enum_to_text(decl: ast.EnumDeclaration) -> str:
698    """Generate the helper function that will convert an enum tag to string."""
699
700    enum_name = decl.id
701    tag_cases = []
702    for t in decl.tags:
703        tag_cases.append(f"case {enum_name}::{t.id}: return \"{t.id}\";")
704
705    return dedent("""\
706
707        inline std::string {enum_name}Text({enum_name} tag) {{
708            switch (tag) {{
709                {tag_cases}
710                default:
711                    return std::string("Unknown {enum_name}: " +
712                           std::to_string(static_cast<uint64_t>(tag)));
713            }}
714        }}
715        """).format(enum_name=enum_name, tag_cases=indent(tag_cases, 2))
716
717
718def generate_packet_field_members(decl: ast.Declaration, view: bool) -> List[str]:
719    """Return the declaration of fields that are backed in the view
720    class declaration.
721
722    Backed fields include all named fields that do not have a constrained
723    value in the selected declaration and its parents.
724
725    :param decl: target declaration
726    :param view: if true the payload and array fields are generated as slices"""
727
728    fields = core.get_unconstrained_parent_fields(decl) + decl.fields
729    members = []
730    for field in fields:
731        if isinstance(field, (ast.PayloadField, ast.BodyField)) and view:
732            members.append("pdl::packet::slice payload_;")
733        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
734            members.append("std::vector<uint8_t> payload_;")
735        elif isinstance(field, ast.ArrayField) and view:
736            members.append(f"pdl::packet::slice {field.id}_;")
737        elif isinstance(field, ast.ArrayField):
738            element_type = field.type_id or get_cxx_scalar_type(field.width)
739            members.append(f"std::vector<{element_type}> {field.id}_;")
740        elif isinstance(field, ast.ScalarField):
741            members.append(f"{get_cxx_scalar_type(field.width)} {field.id}_{{0}};")
742        elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
743            members.append(f"{field.type_id} {field.id}_{{{field.type_id}::{field.type.tags[0].id}}};")
744        elif isinstance(field, ast.TypedefField):
745            members.append(f"{field.type_id} {field.id}_;")
746
747    return members
748
749
750def generate_packet_field_serializers(packet: ast.Declaration) -> List[str]:
751    """Generate the code to serialize the fields of a packet builder or struct."""
752    serializer = FieldSerializer(byteorder=packet.file.byteorder_short)
753    constraints = core.get_parent_constraints(packet)
754    constraints = dict([(c.id, c) for c in constraints])
755    for field in core.get_packet_fields(packet):
756        field_id = getattr(field, 'id', None)
757        constraint = constraints.get(field_id, None)
758        fixed_field = None
759        if constraint and constraint.tag_id:
760            fixed_field = ast.FixedField(enum_id=field.type_id,
761                                         tag_id=constraint.tag_id,
762                                         loc=field.loc,
763                                         kind='fixed_field')
764            fixed_field.parent = field.parent
765        elif constraint:
766            fixed_field = ast.FixedField(width=field.width, value=constraint.value, loc=field.loc, kind='fixed_field')
767            fixed_field.parent = field.parent
768        serializer.serialize(fixed_field or field, packet)
769    return serializer.code
770
771
772def generate_scalar_array_field_accessor(field: ast.ArrayField) -> str:
773    """Parse the selected scalar array field."""
774    element_size = int(field.width / 8)
775    backing_type = get_cxx_scalar_type(field.width)
776    byteorder = field.parent.file.byteorder_short
777    return dedent("""\
778        pdl::packet::slice span = {field_id}_;
779        std::vector<{backing_type}> elements;
780        while (span.size() >= {element_size}) {{
781            elements.push_back(span.read_{byteorder}<{backing_type}, {element_size}>());
782        }}
783        return elements;""").format(field_id=field.id,
784                                    backing_type=backing_type,
785                                    element_size=element_size,
786                                    byteorder=byteorder)
787
788
789def generate_enum_array_field_accessor(field: ast.ArrayField) -> str:
790    """Parse the selected enum array field."""
791    element_size = int(field.type.width / 8)
792    backing_type = get_cxx_scalar_type(field.type.width)
793    byteorder = field.parent.file.byteorder_short
794    return dedent("""\
795        pdl::packet::slice span = {field_id}_;
796        std::vector<{enum_type}> elements;
797        while (span.size() >= {element_size}) {{
798            elements.push_back({enum_type}(span.read_{byteorder}<{backing_type}, {element_size}>()));
799        }}
800        return elements;""").format(field_id=field.id,
801                                    enum_type=field.type_id,
802                                    backing_type=backing_type,
803                                    element_size=element_size,
804                                    byteorder=byteorder)
805
806
807def generate_typedef_array_field_accessor(field: ast.ArrayField) -> str:
808    """Parse the selected typedef array field."""
809    return dedent("""\
810        pdl::packet::slice span = {field_id}_;
811        std::vector<{struct_type}> elements;
812        for (;;) {{
813            {struct_type} element;
814            if (!{struct_type}::Parse(span, &element)) {{
815                break;
816            }}
817            elements.emplace_back(std::move(element));
818        }}
819        return elements;""").format(field_id=field.id, struct_type=field.type_id)
820
821
822def generate_array_field_accessor(field: ast.ArrayField):
823    """Parse the selected array field."""
824
825    if field.width is not None:
826        return generate_scalar_array_field_accessor(field)
827    elif isinstance(field.type, ast.EnumDeclaration):
828        return generate_enum_array_field_accessor(field)
829    else:
830        return generate_typedef_array_field_accessor(field)
831
832
833def generate_array_field_size_getters(decl: ast.Declaration) -> str:
834    """Generate size getters for array fields. Produces the serialized
835    size of the array in bytes."""
836
837    getters = []
838    fields = core.get_unconstrained_parent_fields(decl) + decl.fields
839    for field in fields:
840        if not isinstance(field, ast.ArrayField):
841            continue
842
843        element_width = field.width or core.get_declaration_size(field.type)
844        size = None
845
846        if element_width and field.size:
847            size = int(element_width * field.size / 8)
848        elif element_width:
849            size = f"{field.id}_.size() * {int(element_width / 8)}"
850
851        if size:
852            getters.append(
853                dedent("""\
854                size_t Get{accessor_name}Size() const {{
855                    return {size};
856                }}
857                """).format(accessor_name=to_pascal_case(field.id), size=size))
858        else:
859            getters.append(
860                dedent("""\
861                size_t Get{accessor_name}Size() const {{
862                    size_t array_size = 0;
863                    for (size_t n = 0; n < {field_id}_.size(); n++) {{
864                        array_size += {field_id}_[n].GetSize();
865                    }}
866                    return array_size;
867                }}
868                """).format(accessor_name=to_pascal_case(field.id), field_id=field.id))
869
870    return '\n'.join(getters)
871
872
873def generate_packet_size_getter(decl: ast.Declaration) -> List[str]:
874    """Generate a size getter the current packet. Produces the serialized
875    size of the packet in bytes."""
876
877    constant_width = 0
878    variable_width = []
879    for f in core.get_packet_fields(decl):
880        field_size = core.get_field_size(f)
881        if field_size is not None:
882            constant_width += field_size
883        elif isinstance(f, (ast.PayloadField, ast.BodyField)):
884            variable_width.append("payload_.size()")
885        elif isinstance(f, ast.TypedefField):
886            variable_width.append(f"{f.id}_.GetSize()")
887        elif isinstance(f, ast.ArrayField):
888            variable_width.append(f"Get{to_pascal_case(f.id)}Size()")
889        else:
890            raise Exception("Unsupported field type")
891
892    constant_width = int(constant_width / 8)
893    if not variable_width:
894        return [f"return {constant_width};"]
895    elif len(variable_width) == 1 and constant_width:
896        return [f"return {variable_width[0]} + {constant_width};"]
897    elif len(variable_width) == 1:
898        return [f"return {variable_width[0]};"]
899    elif len(variable_width) > 1 and constant_width:
900        return ([f"return {constant_width} + ("] + " +\n    ".join(variable_width).split("\n") + [");"])
901    elif len(variable_width) > 1:
902        return (["return ("] + " +\n    ".join(variable_width).split("\n") + [");"])
903    else:
904        assert False
905
906
907def generate_packet_view_field_accessors(packet: ast.PacketDeclaration) -> List[str]:
908    """Return the declaration of accessors for the named packet fields."""
909
910    accessors = []
911
912    # Add accessors for the backed fields.
913    fields = core.get_unconstrained_parent_fields(packet) + packet.fields
914    for field in fields:
915        if isinstance(field, (ast.PayloadField, ast.BodyField)):
916            accessors.append(
917                dedent("""\
918                std::vector<uint8_t> GetPayload() const {
919                    ASSERT(valid_);
920                    return payload_.bytes();
921                }
922
923                """))
924        elif isinstance(field, ast.ArrayField):
925            element_type = field.type_id or get_cxx_scalar_type(field.width)
926            accessor_name = to_pascal_case(field.id)
927            accessors.append(
928                dedent("""\
929                std::vector<{element_type}> Get{accessor_name}() const {{
930                    ASSERT(valid_);
931                    {accessor}
932                }}
933
934                """).format(element_type=element_type,
935                            accessor_name=accessor_name,
936                            accessor=indent(generate_array_field_accessor(field), 1)))
937        elif isinstance(field, ast.ScalarField):
938            field_type = get_cxx_scalar_type(field.width)
939            accessor_name = to_pascal_case(field.id)
940            accessors.append(
941                dedent("""\
942                {field_type} Get{accessor_name}() const {{
943                    ASSERT(valid_);
944                    return {member_name}_;
945                }}
946
947                """).format(field_type=field_type, accessor_name=accessor_name, member_name=field.id))
948        elif isinstance(field, ast.TypedefField):
949            field_qualifier = "" if isinstance(field.type, ast.EnumDeclaration) else " const&"
950            accessor_name = to_pascal_case(field.id)
951            accessors.append(
952                dedent("""\
953                {field_type}{field_qualifier} Get{accessor_name}() const {{
954                    ASSERT(valid_);
955                    return {member_name}_;
956                }}
957
958                """).format(field_type=field.type_id,
959                            field_qualifier=field_qualifier,
960                            accessor_name=accessor_name,
961                            member_name=field.id))
962
963    # Add accessors for constrained parent fields.
964    # The accessors return a constant value in this case.
965    for c in core.get_parent_constraints(packet):
966        field = core.get_packet_field(packet, c.id)
967        if isinstance(field, ast.ScalarField):
968            field_type = get_cxx_scalar_type(field.width)
969            accessor_name = to_pascal_case(field.id)
970            accessors.append(
971                dedent("""\
972                {field_type} Get{accessor_name}() const {{
973                    return {value};
974                }}
975
976                """).format(field_type=field_type, accessor_name=accessor_name, value=c.value))
977        else:
978            accessor_name = to_pascal_case(field.id)
979            accessors.append(
980                dedent("""\
981                {field_type} Get{accessor_name}() const {{
982                    return {field_type}::{tag_id};
983                }}
984
985                """).format(field_type=field.type_id, accessor_name=accessor_name, tag_id=c.tag_id))
986
987    return "".join(accessors)
988
989
990def generate_packet_stringifier(packet: ast.PacketDeclaration) -> str:
991    """Generate the packet printer. TODO """
992    return dedent("""\
993        std::string ToString() const {
994            return "";
995        }
996        """)
997
998
999def generate_packet_view_field_parsers(packet: ast.PacketDeclaration) -> str:
1000    """Generate the packet parser. The validator will extract
1001    the fields it can in a pre-parsing phase. """
1002
1003    code = []
1004
1005    # Generate code to check the validity of the parent,
1006    # and import parent fields that do not have a fixed value in the
1007    # current packet.
1008    if packet.parent:
1009        code.append(
1010            dedent("""\
1011            // Check validity of parent packet.
1012            if (!parent.IsValid()) {
1013                return false;
1014            }
1015            """))
1016        parent_fields = core.get_unconstrained_parent_fields(packet)
1017        if parent_fields:
1018            code.append("// Copy parent field values.")
1019            for f in parent_fields:
1020                code.append(f"{f.id}_ = parent.{f.id}_;")
1021            code.append("")
1022        span = "parent.payload_"
1023    else:
1024        span = "parent"
1025
1026    # Validate parent constraints.
1027    for c in packet.constraints:
1028        if c.tag_id:
1029            enum_type = core.get_packet_field(packet.parent, c.id).type_id
1030            code.append(
1031                dedent("""\
1032                if (parent.{field_id}_ != {enum_type}::{tag_id}) {{
1033                    return false;
1034                }}
1035                """).format(field_id=c.id, enum_type=enum_type, tag_id=c.tag_id))
1036        else:
1037            code.append(
1038                dedent("""\
1039                if (parent.{field_id}_ != {value}) {{
1040                    return false;
1041                }}
1042                """).format(field_id=c.id, value=c.value))
1043
1044    # Parse fields linearly.
1045    if packet.fields:
1046        code.append("// Parse packet field values.")
1047        code.append(f"pdl::packet::slice span = {span};")
1048        for f in packet.fields:
1049            if isinstance(f, ast.SizeField):
1050                code.append(f"{get_cxx_scalar_type(f.width)} {f.field_id}_size;")
1051            elif isinstance(f, (ast.SizeField, ast.CountField)):
1052                code.append(f"{get_cxx_scalar_type(f.width)} {f.field_id}_count;")
1053        parser = FieldParser(extract_arrays=False, byteorder=packet.file.byteorder_short)
1054        for f in packet.fields:
1055            parser.parse(f)
1056        parser.done()
1057        code.extend(parser.code)
1058
1059    code.append("return true;")
1060    return '\n'.join(code)
1061
1062
1063def generate_packet_view_friend_classes(packet: ast.PacketDeclaration) -> str:
1064    """Generate the list of friend declarations for a packet.
1065    These are the direct children of the class."""
1066
1067    return [f"friend class {decl.id}View;" for (_, decl) in core.get_derived_packets(packet, traverse=False)]
1068
1069
1070def generate_packet_view(packet: ast.PacketDeclaration) -> str:
1071    """Generate the implementation of the View class for a
1072    packet declaration."""
1073
1074    parent_class = f"{packet.parent.id}View" if packet.parent else "pdl::packet::slice"
1075    field_members = generate_packet_field_members(packet, view=True)
1076    field_accessors = generate_packet_view_field_accessors(packet)
1077    field_parsers = generate_packet_view_field_parsers(packet)
1078    friend_classes = generate_packet_view_friend_classes(packet)
1079    stringifier = generate_packet_stringifier(packet)
1080
1081    return dedent("""\
1082
1083        class {packet_name}View {{
1084        public:
1085            static {packet_name}View Create({parent_class} const& parent) {{
1086                return {packet_name}View(parent);
1087            }}
1088
1089            {field_accessors}
1090            {stringifier}
1091
1092            bool IsValid() const {{
1093                return valid_;
1094            }}
1095
1096        protected:
1097            explicit {packet_name}View({parent_class} const& parent) {{
1098                valid_ = Parse(parent);
1099            }}
1100
1101            bool Parse({parent_class} const& parent) {{
1102                {field_parsers}
1103            }}
1104
1105            bool valid_{{false}};
1106            {field_members}
1107
1108            {friend_classes}
1109        }};
1110        """).format(packet_name=packet.id,
1111                    parent_class=parent_class,
1112                    field_accessors=indent(field_accessors, 1),
1113                    field_members=indent(field_members, 1),
1114                    field_parsers=indent(field_parsers, 2),
1115                    friend_classes=indent(friend_classes, 1),
1116                    stringifier=indent(stringifier, 1))
1117
1118
1119def generate_packet_constructor(struct: ast.StructDeclaration, constructor_name: str) -> str:
1120    """Generate the implementation of the constructor for a
1121    struct declaration."""
1122
1123    constructor_params = []
1124    constructor_initializers = []
1125    fields = core.get_unconstrained_parent_fields(struct) + struct.fields
1126
1127    for field in fields:
1128        if isinstance(field, (ast.PayloadField, ast.BodyField)):
1129            constructor_params.append("std::vector<uint8_t> payload")
1130            constructor_initializers.append("payload_(std::move(payload))")
1131        elif isinstance(field, ast.ArrayField):
1132            element_type = field.type_id or get_cxx_scalar_type(field.width)
1133            constructor_params.append(f"std::vector<{element_type}> {field.id}")
1134            constructor_initializers.append(f"{field.id}_(std::move({field.id}))")
1135        elif isinstance(field, ast.ScalarField):
1136            backing_type = get_cxx_scalar_type(field.width)
1137            constructor_params.append(f"{backing_type} {field.id}")
1138            constructor_initializers.append(f"{field.id}_({field.id})")
1139        elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)):
1140            constructor_params.append(f"{field.type_id} {field.id}")
1141            constructor_initializers.append(f"{field.id}_({field.id})")
1142        elif isinstance(field, ast.TypedefField):
1143            constructor_params.append(f"{field.type_id} {field.id}")
1144            constructor_initializers.append(f"{field.id}_(std::move({field.id}))")
1145
1146    if not constructor_params:
1147        return ""
1148
1149    explicit = 'explicit ' if len(constructor_params) == 1 else ''
1150    constructor_params = ', '.join(constructor_params)
1151    constructor_initializers = ', '.join(constructor_initializers)
1152
1153    return dedent("""\
1154        {explicit}{constructor_name}({constructor_params})
1155            : {constructor_initializers} {{}}""").format(explicit=explicit,
1156                                                         constructor_name=constructor_name,
1157                                                         constructor_params=constructor_params,
1158                                                         constructor_initializers=constructor_initializers)
1159
1160
1161def generate_packet_builder(packet: ast.PacketDeclaration) -> str:
1162    """Generate the implementation of the Builder class for a
1163    packet declaration."""
1164
1165    class_name = f'{packet.id}Builder'
1166    builder_constructor = generate_packet_constructor(packet, constructor_name=class_name)
1167    field_members = generate_packet_field_members(packet, view=False)
1168    field_serializers = generate_packet_field_serializers(packet)
1169    size_getter = generate_packet_size_getter(packet)
1170    array_field_size_getters = generate_array_field_size_getters(packet)
1171
1172    return dedent("""\
1173
1174        class {class_name} : public pdl::packet::Builder {{
1175        public:
1176            ~{class_name}() override = default;
1177            {class_name}() = default;
1178            {class_name}({class_name} const&) = default;
1179            {class_name}({class_name}&&) = default;
1180            {class_name}& operator=({class_name} const&) = default;
1181            {builder_constructor}
1182
1183            void Serialize(std::vector<uint8_t>& output) const override {{
1184                {field_serializers}
1185            }}
1186
1187            size_t GetSize() const override {{
1188                {size_getter}
1189            }}
1190
1191            {array_field_size_getters}
1192            {field_members}
1193        }};
1194        """).format(class_name=f'{packet.id}Builder',
1195                    builder_constructor=builder_constructor,
1196                    field_members=indent(field_members, 1),
1197                    field_serializers=indent(field_serializers, 2),
1198                    size_getter=indent(size_getter, 1),
1199                    array_field_size_getters=indent(array_field_size_getters, 1))
1200
1201
1202def generate_struct_field_parsers(struct: ast.StructDeclaration) -> str:
1203    """Generate the struct parser. The validator will extract
1204    the fields it can in a pre-parsing phase. """
1205
1206    code = []
1207    parsed_fields = []
1208    post_processing = []
1209
1210    for field in struct.fields:
1211        if isinstance(field, (ast.PayloadField, ast.BodyField)):
1212            code.append("std::vector<uint8_t> payload_;")
1213            parsed_fields.append("std::move(payload_)")
1214        elif isinstance(field, ast.ArrayField):
1215            element_type = field.type_id or get_cxx_scalar_type(field.width)
1216            code.append(f"std::vector<{element_type}> {field.id}_;")
1217            parsed_fields.append(f"std::move({field.id}_)")
1218        elif isinstance(field, ast.ScalarField):
1219            backing_type = get_cxx_scalar_type(field.width)
1220            code.append(f"{backing_type} {field.id}_;")
1221            parsed_fields.append(f"{field.id}_")
1222        elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)):
1223            code.append(f"{field.type_id} {field.id}_;")
1224            parsed_fields.append(f"{field.id}_")
1225        elif isinstance(field, ast.TypedefField):
1226            code.append(f"{field.type_id} {field.id}_;")
1227            parsed_fields.append(f"std::move({field.id}_)")
1228        elif isinstance(field, ast.SizeField):
1229            code.append(f"{get_cxx_scalar_type(field.width)} {field.field_id}_size;")
1230        elif isinstance(field, ast.CountField):
1231            code.append(f"{get_cxx_scalar_type(field.width)} {field.field_id}_count;")
1232
1233    parser = FieldParser(extract_arrays=True, byteorder=struct.file.byteorder_short)
1234    for f in struct.fields:
1235        parser.parse(f)
1236    parser.done()
1237    code.extend(parser.code)
1238
1239    parsed_fields = ', '.join(parsed_fields)
1240    code.append(f"*output = {struct.id}({parsed_fields});")
1241    code.append("return true;")
1242    return '\n'.join(code)
1243
1244
1245def generate_struct_declaration(struct: ast.StructDeclaration) -> str:
1246    """Generate the implementation of the class for a
1247    struct declaration."""
1248
1249    if struct.parent:
1250        raise Exception("Struct declaration with parents are not supported")
1251
1252    struct_constructor = generate_packet_constructor(struct, constructor_name=struct.id)
1253    field_members = generate_packet_field_members(struct, view=False)
1254    field_parsers = generate_struct_field_parsers(struct)
1255    field_serializers = generate_packet_field_serializers(struct)
1256    size_getter = generate_packet_size_getter(struct)
1257    array_field_size_getters = generate_array_field_size_getters(struct)
1258    stringifier = generate_packet_stringifier(struct)
1259
1260    return dedent("""\
1261
1262        class {struct_name} : public pdl::packet::Builder {{
1263        public:
1264            ~{struct_name}() override = default;
1265            {struct_name}() = default;
1266            {struct_name}({struct_name} const&) = default;
1267            {struct_name}({struct_name}&&) = default;
1268            {struct_name}& operator=({struct_name} const&) = default;
1269            {struct_constructor}
1270
1271            static bool Parse(pdl::packet::slice& span, {struct_name}* output) {{
1272                {field_parsers}
1273            }}
1274
1275            void Serialize(std::vector<uint8_t>& output) const override {{
1276                {field_serializers}
1277            }}
1278
1279            size_t GetSize() const override {{
1280                {size_getter}
1281            }}
1282
1283            {array_field_size_getters}
1284            {stringifier}
1285            {field_members}
1286        }};
1287        """).format(struct_name=struct.id,
1288                    struct_constructor=struct_constructor,
1289                    field_members=indent(field_members, 1),
1290                    field_parsers=indent(field_parsers, 2),
1291                    field_serializers=indent(field_serializers, 2),
1292                    stringifier=indent(stringifier, 1),
1293                    size_getter=indent(size_getter, 1),
1294                    array_field_size_getters=indent(array_field_size_getters, 1))
1295
1296
1297def run(input: argparse.FileType, output: argparse.FileType, namespace: Optional[str], include_header: List[str],
1298        using_namespace: List[str]):
1299
1300    file = ast.File.from_json(json.load(input))
1301    core.desugar(file)
1302
1303    include_header = '\n'.join([f'#include <{header}>' for header in include_header])
1304    using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace])
1305    open_namespace = f"namespace {namespace} {{" if namespace else ""
1306    close_namespace = f"}}  // {namespace}" if namespace else ""
1307
1308    # Disable unsupported features in the canonical test suite.
1309    skipped_decls = [
1310        'Packet_Custom_Field_ConstantSize',
1311        'Packet_Custom_Field_VariableSize',
1312        'Packet_Checksum_Field_FromStart',
1313        'Packet_Checksum_Field_FromEnd',
1314        'Struct_Custom_Field_ConstantSize',
1315        'Struct_Custom_Field_VariableSize',
1316        'Struct_Checksum_Field_FromStart',
1317        'Struct_Checksum_Field_FromEnd',
1318        'Struct_Custom_Field_ConstantSize_',
1319        'Struct_Custom_Field_VariableSize_',
1320        'Struct_Checksum_Field_FromStart_',
1321        'Struct_Checksum_Field_FromEnd_',
1322        'PartialParent5',
1323        'PartialChild5_A',
1324        'PartialChild5_B',
1325        'PartialParent12',
1326        'PartialChild12_A',
1327        'PartialChild12_B',
1328    ]
1329
1330    output.write(
1331        dedent("""\
1332        // File generated from {input_name}, with the command:
1333        //  {input_command}
1334        // /!\\ Do not edit by hand
1335
1336        #pragma once
1337
1338        #include <cstdint>
1339        #include <string>
1340        #include <utility>
1341        #include <vector>
1342
1343        #include <packet_runtime.h>
1344
1345        {include_header}
1346        {using_namespace}
1347
1348        #ifndef ASSERT
1349        #include <cassert>
1350        #define ASSERT assert
1351        #endif  // !ASSERT
1352
1353        {open_namespace}
1354        """).format(input_name=input.name,
1355                    input_command=' '.join(sys.argv),
1356                    include_header=include_header,
1357                    using_namespace=using_namespace,
1358                    open_namespace=open_namespace))
1359
1360    for d in file.declarations:
1361        if d.id in skipped_decls:
1362            continue
1363
1364        if isinstance(d, ast.EnumDeclaration):
1365            output.write(generate_enum_declaration(d))
1366            output.write(generate_enum_to_text(d))
1367        elif isinstance(d, ast.PacketDeclaration):
1368            output.write(generate_packet_view(d))
1369            output.write(generate_packet_builder(d))
1370        elif isinstance(d, ast.StructDeclaration):
1371            output.write(generate_struct_declaration(d))
1372
1373    output.write(f"{close_namespace}\n")
1374
1375
1376def main() -> int:
1377    """Generate cxx PDL backend."""
1378    parser = argparse.ArgumentParser(description=__doc__)
1379    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
1380    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file')
1381    parser.add_argument('--namespace', type=str, help='Generated module namespace')
1382    parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives')
1383    parser.add_argument('--using-namespace',
1384                        type=str,
1385                        default=[],
1386                        action='append',
1387                        help='Added using namespace statements')
1388    return run(**vars(parser.parse_args()))
1389
1390
1391if __name__ == '__main__':
1392    sys.exit(main())
1393