1#!/usr/bin/env python3 2 3# Copyright 2023 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import argparse 18from dataclasses import dataclass, field 19import json 20from pathlib import Path 21import sys 22from textwrap import dedent 23from typing import List, Tuple, Union, Optional 24 25from pdl import ast, core 26from pdl.utils import indent 27 28 29def mask(width: int) -> str: 30 return hex((1 << width) - 1) 31 32 33def generate_prelude() -> str: 34 return dedent("""\ 35 from dataclasses import dataclass, field, fields 36 from typing import Optional, List, Tuple 37 import enum 38 import inspect 39 import math 40 41 @dataclass 42 class Packet: 43 payload: Optional[bytes] = field(repr=False, default_factory=bytes, compare=False) 44 45 @classmethod 46 def parse_all(cls, span: bytes) -> 'Packet': 47 packet, remain = getattr(cls, 'parse')(span) 48 if len(remain) > 0: 49 raise Exception('Unexpected parsing remainder') 50 return packet 51 52 @property 53 def size(self) -> int: 54 pass 55 56 def show(self, prefix: str = ''): 57 print(f'{self.__class__.__name__}') 58 59 def print_val(p: str, pp: str, name: str, align: int, typ, val): 60 if name == 'payload': 61 pass 62 63 # Scalar fields. 64 elif typ is int: 65 print(f'{p}{name:{align}} = {val} (0x{val:x})') 66 67 # Byte fields. 68 elif typ is bytes: 69 print(f'{p}{name:{align}} = [', end='') 70 line = '' 71 n_pp = '' 72 for (idx, b) in enumerate(val): 73 if idx > 0 and idx % 8 == 0: 74 print(f'{n_pp}{line}') 75 line = '' 76 n_pp = pp + (' ' * (align + 4)) 77 line += f' {b:02x}' 78 print(f'{n_pp}{line} ]') 79 80 # Enum fields. 81 elif inspect.isclass(typ) and issubclass(typ, enum.IntEnum): 82 print(f'{p}{name:{align}} = {typ.__name__}::{val.name} (0x{val:x})') 83 84 # Struct fields. 85 elif inspect.isclass(typ) and issubclass(typ, globals().get('Packet')): 86 print(f'{p}{name:{align}} = ', end='') 87 val.show(prefix=pp) 88 89 # Array fields. 90 elif getattr(typ, '__origin__', None) == list: 91 print(f'{p}{name:{align}}') 92 last = len(val) - 1 93 align = 5 94 for (idx, elt) in enumerate(val): 95 n_p = pp + ('├── ' if idx != last else '└── ') 96 n_pp = pp + ('│ ' if idx != last else ' ') 97 print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx]) 98 99 # Custom fields. 100 elif inspect.isclass(typ): 101 print(f'{p}{name:{align}} = {repr(val)}') 102 103 else: 104 print(f'{p}{name:{align}} = ##{typ}##') 105 106 last = len(fields(self)) - 1 107 align = max(len(f.name) for f in fields(self) if f.name != 'payload') 108 109 for (idx, f) in enumerate(fields(self)): 110 p = prefix + ('├── ' if idx != last else '└── ') 111 pp = prefix + ('│ ' if idx != last else ' ') 112 val = getattr(self, f.name) 113 114 print_val(p, pp, f.name, align, f.type, val) 115 """) 116 117 118@dataclass 119class FieldParser: 120 byteorder: str 121 offset: int = 0 122 shift: int = 0 123 chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: []) 124 unchecked_code: List[str] = field(default_factory=lambda: []) 125 code: List[str] = field(default_factory=lambda: []) 126 127 def unchecked_append_(self, line: str): 128 """Append unchecked field parsing code. 129 The function check_size_ must be called to generate a size guard 130 after parsing is completed.""" 131 self.unchecked_code.append(line) 132 133 def append_(self, line: str): 134 """Append field parsing code. 135 There must be no unchecked code left before this function is called.""" 136 assert len(self.unchecked_code) == 0 137 self.code.append(line) 138 139 def check_size_(self, size: str): 140 """Generate a check of the current span size.""" 141 self.append_(f"if len(span) < {size}:") 142 self.append_(f" raise Exception('Invalid packet size')") 143 144 def check_code_(self): 145 """Generate a size check for pending field parsing.""" 146 if len(self.unchecked_code) > 0: 147 assert len(self.chunk) == 0 148 unchecked_code = self.unchecked_code 149 self.unchecked_code = [] 150 self.check_size_(str(self.offset)) 151 self.code.extend(unchecked_code) 152 153 def consume_span_(self, keep: int = 0) -> str: 154 """Skip consumed span bytes.""" 155 if self.offset > 0: 156 self.check_code_() 157 self.append_(f'span = span[{self.offset - keep}:]') 158 self.offset = 0 159 160 def parse_array_element_dynamic_(self, field: ast.ArrayField, span: str): 161 """Parse a single array field element of variable size.""" 162 if isinstance(field.type, ast.StructDeclaration): 163 self.append_(f" element, {span} = {field.type_id}.parse({span})") 164 self.append_(f" {field.id}.append(element)") 165 else: 166 raise Exception(f'Unexpected array element type {field.type_id} {field.width}') 167 168 def parse_array_element_static_(self, field: ast.ArrayField, span: str): 169 """Parse a single array field element of constant size.""" 170 if field.width is not None: 171 element = f"int.from_bytes({span}, byteorder='{self.byteorder}')" 172 self.append_(f" {field.id}.append({element})") 173 elif isinstance(field.type, ast.EnumDeclaration): 174 element = f"int.from_bytes({span}, byteorder='{self.byteorder}')" 175 element = f"{field.type_id}({element})" 176 self.append_(f" {field.id}.append({element})") 177 else: 178 element = f"{field.type_id}.parse_all({span})" 179 self.append_(f" {field.id}.append({element})") 180 181 def parse_byte_array_field_(self, field: ast.ArrayField): 182 """Parse the selected u8 array field.""" 183 array_size = core.get_array_field_size(field) 184 padded_size = field.padded_size 185 186 # Shift the span to reset the offset to 0. 187 self.consume_span_() 188 189 # Derive the array size. 190 if isinstance(array_size, int): 191 size = array_size 192 elif isinstance(array_size, ast.SizeField): 193 size = f'{field.id}_size - {field.size_modifier}' if field.size_modifier else f'{field.id}_size' 194 elif isinstance(array_size, ast.CountField): 195 size = f'{field.id}_count' 196 else: 197 size = None 198 199 # Parse from the padded array if padding is present. 200 if padded_size and size is not None: 201 self.check_size_(padded_size) 202 self.append_(f"if {size} > {padded_size}:") 203 self.append_(" raise Exception('Array size is larger than the padding size')") 204 self.append_(f"fields['{field.id}'] = list(span[:{size}])") 205 self.append_(f"span = span[{padded_size}:]") 206 207 elif size is not None: 208 self.check_size_(size) 209 self.append_(f"fields['{field.id}'] = list(span[:{size}])") 210 self.append_(f"span = span[{size}:]") 211 212 else: 213 self.append_(f"fields['{field.id}'] = list(span)") 214 self.append_(f"span = bytes()") 215 216 def parse_array_field_(self, field: ast.ArrayField): 217 """Parse the selected array field.""" 218 array_size = core.get_array_field_size(field) 219 element_width = core.get_array_element_size(field) 220 padded_size = field.padded_size 221 222 if element_width: 223 if element_width % 8 != 0: 224 raise Exception('Array element size is not a multiple of 8') 225 element_width = int(element_width / 8) 226 227 if isinstance(array_size, int): 228 size = None 229 count = array_size 230 elif isinstance(array_size, ast.SizeField): 231 size = f'{field.id}_size' 232 count = None 233 elif isinstance(array_size, ast.CountField): 234 size = None 235 count = f'{field.id}_count' 236 else: 237 size = None 238 count = None 239 240 # Shift the span to reset the offset to 0. 241 self.consume_span_() 242 243 # Apply the size modifier. 244 if field.size_modifier and size: 245 self.append_(f"{size} = {size} - {field.size_modifier}") 246 247 # Parse from the padded array if padding is present. 248 if padded_size: 249 self.check_size_(padded_size) 250 self.append_(f"remaining_span = span[{padded_size}:]") 251 self.append_(f"span = span[:{padded_size}]") 252 253 # The element width is not known, but the array full octet size 254 # is known by size field. Parse elements item by item as a vector. 255 if element_width is None and size is not None: 256 self.check_size_(size) 257 self.append_(f"array_span = span[:{size}]") 258 self.append_(f"{field.id} = []") 259 self.append_("while len(array_span) > 0:") 260 self.parse_array_element_dynamic_(field, 'array_span') 261 self.append_(f"fields['{field.id}'] = {field.id}") 262 self.append_(f"span = span[{size}:]") 263 264 # The element width is not known, but the array element count 265 # is known statically or by count field. 266 # Parse elements item by item as a vector. 267 elif element_width is None and count is not None: 268 self.append_(f"{field.id} = []") 269 self.append_(f"for n in range({count}):") 270 self.parse_array_element_dynamic_(field, 'span') 271 self.append_(f"fields['{field.id}'] = {field.id}") 272 273 # Neither the count not size is known, 274 # parse elements until the end of the span. 275 elif element_width is None: 276 self.append_(f"{field.id} = []") 277 self.append_("while len(span) > 0:") 278 self.parse_array_element_dynamic_(field, 'span') 279 self.append_(f"fields['{field.id}'] = {field.id}") 280 281 # The element width is known, and the array element count is known 282 # statically, or by count field. 283 elif count is not None: 284 array_size = (f'{count}' if element_width == 1 else f'{count} * {element_width}') 285 self.check_size_(array_size) 286 self.append_(f"{field.id} = []") 287 self.append_(f"for n in range({count}):") 288 span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]') 289 self.parse_array_element_static_(field, span) 290 self.append_(f"fields['{field.id}'] = {field.id}") 291 self.append_(f"span = span[{array_size}:]") 292 293 # The element width is known, and the array full size is known 294 # by size field, or unknown (in which case it is the remaining span 295 # length). 296 else: 297 if size is not None: 298 self.check_size_(size) 299 array_size = size or 'len(span)' 300 if element_width != 1: 301 self.append_(f"if {array_size} % {element_width} != 0:") 302 self.append_(" raise Exception('Array size is not a multiple of the element size')") 303 self.append_(f"{field.id}_count = int({array_size} / {element_width})") 304 array_count = f'{field.id}_count' 305 else: 306 array_count = array_size 307 self.append_(f"{field.id} = []") 308 self.append_(f"for n in range({array_count}):") 309 span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]') 310 self.parse_array_element_static_(field, span) 311 self.append_(f"fields['{field.id}'] = {field.id}") 312 if size is not None: 313 self.append_(f"span = span[{size}:]") 314 else: 315 self.append_(f"span = bytes()") 316 317 # Drop the padding 318 if padded_size: 319 self.append_(f"span = remaining_span") 320 321 def parse_bit_field_(self, field: ast.Field): 322 """Parse the selected field as a bit field. 323 The field is added to the current chunk. When a byte boundary 324 is reached all saved fields are extracted together.""" 325 326 # Add to current chunk. 327 width = core.get_field_size(field) 328 self.chunk.append((self.shift, width, field)) 329 self.shift += width 330 331 # Wait for more fields if not on a byte boundary. 332 if (self.shift % 8) != 0: 333 return 334 335 # Parse the backing integer using the configured endiannes, 336 # extract field values. 337 size = int(self.shift / 8) 338 end_offset = self.offset + size 339 340 if size == 1: 341 value = f"span[{self.offset}]" 342 else: 343 span = f"span[{self.offset}:{end_offset}]" 344 self.unchecked_append_(f"value_ = int.from_bytes({span}, byteorder='{self.byteorder}')") 345 value = "value_" 346 347 for shift, width, field in self.chunk: 348 v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}") 349 350 if isinstance(field, ast.ScalarField): 351 self.unchecked_append_(f"fields['{field.id}'] = {v}") 352 elif isinstance(field, ast.FixedField) and field.enum_id: 353 self.unchecked_append_(f"if {v} != {field.enum_id}.{field.tag_id}:") 354 self.unchecked_append_(f" raise Exception('Unexpected fixed field value')") 355 elif isinstance(field, ast.FixedField): 356 self.unchecked_append_(f"if {v} != {hex(field.value)}:") 357 self.unchecked_append_(f" raise Exception('Unexpected fixed field value')") 358 elif isinstance(field, ast.TypedefField): 359 self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}({v})") 360 elif isinstance(field, ast.SizeField): 361 self.unchecked_append_(f"{field.field_id}_size = {v}") 362 elif isinstance(field, ast.CountField): 363 self.unchecked_append_(f"{field.field_id}_count = {v}") 364 elif isinstance(field, ast.ReservedField): 365 pass 366 else: 367 raise Exception(f'Unsupported bit field type {field.kind}') 368 369 # Reset state. 370 self.offset = end_offset 371 self.shift = 0 372 self.chunk = [] 373 374 def parse_typedef_field_(self, field: ast.TypedefField): 375 """Parse a typedef field, to the exclusion of Enum fields.""" 376 377 if self.shift != 0: 378 raise Exception('Typedef field does not start on an octet boundary') 379 if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None): 380 raise Exception('Derived struct used in typedef field') 381 382 width = core.get_declaration_size(field.type) 383 if width is None: 384 self.consume_span_() 385 self.append_(f"{field.id}, span = {field.type_id}.parse(span)") 386 self.append_(f"fields['{field.id}'] = {field.id}") 387 else: 388 if width % 8 != 0: 389 raise Exception('Typedef field type size is not a multiple of 8') 390 width = int(width / 8) 391 end_offset = self.offset + width 392 # Checksum value field is generated alongside checksum start. 393 # Deal with this field as padding. 394 if not isinstance(field.type, ast.ChecksumDeclaration): 395 span = f'span[{self.offset}:{end_offset}]' 396 self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.parse_all({span})") 397 self.offset = end_offset 398 399 def parse_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]): 400 """Parse body and payload fields.""" 401 402 payload_size = core.get_payload_field_size(field) 403 offset_from_end = core.get_field_offset_from_end(field) 404 405 # If the payload is not byte aligned, do parse the bit fields 406 # that can be extracted, but do not consume the input bytes as 407 # they will also be included in the payload span. 408 if self.shift != 0: 409 if payload_size: 410 raise Exception("Unexpected payload size for non byte aligned payload") 411 412 rounded_size = int((self.shift + 7) / 8) 413 padding_bits = 8 * rounded_size - self.shift 414 self.parse_bit_field_(core.make_reserved_field(padding_bits)) 415 self.consume_span_(rounded_size) 416 else: 417 self.consume_span_() 418 419 # The payload or body has a known size. 420 # Consume the payload and update the span in case 421 # fields are placed after the payload. 422 if payload_size: 423 if getattr(field, 'size_modifier', None): 424 self.append_(f"{field.id}_size -= {field.size_modifier}") 425 self.check_size_(f'{field.id}_size') 426 self.append_(f"payload = span[:{field.id}_size]") 427 self.append_(f"span = span[{field.id}_size:]") 428 # The payload or body is the last field of a packet, 429 # consume the remaining span. 430 elif offset_from_end == 0: 431 self.append_(f"payload = span") 432 self.append_(f"span = bytes([])") 433 # The payload or body is followed by fields of static size. 434 # Consume the span that is not reserved for the following fields. 435 elif offset_from_end is not None: 436 if (offset_from_end % 8) != 0: 437 raise Exception('Payload field offset from end of packet is not a multiple of 8') 438 offset_from_end = int(offset_from_end / 8) 439 self.check_size_(f'{offset_from_end}') 440 self.append_(f"payload = span[:-{offset_from_end}]") 441 self.append_(f"span = span[-{offset_from_end}:]") 442 self.append_(f"fields['payload'] = payload") 443 444 def parse_checksum_field_(self, field: ast.ChecksumField): 445 """Generate a checksum check.""" 446 447 # The checksum value field can be read starting from the current 448 # offset if the fields in between are of fixed size, or from the end 449 # of the span otherwise. 450 self.consume_span_() 451 value_field = core.get_packet_field(field.parent, field.field_id) 452 offset_from_start = 0 453 offset_from_end = 0 454 start_index = field.parent.fields.index(field) 455 value_index = field.parent.fields.index(value_field) 456 value_size = int(core.get_field_size(value_field) / 8) 457 458 for f in field.parent.fields[start_index + 1:value_index]: 459 size = core.get_field_size(f) 460 if size is None: 461 offset_from_start = None 462 break 463 else: 464 offset_from_start += size 465 466 trailing_fields = field.parent.fields[value_index:] 467 trailing_fields.reverse() 468 for f in trailing_fields: 469 size = core.get_field_size(f) 470 if size is None: 471 offset_from_end = None 472 break 473 else: 474 offset_from_end += size 475 476 if offset_from_start is not None: 477 if offset_from_start % 8 != 0: 478 raise Exception('Checksum value field is not aligned to an octet boundary') 479 offset_from_start = int(offset_from_start / 8) 480 checksum_span = f'span[:{offset_from_start}]' 481 if value_size > 1: 482 start = offset_from_start 483 end = offset_from_start + value_size 484 value = f"int.from_bytes(span[{start}:{end}], byteorder='{self.byteorder}')" 485 else: 486 value = f'span[{offset_from_start}]' 487 self.check_size_(offset_from_start + value_size) 488 489 elif offset_from_end is not None: 490 sign = '' 491 if offset_from_end % 8 != 0: 492 raise Exception('Checksum value field is not aligned to an octet boundary') 493 offset_from_end = int(offset_from_end / 8) 494 checksum_span = f'span[:-{offset_from_end}]' 495 if value_size > 1: 496 start = offset_from_end 497 end = offset_from_end - value_size 498 value = f"int.from_bytes(span[-{start}:-{end}], byteorder='{self.byteorder}')" 499 else: 500 value = f'span[-{offset_from_end}]' 501 self.check_size_(offset_from_end) 502 503 else: 504 raise Exception('Checksum value field cannot be read at constant offset') 505 506 self.append_(f"{value_field.id} = {value}") 507 self.append_(f"fields['{value_field.id}'] = {value_field.id}") 508 self.append_(f"computed_{value_field.id} = {value_field.type.function}({checksum_span})") 509 self.append_(f"if computed_{value_field.id} != {value_field.id}:") 510 self.append_(" raise Exception(f'Invalid checksum computation:" + 511 f" {{computed_{value_field.id}}} != {{{value_field.id}}}')") 512 513 def parse(self, field: ast.Field): 514 # Field has bit granularity. 515 # Append the field to the current chunk, 516 # check if a byte boundary was reached. 517 if core.is_bit_field(field): 518 self.parse_bit_field_(field) 519 520 # Padding fields. 521 elif isinstance(field, ast.PaddingField): 522 pass 523 524 # Array fields. 525 elif isinstance(field, ast.ArrayField) and field.width == 8: 526 self.parse_byte_array_field_(field) 527 528 elif isinstance(field, ast.ArrayField): 529 self.parse_array_field_(field) 530 531 # Other typedef fields. 532 elif isinstance(field, ast.TypedefField): 533 self.parse_typedef_field_(field) 534 535 # Payload and body fields. 536 elif isinstance(field, (ast.PayloadField, ast.BodyField)): 537 self.parse_payload_field_(field) 538 539 # Checksum fields. 540 elif isinstance(field, ast.ChecksumField): 541 self.parse_checksum_field_(field) 542 543 else: 544 raise Exception(f'Unimplemented field type {field.kind}') 545 546 def done(self): 547 self.consume_span_() 548 549 550@dataclass 551class FieldSerializer: 552 byteorder: str 553 shift: int = 0 554 value: List[str] = field(default_factory=lambda: []) 555 code: List[str] = field(default_factory=lambda: []) 556 indent: int = 0 557 558 def indent_(self): 559 self.indent += 1 560 561 def unindent_(self): 562 self.indent -= 1 563 564 def append_(self, line: str): 565 """Append field serializing code.""" 566 lines = line.split('\n') 567 self.code.extend([' ' * self.indent + line for line in lines]) 568 569 def extend_(self, value: str, length: int): 570 """Append data to the span being constructed.""" 571 if length == 1: 572 self.append_(f"_span.append({value})") 573 else: 574 self.append_(f"_span.extend(int.to_bytes({value}, length={length}, byteorder='{self.byteorder}'))") 575 576 def serialize_array_element_(self, field: ast.ArrayField): 577 """Serialize a single array field element.""" 578 if field.width is not None: 579 length = int(field.width / 8) 580 self.extend_('_elt', length) 581 elif isinstance(field.type, ast.EnumDeclaration): 582 length = int(field.type.width / 8) 583 self.extend_('_elt', length) 584 else: 585 self.append_("_span.extend(_elt.serialize())") 586 587 def serialize_array_field_(self, field: ast.ArrayField): 588 """Serialize the selected array field.""" 589 if field.padded_size: 590 self.append_(f"_{field.id}_start = len(_span)") 591 592 if field.width == 8: 593 self.append_(f"_span.extend(self.{field.id})") 594 else: 595 self.append_(f"for _elt in self.{field.id}:") 596 self.indent_() 597 self.serialize_array_element_(field) 598 self.unindent_() 599 600 if field.padded_size: 601 self.append_(f"_span.extend([0] * ({field.padded_size} - len(_span) + _{field.id}_start))") 602 603 def serialize_bit_field_(self, field: ast.Field): 604 """Serialize the selected field as a bit field. 605 The field is added to the current chunk. When a byte boundary 606 is reached all saved fields are serialized together.""" 607 608 # Add to current chunk. 609 width = core.get_field_size(field) 610 shift = self.shift 611 612 if isinstance(field, str): 613 self.value.append(f"({field} << {shift})") 614 elif isinstance(field, ast.ScalarField): 615 max_value = (1 << field.width) - 1 616 self.append_(f"if self.{field.id} > {max_value}:") 617 self.append_(f" print(f\"Invalid value for field {field.parent.id}::{field.id}:" + 618 f" {{self.{field.id}}} > {max_value}; the value will be truncated\")") 619 self.append_(f" self.{field.id} &= {max_value}") 620 self.value.append(f"(self.{field.id} << {shift})") 621 elif isinstance(field, ast.FixedField) and field.enum_id: 622 self.value.append(f"({field.enum_id}.{field.tag_id} << {shift})") 623 elif isinstance(field, ast.FixedField): 624 self.value.append(f"({field.value} << {shift})") 625 elif isinstance(field, ast.TypedefField): 626 self.value.append(f"(self.{field.id} << {shift})") 627 628 elif isinstance(field, ast.SizeField): 629 max_size = (1 << field.width) - 1 630 value_field = core.get_packet_field(field.parent, field.field_id) 631 size_modifier = '' 632 633 if getattr(value_field, 'size_modifier', None): 634 size_modifier = f' + {value_field.size_modifier}' 635 636 if isinstance(value_field, (ast.PayloadField, ast.BodyField)): 637 self.append_(f"_payload_size = len(payload or self.payload or []){size_modifier}") 638 self.append_(f"if _payload_size > {max_size}:") 639 self.append_(f" print(f\"Invalid length for payload field:" + 640 f" {{_payload_size}} > {max_size}; the packet cannot be generated\")") 641 self.append_(f" raise Exception(\"Invalid payload length\")") 642 array_size = "_payload_size" 643 elif isinstance(value_field, ast.ArrayField) and value_field.width: 644 array_size = f"(len(self.{value_field.id}) * {int(value_field.width / 8)}{size_modifier})" 645 elif isinstance(value_field, ast.ArrayField) and isinstance(value_field.type, ast.EnumDeclaration): 646 array_size = f"(len(self.{value_field.id}) * {int(value_field.type.width / 8)}{size_modifier})" 647 elif isinstance(value_field, ast.ArrayField): 648 self.append_( 649 f"_{value_field.id}_size = sum([elt.size for elt in self.{value_field.id}]){size_modifier}") 650 array_size = f"_{value_field.id}_size" 651 else: 652 raise Exception("Unsupported field type") 653 self.value.append(f"({array_size} << {shift})") 654 655 elif isinstance(field, ast.CountField): 656 max_count = (1 << field.width) - 1 657 self.append_(f"if len(self.{field.field_id}) > {max_count}:") 658 self.append_(f" print(f\"Invalid length for field {field.parent.id}::{field.field_id}:" + 659 f" {{len(self.{field.field_id})}} > {max_count}; the array will be truncated\")") 660 self.append_(f" del self.{field.field_id}[{max_count}:]") 661 self.value.append(f"(len(self.{field.field_id}) << {shift})") 662 elif isinstance(field, ast.ReservedField): 663 pass 664 else: 665 raise Exception(f'Unsupported bit field type {field.kind}') 666 667 # Check if a byte boundary is reached. 668 self.shift += width 669 if (self.shift % 8) == 0: 670 self.pack_bit_fields_() 671 672 def pack_bit_fields_(self): 673 """Pack serialized bit fields.""" 674 675 # Should have an integral number of bytes now. 676 assert (self.shift % 8) == 0 677 678 # Generate the backing integer, and serialize it 679 # using the configured endiannes, 680 size = int(self.shift / 8) 681 682 if len(self.value) == 0: 683 self.append_(f"_span.extend([0] * {size})") 684 elif len(self.value) == 1: 685 self.extend_(self.value[0], size) 686 else: 687 self.append_(f"_value = (") 688 self.append_(" " + " |\n ".join(self.value)) 689 self.append_(")") 690 self.extend_('_value', size) 691 692 # Reset state. 693 self.shift = 0 694 self.value = [] 695 696 def serialize_typedef_field_(self, field: ast.TypedefField): 697 """Serialize a typedef field, to the exclusion of Enum fields.""" 698 699 if self.shift != 0: 700 raise Exception('Typedef field does not start on an octet boundary') 701 if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None): 702 raise Exception('Derived struct used in typedef field') 703 704 if isinstance(field.type, ast.ChecksumDeclaration): 705 size = int(field.type.width / 8) 706 self.append_(f"_checksum = {field.type.function}(_span[_checksum_start:])") 707 self.extend_('_checksum', size) 708 else: 709 self.append_(f"_span.extend(self.{field.id}.serialize())") 710 711 def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]): 712 """Serialize body and payload fields.""" 713 714 if self.shift != 0 and self.byteorder == 'big': 715 raise Exception('Payload field does not start on an octet boundary') 716 717 if self.shift == 0: 718 self.append_(f"_span.extend(payload or self.payload or [])") 719 else: 720 # Supported case of packet inheritance; 721 # the incomplete fields are serialized into 722 # the payload, rather than separately. 723 # First extract the padding bits from the payload, 724 # then recombine them with the bit fields to be serialized. 725 rounded_size = int((self.shift + 7) / 8) 726 padding_bits = 8 * rounded_size - self.shift 727 self.append_(f"_payload = payload or self.payload or bytes()") 728 self.append_(f"if len(_payload) < {rounded_size}:") 729 self.append_(f" raise Exception(f\"Invalid length for payload field:" + 730 f" {{len(_payload)}} < {rounded_size}\")") 731 self.append_( 732 f"_padding = int.from_bytes(_payload[:{rounded_size}], byteorder='{self.byteorder}') >> {self.shift}") 733 self.value.append(f"(_padding << {self.shift})") 734 self.shift += padding_bits 735 self.pack_bit_fields_() 736 self.append_(f"_span.extend(_payload[{rounded_size}:])") 737 738 def serialize_checksum_field_(self, field: ast.ChecksumField): 739 """Generate a checksum check.""" 740 741 self.append_("_checksum_start = len(_span)") 742 743 def serialize(self, field: ast.Field): 744 # Field has bit granularity. 745 # Append the field to the current chunk, 746 # check if a byte boundary was reached. 747 if core.is_bit_field(field): 748 self.serialize_bit_field_(field) 749 750 # Padding fields. 751 elif isinstance(field, ast.PaddingField): 752 pass 753 754 # Array fields. 755 elif isinstance(field, ast.ArrayField): 756 self.serialize_array_field_(field) 757 758 # Other typedef fields. 759 elif isinstance(field, ast.TypedefField): 760 self.serialize_typedef_field_(field) 761 762 # Payload and body fields. 763 elif isinstance(field, (ast.PayloadField, ast.BodyField)): 764 self.serialize_payload_field_(field) 765 766 # Checksum fields. 767 elif isinstance(field, ast.ChecksumField): 768 self.serialize_checksum_field_(field) 769 770 else: 771 raise Exception(f'Unimplemented field type {field.kind}') 772 773 774def generate_toplevel_packet_serializer(packet: ast.Declaration) -> List[str]: 775 """Generate the serialize() function for a toplevel Packet or Struct 776 declaration.""" 777 778 serializer = FieldSerializer(byteorder=packet.file.byteorder) 779 for f in packet.fields: 780 serializer.serialize(f) 781 return ['_span = bytearray()'] + serializer.code + ['return bytes(_span)'] 782 783 784def generate_derived_packet_serializer(packet: ast.Declaration) -> List[str]: 785 """Generate the serialize() function for a derived Packet or Struct 786 declaration.""" 787 788 packet_shift = core.get_packet_shift(packet) 789 if packet_shift and packet.file.byteorder == 'big': 790 raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift") 791 792 serializer = FieldSerializer(byteorder=packet.file.byteorder, shift=packet_shift) 793 for f in packet.fields: 794 serializer.serialize(f) 795 return ['_span = bytearray()' 796 ] + serializer.code + [f'return {packet.parent.id}.serialize(self, payload = bytes(_span))'] 797 798 799def generate_packet_parser(packet: ast.Declaration) -> List[str]: 800 """Generate the parse() function for a toplevel Packet or Struct 801 declaration.""" 802 803 packet_shift = core.get_packet_shift(packet) 804 if packet_shift and packet.file.byteorder == 'big': 805 raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift") 806 807 # Convert the packet constraints to a boolean expression. 808 validation = [] 809 if packet.constraints: 810 cond = [] 811 for c in packet.constraints: 812 if c.value is not None: 813 cond.append(f"fields['{c.id}'] != {hex(c.value)}") 814 else: 815 field = core.get_packet_field(packet, c.id) 816 cond.append(f"fields['{c.id}'] != {field.type_id}.{c.tag_id}") 817 818 validation = [f"if {' or '.join(cond)}:", " raise Exception(\"Invalid constraint field values\")"] 819 820 # Parse fields iteratively. 821 parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift) 822 for f in packet.fields: 823 parser.parse(f) 824 parser.done() 825 826 # Specialize to child packets. 827 children = core.get_derived_packets(packet) 828 decl = [] if packet.parent_id else ['fields = {\'payload\': None}'] 829 specialization = [] 830 831 if len(children) != 0: 832 # Try parsing every child packet successively until one is 833 # successfully parsed. Return a parsing error if none is valid. 834 # Return parent packet if no child packet matches. 835 # TODO: order child packets by decreasing size in case no constraint 836 # is given for specialization. 837 for _, child in children: 838 specialization.append("try:") 839 specialization.append(f" return {child.id}.parse(fields.copy(), payload)") 840 specialization.append("except Exception as exn:") 841 specialization.append(" pass") 842 843 return decl + validation + parser.code + specialization + [f"return {packet.id}(**fields), span"] 844 845 846def generate_packet_size_getter(packet: ast.Declaration) -> List[str]: 847 constant_width = 0 848 variable_width = [] 849 for f in packet.fields: 850 field_size = core.get_field_size(f) 851 if field_size is not None: 852 constant_width += field_size 853 elif isinstance(f, (ast.PayloadField, ast.BodyField)): 854 variable_width.append("len(self.payload)") 855 elif isinstance(f, ast.TypedefField): 856 variable_width.append(f"self.{f.id}.size") 857 elif isinstance(f, ast.ArrayField) and isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)): 858 variable_width.append(f"sum([elt.size for elt in self.{f.id}])") 859 elif isinstance(f, ast.ArrayField) and isinstance(f.type, ast.EnumDeclaration): 860 variable_width.append(f"len(self.{f.id}) * {f.type.width}") 861 elif isinstance(f, ast.ArrayField): 862 variable_width.append(f"len(self.{f.id}) * {int(f.width / 8)}") 863 else: 864 raise Exception("Unsupported field type") 865 866 constant_width = int(constant_width / 8) 867 if len(variable_width) == 0: 868 return [f"return {constant_width}"] 869 elif len(variable_width) == 1 and constant_width: 870 return [f"return {variable_width[0]} + {constant_width}"] 871 elif len(variable_width) == 1: 872 return [f"return {variable_width[0]}"] 873 elif len(variable_width) > 1 and constant_width: 874 return ([f"return {constant_width} + ("] + " +\n ".join(variable_width).split("\n") + [")"]) 875 elif len(variable_width) > 1: 876 return (["return ("] + " +\n ".join(variable_width).split("\n") + [")"]) 877 else: 878 assert False 879 880 881def generate_packet_post_init(decl: ast.Declaration) -> List[str]: 882 """Generate __post_init__ function to set constraint field values.""" 883 884 # Gather all constraints from parent packets. 885 constraints = [] 886 current = decl 887 while current.parent_id: 888 constraints.extend(current.constraints) 889 current = current.parent 890 891 if constraints: 892 code = [] 893 for c in constraints: 894 if c.value is not None: 895 code.append(f"self.{c.id} = {c.value}") 896 else: 897 field = core.get_packet_field(decl, c.id) 898 code.append(f"self.{c.id} = {field.type_id}.{c.tag_id}") 899 return code 900 901 else: 902 return ["pass"] 903 904 905def generate_enum_declaration(decl: ast.EnumDeclaration) -> str: 906 """Generate the implementation of an enum type.""" 907 908 enum_name = decl.id 909 tag_decls = [] 910 for t in decl.tags: 911 tag_decls.append(f"{t.id} = {hex(t.value)}") 912 913 return dedent("""\ 914 915 class {enum_name}(enum.IntEnum): 916 {tag_decls} 917 """).format(enum_name=enum_name, tag_decls=indent(tag_decls, 1)) 918 919 920def generate_packet_declaration(packet: ast.Declaration) -> str: 921 """Generate the implementation a toplevel Packet or Struct 922 declaration.""" 923 924 packet_name = packet.id 925 field_decls = [] 926 for f in packet.fields: 927 if isinstance(f, ast.ScalarField): 928 field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)") 929 elif isinstance(f, ast.TypedefField): 930 if isinstance(f.type, ast.EnumDeclaration): 931 field_decls.append( 932 f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type_id}.{f.type.tags[0].id})") 933 elif isinstance(f.type, ast.ChecksumDeclaration): 934 field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)") 935 elif isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)): 936 field_decls.append(f"{f.id}: {f.type_id} = field(kw_only=True, default_factory={f.type_id})") 937 else: 938 raise Exception("Unsupported typedef field type") 939 elif isinstance(f, ast.ArrayField) and f.width == 8: 940 field_decls.append(f"{f.id}: bytearray = field(kw_only=True, default_factory=bytearray)") 941 elif isinstance(f, ast.ArrayField) and f.width: 942 field_decls.append(f"{f.id}: List[int] = field(kw_only=True, default_factory=list)") 943 elif isinstance(f, ast.ArrayField) and f.type_id: 944 field_decls.append(f"{f.id}: List[{f.type_id}] = field(kw_only=True, default_factory=list)") 945 946 if packet.parent_id: 947 parent_name = packet.parent_id 948 parent_fields = 'fields: dict, ' 949 serializer = generate_derived_packet_serializer(packet) 950 else: 951 parent_name = 'Packet' 952 parent_fields = '' 953 serializer = generate_toplevel_packet_serializer(packet) 954 955 parser = generate_packet_parser(packet) 956 size = generate_packet_size_getter(packet) 957 post_init = generate_packet_post_init(packet) 958 959 return dedent("""\ 960 961 @dataclass 962 class {packet_name}({parent_name}): 963 {field_decls} 964 965 def __post_init__(self): 966 {post_init} 967 968 @staticmethod 969 def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]: 970 {parser} 971 972 def serialize(self, payload: bytes = None) -> bytes: 973 {serializer} 974 975 @property 976 def size(self) -> int: 977 {size} 978 """).format(packet_name=packet_name, 979 parent_name=parent_name, 980 parent_fields=parent_fields, 981 field_decls=indent(field_decls, 1), 982 post_init=indent(post_init, 2), 983 parser=indent(parser, 2), 984 serializer=indent(serializer, 2), 985 size=indent(size, 2)) 986 987 988def generate_custom_field_declaration_check(decl: ast.CustomFieldDeclaration) -> str: 989 """Generate the code to validate a user custom field implementation. 990 991 This code is to be executed when the generated module is loaded to ensure 992 the user gets an immediate and clear error message when the provided 993 custom types do not fit the expected template. 994 """ 995 return dedent("""\ 996 997 if (not callable(getattr({custom_field_name}, 'parse', None)) or 998 not callable(getattr({custom_field_name}, 'parse_all', None))): 999 raise Exception('The custom field type {custom_field_name} does not implement the parse method') 1000 """).format(custom_field_name=decl.id) 1001 1002 1003def generate_checksum_declaration_check(decl: ast.ChecksumDeclaration) -> str: 1004 """Generate the code to validate a user checksum field implementation. 1005 1006 This code is to be executed when the generated module is loaded to ensure 1007 the user gets an immediate and clear error message when the provided 1008 checksum functions do not fit the expected template. 1009 """ 1010 return dedent("""\ 1011 1012 if not callable({checksum_name}): 1013 raise Exception('{checksum_name} is not callable') 1014 """).format(checksum_name=decl.id) 1015 1016 1017def run(input: argparse.FileType, output: argparse.FileType, custom_type_location: Optional[str]): 1018 file = ast.File.from_json(json.load(input)) 1019 core.desugar(file) 1020 1021 custom_types = [] 1022 custom_type_checks = "" 1023 for d in file.declarations: 1024 if isinstance(d, ast.CustomFieldDeclaration): 1025 custom_types.append(d.id) 1026 custom_type_checks += generate_custom_field_declaration_check(d) 1027 elif isinstance(d, ast.ChecksumDeclaration): 1028 custom_types.append(d.id) 1029 custom_type_checks += generate_checksum_declaration_check(d) 1030 1031 output.write(f"# File generated from {input.name}, with the command:\n") 1032 output.write(f"# {' '.join(sys.argv)}\n") 1033 output.write("# /!\\ Do not edit by hand.\n") 1034 if custom_types and custom_type_location: 1035 output.write(f"\nfrom {custom_type_location} import {', '.join(custom_types)}\n") 1036 output.write(generate_prelude()) 1037 output.write(custom_type_checks) 1038 1039 for d in file.declarations: 1040 if isinstance(d, ast.EnumDeclaration): 1041 output.write(generate_enum_declaration(d)) 1042 elif isinstance(d, (ast.PacketDeclaration, ast.StructDeclaration)): 1043 output.write(generate_packet_declaration(d)) 1044 1045 1046def main() -> int: 1047 """Generate python PDL backend.""" 1048 parser = argparse.ArgumentParser(description=__doc__) 1049 parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') 1050 parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file') 1051 parser.add_argument('--custom-type-location', 1052 type=str, 1053 required=False, 1054 help='Module of declaration of custom types') 1055 return run(**vars(parser.parse_args())) 1056 1057 1058if __name__ == '__main__': 1059 sys.exit(main()) 1060