• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18from dataclasses import dataclass, field
19import json
20from pathlib import Path
21import sys
22from textwrap import dedent
23from typing import List, Tuple, Union, Optional
24
25from pdl import ast, core
26from pdl.utils import indent
27
28
29def mask(width: int) -> str:
30    return hex((1 << width) - 1)
31
32
33def generate_prelude() -> str:
34    return dedent("""\
35        from dataclasses import dataclass, field, fields
36        from typing import Optional, List, Tuple
37        import enum
38        import inspect
39        import math
40
41        @dataclass
42        class Packet:
43            payload: Optional[bytes] = field(repr=False, default_factory=bytes, compare=False)
44
45            @classmethod
46            def parse_all(cls, span: bytes) -> 'Packet':
47                packet, remain = getattr(cls, 'parse')(span)
48                if len(remain) > 0:
49                    raise Exception('Unexpected parsing remainder')
50                return packet
51
52            @property
53            def size(self) -> int:
54                pass
55
56            def show(self, prefix: str = ''):
57                print(f'{self.__class__.__name__}')
58
59                def print_val(p: str, pp: str, name: str, align: int, typ, val):
60                    if name == 'payload':
61                        pass
62
63                    # Scalar fields.
64                    elif typ is int:
65                        print(f'{p}{name:{align}} = {val} (0x{val:x})')
66
67                    # Byte fields.
68                    elif typ is bytes:
69                        print(f'{p}{name:{align}} = [', end='')
70                        line = ''
71                        n_pp = ''
72                        for (idx, b) in enumerate(val):
73                            if idx > 0 and idx % 8 == 0:
74                                print(f'{n_pp}{line}')
75                                line = ''
76                                n_pp = pp + (' ' * (align + 4))
77                            line += f' {b:02x}'
78                        print(f'{n_pp}{line} ]')
79
80                    # Enum fields.
81                    elif inspect.isclass(typ) and issubclass(typ, enum.IntEnum):
82                        print(f'{p}{name:{align}} = {typ.__name__}::{val.name} (0x{val:x})')
83
84                    # Struct fields.
85                    elif inspect.isclass(typ) and issubclass(typ, globals().get('Packet')):
86                        print(f'{p}{name:{align}} = ', end='')
87                        val.show(prefix=pp)
88
89                    # Array fields.
90                    elif getattr(typ, '__origin__', None) == list:
91                        print(f'{p}{name:{align}}')
92                        last = len(val) - 1
93                        align = 5
94                        for (idx, elt) in enumerate(val):
95                            n_p  = pp + ('├── ' if idx != last else '└── ')
96                            n_pp = pp + ('│   ' if idx != last else '    ')
97                            print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx])
98
99                    # Custom fields.
100                    elif inspect.isclass(typ):
101                        print(f'{p}{name:{align}} = {repr(val)}')
102
103                    else:
104                        print(f'{p}{name:{align}} = ##{typ}##')
105
106                last = len(fields(self)) - 1
107                align = max(len(f.name) for f in fields(self) if f.name != 'payload')
108
109                for (idx, f) in enumerate(fields(self)):
110                    p  = prefix + ('├── ' if idx != last else '└── ')
111                    pp = prefix + ('│   ' if idx != last else '    ')
112                    val = getattr(self, f.name)
113
114                    print_val(p, pp, f.name, align, f.type, val)
115        """)
116
117
118@dataclass
119class FieldParser:
120    byteorder: str
121    offset: int = 0
122    shift: int = 0
123    chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: [])
124    unchecked_code: List[str] = field(default_factory=lambda: [])
125    code: List[str] = field(default_factory=lambda: [])
126
127    def unchecked_append_(self, line: str):
128        """Append unchecked field parsing code.
129        The function check_size_ must be called to generate a size guard
130        after parsing is completed."""
131        self.unchecked_code.append(line)
132
133    def append_(self, line: str):
134        """Append field parsing code.
135        There must be no unchecked code left before this function is called."""
136        assert len(self.unchecked_code) == 0
137        self.code.append(line)
138
139    def check_size_(self, size: str):
140        """Generate a check of the current span size."""
141        self.append_(f"if len(span) < {size}:")
142        self.append_(f"    raise Exception('Invalid packet size')")
143
144    def check_code_(self):
145        """Generate a size check for pending field parsing."""
146        if len(self.unchecked_code) > 0:
147            assert len(self.chunk) == 0
148            unchecked_code = self.unchecked_code
149            self.unchecked_code = []
150            self.check_size_(str(self.offset))
151            self.code.extend(unchecked_code)
152
153    def consume_span_(self, keep: int = 0) -> str:
154        """Skip consumed span bytes."""
155        if self.offset > 0:
156            self.check_code_()
157            self.append_(f'span = span[{self.offset - keep}:]')
158            self.offset = 0
159
160    def parse_array_element_dynamic_(self, field: ast.ArrayField, span: str):
161        """Parse a single array field element of variable size."""
162        if isinstance(field.type, ast.StructDeclaration):
163            self.append_(f"    element, {span} = {field.type_id}.parse({span})")
164            self.append_(f"    {field.id}.append(element)")
165        else:
166            raise Exception(f'Unexpected array element type {field.type_id} {field.width}')
167
168    def parse_array_element_static_(self, field: ast.ArrayField, span: str):
169        """Parse a single array field element of constant size."""
170        if field.width is not None:
171            element = f"int.from_bytes({span}, byteorder='{self.byteorder}')"
172            self.append_(f"    {field.id}.append({element})")
173        elif isinstance(field.type, ast.EnumDeclaration):
174            element = f"int.from_bytes({span}, byteorder='{self.byteorder}')"
175            element = f"{field.type_id}({element})"
176            self.append_(f"    {field.id}.append({element})")
177        else:
178            element = f"{field.type_id}.parse_all({span})"
179            self.append_(f"    {field.id}.append({element})")
180
181    def parse_byte_array_field_(self, field: ast.ArrayField):
182        """Parse the selected u8 array field."""
183        array_size = core.get_array_field_size(field)
184        padded_size = field.padded_size
185
186        # Shift the span to reset the offset to 0.
187        self.consume_span_()
188
189        # Derive the array size.
190        if isinstance(array_size, int):
191            size = array_size
192        elif isinstance(array_size, ast.SizeField):
193            size = f'{field.id}_size - {field.size_modifier}' if field.size_modifier else f'{field.id}_size'
194        elif isinstance(array_size, ast.CountField):
195            size = f'{field.id}_count'
196        else:
197            size = None
198
199        # Parse from the padded array if padding is present.
200        if padded_size and size is not None:
201            self.check_size_(padded_size)
202            self.append_(f"if {size} > {padded_size}:")
203            self.append_("    raise Exception('Array size is larger than the padding size')")
204            self.append_(f"fields['{field.id}'] = list(span[:{size}])")
205            self.append_(f"span = span[{padded_size}:]")
206
207        elif size is not None:
208            self.check_size_(size)
209            self.append_(f"fields['{field.id}'] = list(span[:{size}])")
210            self.append_(f"span = span[{size}:]")
211
212        else:
213            self.append_(f"fields['{field.id}'] = list(span)")
214            self.append_(f"span = bytes()")
215
216    def parse_array_field_(self, field: ast.ArrayField):
217        """Parse the selected array field."""
218        array_size = core.get_array_field_size(field)
219        element_width = core.get_array_element_size(field)
220        padded_size = field.padded_size
221
222        if element_width:
223            if element_width % 8 != 0:
224                raise Exception('Array element size is not a multiple of 8')
225            element_width = int(element_width / 8)
226
227        if isinstance(array_size, int):
228            size = None
229            count = array_size
230        elif isinstance(array_size, ast.SizeField):
231            size = f'{field.id}_size'
232            count = None
233        elif isinstance(array_size, ast.CountField):
234            size = None
235            count = f'{field.id}_count'
236        else:
237            size = None
238            count = None
239
240        # Shift the span to reset the offset to 0.
241        self.consume_span_()
242
243        # Apply the size modifier.
244        if field.size_modifier and size:
245            self.append_(f"{size} = {size} - {field.size_modifier}")
246
247        # Parse from the padded array if padding is present.
248        if padded_size:
249            self.check_size_(padded_size)
250            self.append_(f"remaining_span = span[{padded_size}:]")
251            self.append_(f"span = span[:{padded_size}]")
252
253        # The element width is not known, but the array full octet size
254        # is known by size field. Parse elements item by item as a vector.
255        if element_width is None and size is not None:
256            self.check_size_(size)
257            self.append_(f"array_span = span[:{size}]")
258            self.append_(f"{field.id} = []")
259            self.append_("while len(array_span) > 0:")
260            self.parse_array_element_dynamic_(field, 'array_span')
261            self.append_(f"fields['{field.id}'] = {field.id}")
262            self.append_(f"span = span[{size}:]")
263
264        # The element width is not known, but the array element count
265        # is known statically or by count field.
266        # Parse elements item by item as a vector.
267        elif element_width is None and count is not None:
268            self.append_(f"{field.id} = []")
269            self.append_(f"for n in range({count}):")
270            self.parse_array_element_dynamic_(field, 'span')
271            self.append_(f"fields['{field.id}'] = {field.id}")
272
273        # Neither the count not size is known,
274        # parse elements until the end of the span.
275        elif element_width is None:
276            self.append_(f"{field.id} = []")
277            self.append_("while len(span) > 0:")
278            self.parse_array_element_dynamic_(field, 'span')
279            self.append_(f"fields['{field.id}'] = {field.id}")
280
281        # The element width is known, and the array element count is known
282        # statically, or by count field.
283        elif count is not None:
284            array_size = (f'{count}' if element_width == 1 else f'{count} * {element_width}')
285            self.check_size_(array_size)
286            self.append_(f"{field.id} = []")
287            self.append_(f"for n in range({count}):")
288            span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]')
289            self.parse_array_element_static_(field, span)
290            self.append_(f"fields['{field.id}'] = {field.id}")
291            self.append_(f"span = span[{array_size}:]")
292
293        # The element width is known, and the array full size is known
294        # by size field, or unknown (in which case it is the remaining span
295        # length).
296        else:
297            if size is not None:
298                self.check_size_(size)
299            array_size = size or 'len(span)'
300            if element_width != 1:
301                self.append_(f"if {array_size} % {element_width} != 0:")
302                self.append_("    raise Exception('Array size is not a multiple of the element size')")
303                self.append_(f"{field.id}_count = int({array_size} / {element_width})")
304                array_count = f'{field.id}_count'
305            else:
306                array_count = array_size
307            self.append_(f"{field.id} = []")
308            self.append_(f"for n in range({array_count}):")
309            span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]')
310            self.parse_array_element_static_(field, span)
311            self.append_(f"fields['{field.id}'] = {field.id}")
312            if size is not None:
313                self.append_(f"span = span[{size}:]")
314            else:
315                self.append_(f"span = bytes()")
316
317        # Drop the padding
318        if padded_size:
319            self.append_(f"span = remaining_span")
320
321    def parse_bit_field_(self, field: ast.Field):
322        """Parse the selected field as a bit field.
323        The field is added to the current chunk. When a byte boundary
324        is reached all saved fields are extracted together."""
325
326        # Add to current chunk.
327        width = core.get_field_size(field)
328        self.chunk.append((self.shift, width, field))
329        self.shift += width
330
331        # Wait for more fields if not on a byte boundary.
332        if (self.shift % 8) != 0:
333            return
334
335        # Parse the backing integer using the configured endiannes,
336        # extract field values.
337        size = int(self.shift / 8)
338        end_offset = self.offset + size
339
340        if size == 1:
341            value = f"span[{self.offset}]"
342        else:
343            span = f"span[{self.offset}:{end_offset}]"
344            self.unchecked_append_(f"value_ = int.from_bytes({span}, byteorder='{self.byteorder}')")
345            value = "value_"
346
347        for shift, width, field in self.chunk:
348            v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}")
349
350            if isinstance(field, ast.ScalarField):
351                self.unchecked_append_(f"fields['{field.id}'] = {v}")
352            elif isinstance(field, ast.FixedField) and field.enum_id:
353                self.unchecked_append_(f"if {v} != {field.enum_id}.{field.tag_id}:")
354                self.unchecked_append_(f"    raise Exception('Unexpected fixed field value')")
355            elif isinstance(field, ast.FixedField):
356                self.unchecked_append_(f"if {v} != {hex(field.value)}:")
357                self.unchecked_append_(f"    raise Exception('Unexpected fixed field value')")
358            elif isinstance(field, ast.TypedefField):
359                self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}({v})")
360            elif isinstance(field, ast.SizeField):
361                self.unchecked_append_(f"{field.field_id}_size = {v}")
362            elif isinstance(field, ast.CountField):
363                self.unchecked_append_(f"{field.field_id}_count = {v}")
364            elif isinstance(field, ast.ReservedField):
365                pass
366            else:
367                raise Exception(f'Unsupported bit field type {field.kind}')
368
369        # Reset state.
370        self.offset = end_offset
371        self.shift = 0
372        self.chunk = []
373
374    def parse_typedef_field_(self, field: ast.TypedefField):
375        """Parse a typedef field, to the exclusion of Enum fields."""
376
377        if self.shift != 0:
378            raise Exception('Typedef field does not start on an octet boundary')
379        if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
380            raise Exception('Derived struct used in typedef field')
381
382        width = core.get_declaration_size(field.type)
383        if width is None:
384            self.consume_span_()
385            self.append_(f"{field.id}, span = {field.type_id}.parse(span)")
386            self.append_(f"fields['{field.id}'] = {field.id}")
387        else:
388            if width % 8 != 0:
389                raise Exception('Typedef field type size is not a multiple of 8')
390            width = int(width / 8)
391            end_offset = self.offset + width
392            # Checksum value field is generated alongside checksum start.
393            # Deal with this field as padding.
394            if not isinstance(field.type, ast.ChecksumDeclaration):
395                span = f'span[{self.offset}:{end_offset}]'
396                self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.parse_all({span})")
397            self.offset = end_offset
398
399    def parse_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]):
400        """Parse body and payload fields."""
401
402        payload_size = core.get_payload_field_size(field)
403        offset_from_end = core.get_field_offset_from_end(field)
404
405        # If the payload is not byte aligned, do parse the bit fields
406        # that can be extracted, but do not consume the input bytes as
407        # they will also be included in the payload span.
408        if self.shift != 0:
409            if payload_size:
410                raise Exception("Unexpected payload size for non byte aligned payload")
411
412            rounded_size = int((self.shift + 7) / 8)
413            padding_bits = 8 * rounded_size - self.shift
414            self.parse_bit_field_(core.make_reserved_field(padding_bits))
415            self.consume_span_(rounded_size)
416        else:
417            self.consume_span_()
418
419        # The payload or body has a known size.
420        # Consume the payload and update the span in case
421        # fields are placed after the payload.
422        if payload_size:
423            if getattr(field, 'size_modifier', None):
424                self.append_(f"{field.id}_size -= {field.size_modifier}")
425            self.check_size_(f'{field.id}_size')
426            self.append_(f"payload = span[:{field.id}_size]")
427            self.append_(f"span = span[{field.id}_size:]")
428        # The payload or body is the last field of a packet,
429        # consume the remaining span.
430        elif offset_from_end == 0:
431            self.append_(f"payload = span")
432            self.append_(f"span = bytes([])")
433        # The payload or body is followed by fields of static size.
434        # Consume the span that is not reserved for the following fields.
435        elif offset_from_end is not None:
436            if (offset_from_end % 8) != 0:
437                raise Exception('Payload field offset from end of packet is not a multiple of 8')
438            offset_from_end = int(offset_from_end / 8)
439            self.check_size_(f'{offset_from_end}')
440            self.append_(f"payload = span[:-{offset_from_end}]")
441            self.append_(f"span = span[-{offset_from_end}:]")
442        self.append_(f"fields['payload'] = payload")
443
444    def parse_checksum_field_(self, field: ast.ChecksumField):
445        """Generate a checksum check."""
446
447        # The checksum value field can be read starting from the current
448        # offset if the fields in between are of fixed size, or from the end
449        # of the span otherwise.
450        self.consume_span_()
451        value_field = core.get_packet_field(field.parent, field.field_id)
452        offset_from_start = 0
453        offset_from_end = 0
454        start_index = field.parent.fields.index(field)
455        value_index = field.parent.fields.index(value_field)
456        value_size = int(core.get_field_size(value_field) / 8)
457
458        for f in field.parent.fields[start_index + 1:value_index]:
459            size = core.get_field_size(f)
460            if size is None:
461                offset_from_start = None
462                break
463            else:
464                offset_from_start += size
465
466        trailing_fields = field.parent.fields[value_index:]
467        trailing_fields.reverse()
468        for f in trailing_fields:
469            size = core.get_field_size(f)
470            if size is None:
471                offset_from_end = None
472                break
473            else:
474                offset_from_end += size
475
476        if offset_from_start is not None:
477            if offset_from_start % 8 != 0:
478                raise Exception('Checksum value field is not aligned to an octet boundary')
479            offset_from_start = int(offset_from_start / 8)
480            checksum_span = f'span[:{offset_from_start}]'
481            if value_size > 1:
482                start = offset_from_start
483                end = offset_from_start + value_size
484                value = f"int.from_bytes(span[{start}:{end}], byteorder='{self.byteorder}')"
485            else:
486                value = f'span[{offset_from_start}]'
487            self.check_size_(offset_from_start + value_size)
488
489        elif offset_from_end is not None:
490            sign = ''
491            if offset_from_end % 8 != 0:
492                raise Exception('Checksum value field is not aligned to an octet boundary')
493            offset_from_end = int(offset_from_end / 8)
494            checksum_span = f'span[:-{offset_from_end}]'
495            if value_size > 1:
496                start = offset_from_end
497                end = offset_from_end - value_size
498                value = f"int.from_bytes(span[-{start}:-{end}], byteorder='{self.byteorder}')"
499            else:
500                value = f'span[-{offset_from_end}]'
501            self.check_size_(offset_from_end)
502
503        else:
504            raise Exception('Checksum value field cannot be read at constant offset')
505
506        self.append_(f"{value_field.id} = {value}")
507        self.append_(f"fields['{value_field.id}'] = {value_field.id}")
508        self.append_(f"computed_{value_field.id} = {value_field.type.function}({checksum_span})")
509        self.append_(f"if computed_{value_field.id} != {value_field.id}:")
510        self.append_("    raise Exception(f'Invalid checksum computation:" +
511                     f" {{computed_{value_field.id}}} != {{{value_field.id}}}')")
512
513    def parse(self, field: ast.Field):
514        # Field has bit granularity.
515        # Append the field to the current chunk,
516        # check if a byte boundary was reached.
517        if core.is_bit_field(field):
518            self.parse_bit_field_(field)
519
520        # Padding fields.
521        elif isinstance(field, ast.PaddingField):
522            pass
523
524        # Array fields.
525        elif isinstance(field, ast.ArrayField) and field.width == 8:
526            self.parse_byte_array_field_(field)
527
528        elif isinstance(field, ast.ArrayField):
529            self.parse_array_field_(field)
530
531        # Other typedef fields.
532        elif isinstance(field, ast.TypedefField):
533            self.parse_typedef_field_(field)
534
535        # Payload and body fields.
536        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
537            self.parse_payload_field_(field)
538
539        # Checksum fields.
540        elif isinstance(field, ast.ChecksumField):
541            self.parse_checksum_field_(field)
542
543        else:
544            raise Exception(f'Unimplemented field type {field.kind}')
545
546    def done(self):
547        self.consume_span_()
548
549
550@dataclass
551class FieldSerializer:
552    byteorder: str
553    shift: int = 0
554    value: List[str] = field(default_factory=lambda: [])
555    code: List[str] = field(default_factory=lambda: [])
556    indent: int = 0
557
558    def indent_(self):
559        self.indent += 1
560
561    def unindent_(self):
562        self.indent -= 1
563
564    def append_(self, line: str):
565        """Append field serializing code."""
566        lines = line.split('\n')
567        self.code.extend(['    ' * self.indent + line for line in lines])
568
569    def extend_(self, value: str, length: int):
570        """Append data to the span being constructed."""
571        if length == 1:
572            self.append_(f"_span.append({value})")
573        else:
574            self.append_(f"_span.extend(int.to_bytes({value}, length={length}, byteorder='{self.byteorder}'))")
575
576    def serialize_array_element_(self, field: ast.ArrayField):
577        """Serialize a single array field element."""
578        if field.width is not None:
579            length = int(field.width / 8)
580            self.extend_('_elt', length)
581        elif isinstance(field.type, ast.EnumDeclaration):
582            length = int(field.type.width / 8)
583            self.extend_('_elt', length)
584        else:
585            self.append_("_span.extend(_elt.serialize())")
586
587    def serialize_array_field_(self, field: ast.ArrayField):
588        """Serialize the selected array field."""
589        if field.padded_size:
590            self.append_(f"_{field.id}_start = len(_span)")
591
592        if field.width == 8:
593            self.append_(f"_span.extend(self.{field.id})")
594        else:
595            self.append_(f"for _elt in self.{field.id}:")
596            self.indent_()
597            self.serialize_array_element_(field)
598            self.unindent_()
599
600        if field.padded_size:
601            self.append_(f"_span.extend([0] * ({field.padded_size} - len(_span) + _{field.id}_start))")
602
603    def serialize_bit_field_(self, field: ast.Field):
604        """Serialize the selected field as a bit field.
605        The field is added to the current chunk. When a byte boundary
606        is reached all saved fields are serialized together."""
607
608        # Add to current chunk.
609        width = core.get_field_size(field)
610        shift = self.shift
611
612        if isinstance(field, str):
613            self.value.append(f"({field} << {shift})")
614        elif isinstance(field, ast.ScalarField):
615            max_value = (1 << field.width) - 1
616            self.append_(f"if self.{field.id} > {max_value}:")
617            self.append_(f"    print(f\"Invalid value for field {field.parent.id}::{field.id}:" +
618                         f" {{self.{field.id}}} > {max_value}; the value will be truncated\")")
619            self.append_(f"    self.{field.id} &= {max_value}")
620            self.value.append(f"(self.{field.id} << {shift})")
621        elif isinstance(field, ast.FixedField) and field.enum_id:
622            self.value.append(f"({field.enum_id}.{field.tag_id} << {shift})")
623        elif isinstance(field, ast.FixedField):
624            self.value.append(f"({field.value} << {shift})")
625        elif isinstance(field, ast.TypedefField):
626            self.value.append(f"(self.{field.id} << {shift})")
627
628        elif isinstance(field, ast.SizeField):
629            max_size = (1 << field.width) - 1
630            value_field = core.get_packet_field(field.parent, field.field_id)
631            size_modifier = ''
632
633            if getattr(value_field, 'size_modifier', None):
634                size_modifier = f' + {value_field.size_modifier}'
635
636            if isinstance(value_field, (ast.PayloadField, ast.BodyField)):
637                self.append_(f"_payload_size = len(payload or self.payload or []){size_modifier}")
638                self.append_(f"if _payload_size > {max_size}:")
639                self.append_(f"    print(f\"Invalid length for payload field:" +
640                             f"  {{_payload_size}} > {max_size}; the packet cannot be generated\")")
641                self.append_(f"    raise Exception(\"Invalid payload length\")")
642                array_size = "_payload_size"
643            elif isinstance(value_field, ast.ArrayField) and value_field.width:
644                array_size = f"(len(self.{value_field.id}) * {int(value_field.width / 8)}{size_modifier})"
645            elif isinstance(value_field, ast.ArrayField) and isinstance(value_field.type, ast.EnumDeclaration):
646                array_size = f"(len(self.{value_field.id}) * {int(value_field.type.width / 8)}{size_modifier})"
647            elif isinstance(value_field, ast.ArrayField):
648                self.append_(
649                    f"_{value_field.id}_size = sum([elt.size for elt in self.{value_field.id}]){size_modifier}")
650                array_size = f"_{value_field.id}_size"
651            else:
652                raise Exception("Unsupported field type")
653            self.value.append(f"({array_size} << {shift})")
654
655        elif isinstance(field, ast.CountField):
656            max_count = (1 << field.width) - 1
657            self.append_(f"if len(self.{field.field_id}) > {max_count}:")
658            self.append_(f"    print(f\"Invalid length for field {field.parent.id}::{field.field_id}:" +
659                         f"  {{len(self.{field.field_id})}} > {max_count}; the array will be truncated\")")
660            self.append_(f"    del self.{field.field_id}[{max_count}:]")
661            self.value.append(f"(len(self.{field.field_id}) << {shift})")
662        elif isinstance(field, ast.ReservedField):
663            pass
664        else:
665            raise Exception(f'Unsupported bit field type {field.kind}')
666
667        # Check if a byte boundary is reached.
668        self.shift += width
669        if (self.shift % 8) == 0:
670            self.pack_bit_fields_()
671
672    def pack_bit_fields_(self):
673        """Pack serialized bit fields."""
674
675        # Should have an integral number of bytes now.
676        assert (self.shift % 8) == 0
677
678        # Generate the backing integer, and serialize it
679        # using the configured endiannes,
680        size = int(self.shift / 8)
681
682        if len(self.value) == 0:
683            self.append_(f"_span.extend([0] * {size})")
684        elif len(self.value) == 1:
685            self.extend_(self.value[0], size)
686        else:
687            self.append_(f"_value = (")
688            self.append_("    " + " |\n    ".join(self.value))
689            self.append_(")")
690            self.extend_('_value', size)
691
692        # Reset state.
693        self.shift = 0
694        self.value = []
695
696    def serialize_typedef_field_(self, field: ast.TypedefField):
697        """Serialize a typedef field, to the exclusion of Enum fields."""
698
699        if self.shift != 0:
700            raise Exception('Typedef field does not start on an octet boundary')
701        if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
702            raise Exception('Derived struct used in typedef field')
703
704        if isinstance(field.type, ast.ChecksumDeclaration):
705            size = int(field.type.width / 8)
706            self.append_(f"_checksum = {field.type.function}(_span[_checksum_start:])")
707            self.extend_('_checksum', size)
708        else:
709            self.append_(f"_span.extend(self.{field.id}.serialize())")
710
711    def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]):
712        """Serialize body and payload fields."""
713
714        if self.shift != 0 and self.byteorder == 'big':
715            raise Exception('Payload field does not start on an octet boundary')
716
717        if self.shift == 0:
718            self.append_(f"_span.extend(payload or self.payload or [])")
719        else:
720            # Supported case of packet inheritance;
721            # the incomplete fields are serialized into
722            # the payload, rather than separately.
723            # First extract the padding bits from the payload,
724            # then recombine them with the bit fields to be serialized.
725            rounded_size = int((self.shift + 7) / 8)
726            padding_bits = 8 * rounded_size - self.shift
727            self.append_(f"_payload = payload or self.payload or bytes()")
728            self.append_(f"if len(_payload) < {rounded_size}:")
729            self.append_(f"    raise Exception(f\"Invalid length for payload field:" +
730                         f"  {{len(_payload)}} < {rounded_size}\")")
731            self.append_(
732                f"_padding = int.from_bytes(_payload[:{rounded_size}], byteorder='{self.byteorder}') >> {self.shift}")
733            self.value.append(f"(_padding << {self.shift})")
734            self.shift += padding_bits
735            self.pack_bit_fields_()
736            self.append_(f"_span.extend(_payload[{rounded_size}:])")
737
738    def serialize_checksum_field_(self, field: ast.ChecksumField):
739        """Generate a checksum check."""
740
741        self.append_("_checksum_start = len(_span)")
742
743    def serialize(self, field: ast.Field):
744        # Field has bit granularity.
745        # Append the field to the current chunk,
746        # check if a byte boundary was reached.
747        if core.is_bit_field(field):
748            self.serialize_bit_field_(field)
749
750        # Padding fields.
751        elif isinstance(field, ast.PaddingField):
752            pass
753
754        # Array fields.
755        elif isinstance(field, ast.ArrayField):
756            self.serialize_array_field_(field)
757
758        # Other typedef fields.
759        elif isinstance(field, ast.TypedefField):
760            self.serialize_typedef_field_(field)
761
762        # Payload and body fields.
763        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
764            self.serialize_payload_field_(field)
765
766        # Checksum fields.
767        elif isinstance(field, ast.ChecksumField):
768            self.serialize_checksum_field_(field)
769
770        else:
771            raise Exception(f'Unimplemented field type {field.kind}')
772
773
774def generate_toplevel_packet_serializer(packet: ast.Declaration) -> List[str]:
775    """Generate the serialize() function for a toplevel Packet or Struct
776       declaration."""
777
778    serializer = FieldSerializer(byteorder=packet.file.byteorder)
779    for f in packet.fields:
780        serializer.serialize(f)
781    return ['_span = bytearray()'] + serializer.code + ['return bytes(_span)']
782
783
784def generate_derived_packet_serializer(packet: ast.Declaration) -> List[str]:
785    """Generate the serialize() function for a derived Packet or Struct
786       declaration."""
787
788    packet_shift = core.get_packet_shift(packet)
789    if packet_shift and packet.file.byteorder == 'big':
790        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")
791
792    serializer = FieldSerializer(byteorder=packet.file.byteorder, shift=packet_shift)
793    for f in packet.fields:
794        serializer.serialize(f)
795    return ['_span = bytearray()'
796           ] + serializer.code + [f'return {packet.parent.id}.serialize(self, payload = bytes(_span))']
797
798
799def generate_packet_parser(packet: ast.Declaration) -> List[str]:
800    """Generate the parse() function for a toplevel Packet or Struct
801       declaration."""
802
803    packet_shift = core.get_packet_shift(packet)
804    if packet_shift and packet.file.byteorder == 'big':
805        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")
806
807    # Convert the packet constraints to a boolean expression.
808    validation = []
809    if packet.constraints:
810        cond = []
811        for c in packet.constraints:
812            if c.value is not None:
813                cond.append(f"fields['{c.id}'] != {hex(c.value)}")
814            else:
815                field = core.get_packet_field(packet, c.id)
816                cond.append(f"fields['{c.id}'] != {field.type_id}.{c.tag_id}")
817
818        validation = [f"if {' or '.join(cond)}:", "    raise Exception(\"Invalid constraint field values\")"]
819
820    # Parse fields iteratively.
821    parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift)
822    for f in packet.fields:
823        parser.parse(f)
824    parser.done()
825
826    # Specialize to child packets.
827    children = core.get_derived_packets(packet)
828    decl = [] if packet.parent_id else ['fields = {\'payload\': None}']
829    specialization = []
830
831    if len(children) != 0:
832        # Try parsing every child packet successively until one is
833        # successfully parsed. Return a parsing error if none is valid.
834        # Return parent packet if no child packet matches.
835        # TODO: order child packets by decreasing size in case no constraint
836        # is given for specialization.
837        for _, child in children:
838            specialization.append("try:")
839            specialization.append(f"    return {child.id}.parse(fields.copy(), payload)")
840            specialization.append("except Exception as exn:")
841            specialization.append("    pass")
842
843    return decl + validation + parser.code + specialization + [f"return {packet.id}(**fields), span"]
844
845
846def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
847    constant_width = 0
848    variable_width = []
849    for f in packet.fields:
850        field_size = core.get_field_size(f)
851        if field_size is not None:
852            constant_width += field_size
853        elif isinstance(f, (ast.PayloadField, ast.BodyField)):
854            variable_width.append("len(self.payload)")
855        elif isinstance(f, ast.TypedefField):
856            variable_width.append(f"self.{f.id}.size")
857        elif isinstance(f, ast.ArrayField) and isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)):
858            variable_width.append(f"sum([elt.size for elt in self.{f.id}])")
859        elif isinstance(f, ast.ArrayField) and isinstance(f.type, ast.EnumDeclaration):
860            variable_width.append(f"len(self.{f.id}) * {f.type.width}")
861        elif isinstance(f, ast.ArrayField):
862            variable_width.append(f"len(self.{f.id}) * {int(f.width / 8)}")
863        else:
864            raise Exception("Unsupported field type")
865
866    constant_width = int(constant_width / 8)
867    if len(variable_width) == 0:
868        return [f"return {constant_width}"]
869    elif len(variable_width) == 1 and constant_width:
870        return [f"return {variable_width[0]} + {constant_width}"]
871    elif len(variable_width) == 1:
872        return [f"return {variable_width[0]}"]
873    elif len(variable_width) > 1 and constant_width:
874        return ([f"return {constant_width} + ("] + " +\n    ".join(variable_width).split("\n") + [")"])
875    elif len(variable_width) > 1:
876        return (["return ("] + " +\n    ".join(variable_width).split("\n") + [")"])
877    else:
878        assert False
879
880
881def generate_packet_post_init(decl: ast.Declaration) -> List[str]:
882    """Generate __post_init__ function to set constraint field values."""
883
884    # Gather all constraints from parent packets.
885    constraints = []
886    current = decl
887    while current.parent_id:
888        constraints.extend(current.constraints)
889        current = current.parent
890
891    if constraints:
892        code = []
893        for c in constraints:
894            if c.value is not None:
895                code.append(f"self.{c.id} = {c.value}")
896            else:
897                field = core.get_packet_field(decl, c.id)
898                code.append(f"self.{c.id} = {field.type_id}.{c.tag_id}")
899        return code
900
901    else:
902        return ["pass"]
903
904
905def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
906    """Generate the implementation of an enum type."""
907
908    enum_name = decl.id
909    tag_decls = []
910    for t in decl.tags:
911        tag_decls.append(f"{t.id} = {hex(t.value)}")
912
913    return dedent("""\
914
915        class {enum_name}(enum.IntEnum):
916            {tag_decls}
917        """).format(enum_name=enum_name, tag_decls=indent(tag_decls, 1))
918
919
920def generate_packet_declaration(packet: ast.Declaration) -> str:
921    """Generate the implementation a toplevel Packet or Struct
922       declaration."""
923
924    packet_name = packet.id
925    field_decls = []
926    for f in packet.fields:
927        if isinstance(f, ast.ScalarField):
928            field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)")
929        elif isinstance(f, ast.TypedefField):
930            if isinstance(f.type, ast.EnumDeclaration):
931                field_decls.append(
932                    f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type_id}.{f.type.tags[0].id})")
933            elif isinstance(f.type, ast.ChecksumDeclaration):
934                field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)")
935            elif isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)):
936                field_decls.append(f"{f.id}: {f.type_id} = field(kw_only=True, default_factory={f.type_id})")
937            else:
938                raise Exception("Unsupported typedef field type")
939        elif isinstance(f, ast.ArrayField) and f.width == 8:
940            field_decls.append(f"{f.id}: bytearray = field(kw_only=True, default_factory=bytearray)")
941        elif isinstance(f, ast.ArrayField) and f.width:
942            field_decls.append(f"{f.id}: List[int] = field(kw_only=True, default_factory=list)")
943        elif isinstance(f, ast.ArrayField) and f.type_id:
944            field_decls.append(f"{f.id}: List[{f.type_id}] = field(kw_only=True, default_factory=list)")
945
946    if packet.parent_id:
947        parent_name = packet.parent_id
948        parent_fields = 'fields: dict, '
949        serializer = generate_derived_packet_serializer(packet)
950    else:
951        parent_name = 'Packet'
952        parent_fields = ''
953        serializer = generate_toplevel_packet_serializer(packet)
954
955    parser = generate_packet_parser(packet)
956    size = generate_packet_size_getter(packet)
957    post_init = generate_packet_post_init(packet)
958
959    return dedent("""\
960
961        @dataclass
962        class {packet_name}({parent_name}):
963            {field_decls}
964
965            def __post_init__(self):
966                {post_init}
967
968            @staticmethod
969            def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]:
970                {parser}
971
972            def serialize(self, payload: bytes = None) -> bytes:
973                {serializer}
974
975            @property
976            def size(self) -> int:
977                {size}
978        """).format(packet_name=packet_name,
979                    parent_name=parent_name,
980                    parent_fields=parent_fields,
981                    field_decls=indent(field_decls, 1),
982                    post_init=indent(post_init, 2),
983                    parser=indent(parser, 2),
984                    serializer=indent(serializer, 2),
985                    size=indent(size, 2))
986
987
988def generate_custom_field_declaration_check(decl: ast.CustomFieldDeclaration) -> str:
989    """Generate the code to validate a user custom field implementation.
990
991    This code is to be executed when the generated module is loaded to ensure
992    the user gets an immediate and clear error message when the provided
993    custom types do not fit the expected template.
994    """
995    return dedent("""\
996
997        if (not callable(getattr({custom_field_name}, 'parse', None)) or
998            not callable(getattr({custom_field_name}, 'parse_all', None))):
999            raise Exception('The custom field type {custom_field_name} does not implement the parse method')
1000    """).format(custom_field_name=decl.id)
1001
1002
1003def generate_checksum_declaration_check(decl: ast.ChecksumDeclaration) -> str:
1004    """Generate the code to validate a user checksum field implementation.
1005
1006    This code is to be executed when the generated module is loaded to ensure
1007    the user gets an immediate and clear error message when the provided
1008    checksum functions do not fit the expected template.
1009    """
1010    return dedent("""\
1011
1012        if not callable({checksum_name}):
1013            raise Exception('{checksum_name} is not callable')
1014    """).format(checksum_name=decl.id)
1015
1016
1017def run(input: argparse.FileType, output: argparse.FileType, custom_type_location: Optional[str]):
1018    file = ast.File.from_json(json.load(input))
1019    core.desugar(file)
1020
1021    custom_types = []
1022    custom_type_checks = ""
1023    for d in file.declarations:
1024        if isinstance(d, ast.CustomFieldDeclaration):
1025            custom_types.append(d.id)
1026            custom_type_checks += generate_custom_field_declaration_check(d)
1027        elif isinstance(d, ast.ChecksumDeclaration):
1028            custom_types.append(d.id)
1029            custom_type_checks += generate_checksum_declaration_check(d)
1030
1031    output.write(f"# File generated from {input.name}, with the command:\n")
1032    output.write(f"#  {' '.join(sys.argv)}\n")
1033    output.write("# /!\\ Do not edit by hand.\n")
1034    if custom_types and custom_type_location:
1035        output.write(f"\nfrom {custom_type_location} import {', '.join(custom_types)}\n")
1036    output.write(generate_prelude())
1037    output.write(custom_type_checks)
1038
1039    for d in file.declarations:
1040        if isinstance(d, ast.EnumDeclaration):
1041            output.write(generate_enum_declaration(d))
1042        elif isinstance(d, (ast.PacketDeclaration, ast.StructDeclaration)):
1043            output.write(generate_packet_declaration(d))
1044
1045
1046def main() -> int:
1047    """Generate python PDL backend."""
1048    parser = argparse.ArgumentParser(description=__doc__)
1049    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
1050    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file')
1051    parser.add_argument('--custom-type-location',
1052                        type=str,
1053                        required=False,
1054                        help='Module of declaration of custom types')
1055    return run(**vars(parser.parse_args()))
1056
1057
1058if __name__ == '__main__':
1059    sys.exit(main())
1060