1#!/usr/bin/env python3 2 3import argparse 4import collections 5import copy 6import json 7from pathlib import Path 8import pprint 9import traceback 10from typing import Iterable, List, Optional, Union 11import sys 12 13from pdl import ast, core 14 15MAX_ARRAY_SIZE = 256 16MAX_ARRAY_COUNT = 32 17DEFAULT_ARRAY_COUNT = 3 18DEFAULT_PAYLOAD_SIZE = 5 19 20 21class BitSerializer: 22 def __init__(self, big_endian: bool): 23 self.stream = [] 24 self.value = 0 25 self.shift = 0 26 self.byteorder = "big" if big_endian else "little" 27 28 def append(self, value: int, width: int): 29 self.value = self.value | (value << self.shift) 30 self.shift += width 31 32 if (self.shift % 8) == 0: 33 width = int(self.shift / 8) 34 self.stream.extend(self.value.to_bytes(width, byteorder=self.byteorder)) 35 self.shift = 0 36 self.value = 0 37 38 39class Value: 40 def __init__(self, value: object, width: Optional[int] = None): 41 self.value = value 42 if width is not None: 43 self.width = width 44 elif isinstance(value, int) or callable(value): 45 raise Exception("Creating scalar value of unspecified width") 46 elif isinstance(value, list): 47 self.width = sum([v.width for v in value]) 48 elif isinstance(value, Packet): 49 self.width = value.width 50 else: 51 raise Exception(f"Malformed value {value}") 52 53 def finalize(self, parent: "Packet"): 54 if callable(self.width): 55 self.width = self.width(parent) 56 57 if callable(self.value): 58 self.value = self.value(parent) 59 elif isinstance(self.value, list): 60 for v in self.value: 61 v.finalize(parent) 62 elif isinstance(self.value, Packet): 63 self.value.finalize() 64 65 def serialize_(self, serializer: BitSerializer): 66 if isinstance(self.value, int): 67 serializer.append(self.value, self.width) 68 elif isinstance(self.value, list): 69 for v in self.value: 70 v.serialize_(serializer) 71 elif isinstance(self.value, Packet): 72 self.value.serialize_(serializer) 73 elif self.value == None: 74 pass 75 else: 76 raise Exception(f"Malformed value {self.value}") 77 78 def show(self, indent: int = 0): 79 space = " " * indent 80 if isinstance(self.value, int): 81 print(f"{space}{self.name}: {hex(self.value)}") 82 elif isinstance(self.value, list): 83 print(f"{space}{self.name}[{len(self.value)}]:") 84 for v in self.value: 85 v.show(indent + 2) 86 elif isinstance(self.value, Packet): 87 print(f"{space}{self.name}:") 88 self.value.show(indent + 2) 89 90 def to_json(self) -> object: 91 if isinstance(self.value, int): 92 return self.value 93 elif isinstance(self.value, list): 94 return [v.to_json() for v in self.value] 95 elif isinstance(self.value, Packet): 96 return self.value.to_json() 97 98 99class Field: 100 def __init__(self, value: Value, ref: ast.Field): 101 self.value = value 102 self.ref = ref 103 104 def finalize(self, parent: "Packet"): 105 self.value.finalize(parent) 106 107 def serialize_(self, serializer: BitSerializer): 108 self.value.serialize_(serializer) 109 110 def clone(self): 111 return Field(copy.copy(self.value), self.ref) 112 113 114class Packet: 115 def __init__(self, fields: List[Field], ref: ast.Declaration): 116 self.fields = fields 117 self.ref = ref 118 119 def finalize(self, parent: Optional["Packet"] = None): 120 for f in self.fields: 121 f.finalize(self) 122 123 def serialize_(self, serializer: BitSerializer): 124 for f in self.fields: 125 f.serialize_(serializer) 126 127 def serialize(self, big_endian: bool) -> bytes: 128 serializer = BitSerializer(big_endian) 129 self.serialize_(serializer) 130 if serializer.shift != 0: 131 raise Exception("The packet size is not an integral number of octets") 132 return bytes(serializer.stream) 133 134 def show(self, indent: int = 0): 135 for f in self.fields: 136 f.value.show(indent) 137 138 def to_json(self) -> dict: 139 result = dict() 140 for f in self.fields: 141 if isinstance(f.ref, (ast.PayloadField, ast.BodyField)) and isinstance( 142 f.value.value, Packet 143 ): 144 result.update(f.value.to_json()) 145 elif isinstance(f.ref, (ast.PayloadField, ast.BodyField)): 146 result["payload"] = f.value.to_json() 147 elif hasattr(f.ref, "id"): 148 result[f.ref.id] = f.value.to_json() 149 return result 150 151 @property 152 def width(self) -> int: 153 self.finalize() 154 return sum([f.value.width for f in self.fields]) 155 156 157class BitGenerator: 158 def __init__(self): 159 self.value = 0 160 self.shift = 0 161 162 def generate(self, width: int) -> Value: 163 """Generate an integer value of the selected width.""" 164 value = 0 165 remains = width 166 while remains > 0: 167 w = min(8 - self.shift, remains) 168 mask = (1 << w) - 1 169 value = (value << w) | ((self.value >> self.shift) & mask) 170 remains -= w 171 self.shift += w 172 if self.shift >= 8: 173 self.shift = 0 174 self.value = (self.value + 1) % 0xFF 175 return Value(value, width) 176 177 def generate_list(self, width: int, count: int) -> List[Value]: 178 return [self.generate(width) for n in range(count)] 179 180 181generator = BitGenerator() 182 183 184def generate_cond_field_values(field: ast.ScalarField) -> List[Value]: 185 cond_value_present = field.cond_for.cond.value 186 cond_value_absent = 0 if field.cond_for.cond.value != 0 else 1 187 188 def get_cond_value(parent: Packet, field: ast.Field) -> int: 189 for f in parent.fields: 190 if f.ref is field: 191 return cond_value_absent if f.value.value is None else cond_value_present 192 193 return [Value(lambda p: get_cond_value(p, field.cond_for), field.width)] 194 195 196def generate_size_field_values(field: ast.SizeField) -> List[Value]: 197 def get_field_size(parent: Packet, field_id: str) -> int: 198 for f in parent.fields: 199 if ( 200 (field_id == "_payload_" and isinstance(f.ref, ast.PayloadField)) 201 or (field_id == "_body_" and isinstance(f.ref, ast.BodyField)) 202 or (getattr(f.ref, "id", None) == field_id) 203 ): 204 assert f.value.width % 8 == 0 205 size_modifier = int(getattr(f.ref, "size_modifier", None) or 0) 206 return int(f.value.width / 8) + size_modifier 207 raise Exception( 208 "Field {} not found in packet {}".format(field_id, parent.ref.id) 209 ) 210 211 return [Value(lambda p: get_field_size(p, field.field_id), field.width)] 212 213 214def generate_count_field_values(field: ast.CountField) -> List[Value]: 215 def get_array_count(parent: Packet, field_id: str) -> int: 216 for f in parent.fields: 217 if getattr(f.ref, "id", None) == field_id: 218 assert isinstance(f.value.value, list) 219 return len(f.value.value) 220 raise Exception( 221 "Field {} not found in packet {}".format(field_id, parent.ref.id) 222 ) 223 224 return [Value(lambda p: get_array_count(p, field.field_id), field.width)] 225 226 227def generate_checksum_field_values(field: ast.TypedefField) -> List[Value]: 228 field_width = core.get_field_size(field) 229 230 def basic_checksum(input: bytes, width: int): 231 assert width == 8 232 return sum(input) % 256 233 234 def compute_checksum(parent: Packet, field_id: str) -> int: 235 serializer = None 236 for f in parent.fields: 237 if isinstance(f.ref, ast.ChecksumField) and f.ref.field_id == field_id: 238 serializer = BitSerializer( 239 f.ref.parent.file.endianness.value == "big_endian" 240 ) 241 elif isinstance(f.ref, ast.TypedefField) and f.ref.id == field_id: 242 return basic_checksum(serializer.stream, field_width) 243 elif serializer: 244 f.value.serialize_(serializer) 245 raise Exception("malformed checksum") 246 247 return [Value(lambda p: compute_checksum(p, field.id), field_width)] 248 249 250def generate_padding_field_values(field: ast.PaddingField) -> List[Value]: 251 preceding_field_id = field.padded_field.id 252 253 def get_padding(parent: Packet, field_id: str, width: int) -> List[Value]: 254 for f in parent.fields: 255 if ( 256 (field_id == "_payload_" and isinstance(f.ref, ast.PayloadField)) 257 or (field_id == "_body_" and isinstance(f.ref, ast.BodyField)) 258 or (getattr(f.ref, "id", None) == field_id) 259 ): 260 assert f.value.width % 8 == 0 261 assert f.value.width <= width 262 return width - f.value.width 263 raise Exception( 264 "Field {} not found in packet {}".format(field_id, parent.ref.id) 265 ) 266 267 return [Value(0, lambda p: get_padding(p, preceding_field_id, 8 * field.size))] 268 269 270def generate_payload_field_values( 271 field: Union[ast.PayloadField, ast.BodyField] 272) -> List[Value]: 273 payload_size = core.get_payload_field_size(field) 274 size_modifier = int(getattr(field, "size_modifier", None) or 0) 275 276 # If the paylaod has a size field, generate an empty payload and 277 # a payload of maximum size. If not generate a payload of the default size. 278 max_size = (1 << payload_size.width) - 1 if payload_size else DEFAULT_PAYLOAD_SIZE 279 max_size -= size_modifier 280 281 assert max_size > 0 282 return [Value([]), Value(generator.generate_list(8, max_size))] 283 284 285def generate_scalar_array_field_values(field: ast.ArrayField) -> List[Value]: 286 if field.width % 8 != 0: 287 if element_width % 8 != 0: 288 raise Exception("Array element size is not a multiple of 8") 289 290 array_size = core.get_array_field_size(field) 291 element_width = int(field.width / 8) 292 293 # TODO 294 # The array might also be bounded if it is included in the sized payload 295 # of a packet. 296 297 # Apply the size modifiers. 298 size_modifier = int(getattr(field, "size_modifier", None) or 0) 299 300 # The element width is known, and the array element count is known 301 # statically. 302 if isinstance(array_size, int): 303 return [Value(generator.generate_list(field.width, array_size))] 304 305 # The element width is known, and the array element count is known 306 # by count field. 307 elif isinstance(array_size, ast.CountField): 308 min_count = 0 309 max_count = (1 << array_size.width) - 1 310 return [Value([]), Value(generator.generate_list(field.width, max_count))] 311 312 # The element width is known, and the array full size is known 313 # by size field. 314 elif isinstance(array_size, ast.SizeField): 315 min_count = 0 316 max_size = (1 << array_size.width) - 1 - size_modifier 317 max_count = int(max_size / element_width) 318 return [Value([]), Value(generator.generate_list(field.width, max_count))] 319 320 # The element width is known, but the array size is unknown. 321 # Generate two arrays: one empty and one including some possible element 322 # values. 323 else: 324 return [ 325 Value([]), 326 Value(generator.generate_list(field.width, DEFAULT_ARRAY_COUNT)), 327 ] 328 329 330def generate_typedef_array_field_values(field: ast.ArrayField) -> List[Value]: 331 array_size = core.get_array_field_size(field) 332 element_width = core.get_array_element_size(field) 333 if element_width: 334 if element_width % 8 != 0: 335 raise Exception("Array element size is not a multiple of 8") 336 element_width = int(element_width / 8) 337 338 # Generate element values to use for the generation. 339 type_decl = field.parent.file.typedef_scope[field.type_id] 340 341 def generate_list(count: Optional[int]) -> List[Value]: 342 """Generate an array of specified length. 343 If the count is None all possible array items are returned.""" 344 element_values = generate_typedef_values(type_decl) 345 346 # Requested a variable count, send everything in one chunk. 347 if count is None: 348 return [Value(element_values)] 349 # Have more items than the requested count. 350 # Slice the possible array values in multiple slices. 351 elif len(element_values) > count: 352 # Add more elements in case of wrap-over. 353 elements_count = len(element_values) 354 element_values.extend(generate_typedef_values(type_decl)) 355 chunk_count = int((len(elements) + count - 1) / count) 356 return [ 357 Value(element_values[n * count : (n + 1) * count]) 358 for n in range(chunk_count) 359 ] 360 # Have less items than the requested count. 361 # Generate additional items to fill the gap. 362 else: 363 chunk = element_values 364 while len(chunk) < count: 365 chunk.extend(generate_typedef_values(type_decl)) 366 return [Value(chunk[:count])] 367 368 # TODO 369 # The array might also be bounded if it is included in the sized payload 370 # of a packet. 371 372 # Apply the size modifier. 373 size_modifier = int(getattr(field, "size_modifier", None) or 0) 374 375 min_size = 0 376 max_size = MAX_ARRAY_SIZE 377 min_count = 0 378 max_count = MAX_ARRAY_COUNT 379 380 if field.padded_size: 381 max_size = field.padded_size 382 383 if isinstance(array_size, ast.SizeField): 384 max_size = (1 << array_size.width) - 1 - size_modifier 385 min_size = size_modifier 386 elif isinstance(array_size, ast.CountField): 387 max_count = (1 << array_size.width) - 1 388 elif isinstance(array_size, int): 389 min_count = array_size 390 max_count = array_size 391 392 values = [] 393 chunk = [] 394 chunk_size = 0 395 396 while not values: 397 element_values = generate_typedef_values(type_decl) 398 for element_value in element_values: 399 element_size = int(element_value.width / 8) 400 401 if len(chunk) >= max_count or chunk_size + element_size > max_size: 402 assert len(chunk) >= min_count 403 values.append(Value(chunk)) 404 chunk = [] 405 chunk_size = 0 406 407 chunk.append(element_value) 408 chunk_size += element_size 409 410 if min_count == 0: 411 values.append(Value([])) 412 413 return values 414 415 # The element width is not known, but the array full octet size 416 # is known by size field. Generate two arrays: of minimal and maximum 417 # size. All unused element values are packed into arrays of varying size. 418 if element_width is None and isinstance(array_size, ast.SizeField): 419 element_values = generate_typedef_values(type_decl) 420 chunk = [] 421 chunk_size = 0 422 values = [Value([])] 423 for element_value in element_values: 424 assert element_value.width % 8 == 0 425 element_size = int(element_value.width / 8) 426 if chunk_size + element_size > max_size: 427 values.append(Value(chunk)) 428 chunk = [] 429 chunk.append(element_value) 430 chunk_size += element_size 431 if chunk: 432 values.append(Value(chunk)) 433 return values 434 435 # The element width is not known, but the array element count 436 # is known statically or by count field. Generate two arrays: 437 # of minimal and maximum length. All unused element values are packed into 438 # arrays of varying count. 439 elif element_width is None and isinstance(array_size, ast.CountField): 440 return [Value([])] + generate_list(max_count) 441 442 # The element width is not known, and the array element count is known 443 # statically. 444 elif element_width is None and isinstance(array_size, int): 445 return generate_list(array_size) 446 447 # Neither the count not size is known, 448 # generate two arrays: one empty and one including all possible element 449 # values. 450 elif element_width is None: 451 return [Value([])] + generate_list(None) 452 453 # The element width is known, and the array element count is known 454 # statically. 455 elif isinstance(array_size, int): 456 return generate_list(array_size) 457 458 # The element width is known, and the array element count is known 459 # by count field. 460 elif isinstance(array_size, ast.CountField): 461 return [Value([])] + generate_list(max_count) 462 463 # The element width is known, and the array full size is known 464 # by size field. 465 elif isinstance(array_size, ast.SizeField): 466 return [Value([])] + generate_list(max_count) 467 468 # The element width is known, but the array size is unknown. 469 # Generate two arrays: one empty and one including all possible element 470 # values. 471 else: 472 return [Value([])] + generate_list(None) 473 474 475def generate_array_field_values(field: ast.ArrayField) -> List[Value]: 476 if field.width is not None: 477 return generate_scalar_array_field_values(field) 478 else: 479 return generate_typedef_array_field_values(field) 480 481 482def generate_typedef_field_values( 483 field: ast.TypedefField, constraints: List[ast.Constraint] 484) -> List[Value]: 485 type_decl = field.parent.file.typedef_scope[field.type_id] 486 487 # Check for constraint on enum field. 488 if isinstance(type_decl, ast.EnumDeclaration): 489 for c in constraints: 490 if c.id == field.id: 491 for tag in type_decl.tags: 492 if tag.id == c.tag_id: 493 return [Value(tag.value, type_decl.width)] 494 raise Exception("undefined enum tag") 495 496 # Checksum field needs to known the checksum range. 497 if isinstance(type_decl, ast.ChecksumDeclaration): 498 return generate_checksum_field_values(field) 499 500 return generate_typedef_values(type_decl) 501 502 503def generate_field_values( 504 field: ast.Field, constraints: List[ast.Constraint], payload: Optional[List[Packet]] 505) -> List[Value]: 506 if field.cond_for: 507 return generate_cond_field_values(field) 508 509 elif isinstance(field, ast.ChecksumField): 510 # Checksum fields are just markers. 511 return [Value(0, 0)] 512 513 elif isinstance(field, ast.PaddingField): 514 return generate_padding_field_values(field) 515 516 elif isinstance(field, ast.SizeField): 517 return generate_size_field_values(field) 518 519 elif isinstance(field, ast.CountField): 520 return generate_count_field_values(field) 521 522 elif isinstance(field, (ast.BodyField, ast.PayloadField)) and payload: 523 return [Value(p) for p in payload] 524 525 elif isinstance(field, (ast.BodyField, ast.PayloadField)): 526 return generate_payload_field_values(field) 527 528 elif isinstance(field, ast.FixedField) and field.enum_id: 529 enum_decl = field.parent.file.typedef_scope[field.enum_id] 530 for tag in enum_decl.tags: 531 if tag.id == field.tag_id: 532 return [Value(tag.value, enum_decl.width)] 533 raise Exception("undefined enum tag") 534 535 elif isinstance(field, ast.FixedField) and field.width: 536 return [Value(field.value, field.width)] 537 538 elif isinstance(field, ast.ReservedField): 539 return [Value(0, field.width)] 540 541 elif isinstance(field, ast.ArrayField): 542 return generate_array_field_values(field) 543 544 elif isinstance(field, ast.ScalarField): 545 for c in constraints: 546 if c.id == field.id: 547 return [Value(c.value, field.width)] 548 mask = (1 << field.width) - 1 549 return [ 550 Value(0, field.width), 551 Value(-1 & mask, field.width), 552 generator.generate(field.width), 553 ] 554 555 elif isinstance(field, ast.TypedefField): 556 return generate_typedef_field_values(field, constraints) 557 558 else: 559 raise Exception("unsupported field kind") 560 561 562def generate_fields( 563 decl: ast.Declaration, 564 constraints: List[ast.Constraint], 565 payload: Optional[List[Packet]], 566) -> List[List[Field]]: 567 fields = [] 568 for f in decl.fields: 569 values = generate_field_values(f, constraints, payload) 570 optional_none = [] if not f.cond else [Field(Value(None, 0), f)] 571 fields.append(optional_none + [Field(v, f) for v in values]) 572 return fields 573 574 575def generate_fields_recursive( 576 scope: dict, 577 decl: ast.Declaration, 578 constraints: List[ast.Constraint] = [], 579 payload: Optional[List[Packet]] = None, 580) -> List[List[Field]]: 581 fields = generate_fields(decl, constraints, payload) 582 583 if not decl.parent_id: 584 return fields 585 586 packets = [Packet(fields, decl) for fields in product(fields)] 587 parent_decl = scope[decl.parent_id] 588 return generate_fields_recursive( 589 scope, parent_decl, constraints + decl.constraints, payload=packets 590 ) 591 592 593def generate_struct_values(decl: ast.StructDeclaration) -> List[Packet]: 594 fields = generate_fields_recursive(decl.file.typedef_scope, decl) 595 return [Packet(fields, decl) for fields in product(fields)] 596 597 598def generate_packet_values(decl: ast.PacketDeclaration) -> List[Packet]: 599 fields = generate_fields_recursive(decl.file.packet_scope, decl) 600 return [Packet(fields, decl) for fields in product(fields)] 601 602 603def generate_typedef_values(decl: ast.Declaration) -> List[Value]: 604 if isinstance(decl, ast.EnumDeclaration): 605 return [Value(t.value, decl.width) for t in decl.tags] 606 607 elif isinstance(decl, ast.ChecksumDeclaration): 608 raise Exception("ChecksumDeclaration handled in typedef field") 609 610 elif isinstance(decl, ast.CustomFieldDeclaration): 611 raise Exception("TODO custom field") 612 613 elif isinstance(decl, ast.StructDeclaration): 614 return [Value(p) for p in generate_struct_values(decl)] 615 616 else: 617 raise Exception("unsupported typedef declaration type") 618 619 620def product(fields: List[List[Field]]) -> List[List[Field]]: 621 """Perform a cartesian product of generated options for packet field values.""" 622 623 def aux(vec: List[List[Field]]) -> List[List[Field]]: 624 if len(vec) == 0: 625 return [[]] 626 return [[item.clone()] + items for item in vec[0] for items in aux(vec[1:])] 627 628 count = 1 629 max_len = 0 630 for f in fields: 631 count *= len(f) 632 max_len = max(max_len, len(f)) 633 634 # Limit products to 32 elements to prevent combinatorial 635 # explosion. 636 if count <= 32: 637 return aux(fields) 638 639 # If too many products, select samples which test all fields value 640 # values at the minimum. 641 else: 642 return [[f[idx % len(f)] for f in fields] for idx in range(0, max_len + 1)] 643 644 645def serialize_values(file: ast.File, values: List[Value]) -> List[dict]: 646 results = [] 647 for v in values: 648 v.finalize() 649 packed = v.serialize(file.endianness.value == "big_endian") 650 result = { 651 "packed": "".join([f"{b:02x}" for b in packed]), 652 "unpacked": v.to_json(), 653 } 654 if v.ref.parent_id: 655 result["packet"] = v.ref.id 656 results.append(result) 657 return results 658 659 660def run(input: Path, packet: List[str]): 661 with open(input) as f: 662 file = ast.File.from_json(json.load(f)) 663 core.desugar(file) 664 665 results = dict() 666 for decl in file.packet_scope.values(): 667 if core.get_derived_packets(decl) or (packet and decl.id not in packet): 668 continue 669 670 try: 671 values = generate_packet_values(decl) 672 ancestor = core.get_packet_ancestor(decl) 673 results[ancestor.id] = results.get(ancestor.id, []) + serialize_values( 674 file, values 675 ) 676 except Exception as exn: 677 print( 678 f"Skipping packet {decl.id}; cannot generate values: {exn}", 679 file=sys.stderr, 680 ) 681 682 results = [{"packet": k, "tests": v} for (k, v) in results.items()] 683 json.dump(results, sys.stdout, indent=2) 684 685 686def main() -> int: 687 """Generate test vectors for top-level PDL packets.""" 688 parser = argparse.ArgumentParser(description=__doc__) 689 parser.add_argument( 690 "--input", type=Path, required=True, help="Input PDL-JSON source" 691 ) 692 parser.add_argument( 693 "--packet", 694 type=lambda x: x.split(","), 695 required=False, 696 action="extend", 697 default=[], 698 help="Select PDL packet to test", 699 ) 700 return run(**vars(parser.parse_args())) 701 702 703if __name__ == "__main__": 704 sys.exit(main()) 705