• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3import argparse
4import collections
5import copy
6import json
7from pathlib import Path
8import pprint
9import traceback
10from typing import Iterable, List, Optional, Union
11import sys
12
13from pdl import ast, core
14
15MAX_ARRAY_SIZE = 256
16MAX_ARRAY_COUNT = 32
17DEFAULT_ARRAY_COUNT = 3
18DEFAULT_PAYLOAD_SIZE = 5
19
20
21class BitSerializer:
22    def __init__(self, big_endian: bool):
23        self.stream = []
24        self.value = 0
25        self.shift = 0
26        self.byteorder = "big" if big_endian else "little"
27
28    def append(self, value: int, width: int):
29        self.value = self.value | (value << self.shift)
30        self.shift += width
31
32        if (self.shift % 8) == 0:
33            width = int(self.shift / 8)
34            self.stream.extend(self.value.to_bytes(width, byteorder=self.byteorder))
35            self.shift = 0
36            self.value = 0
37
38
39class Value:
40    def __init__(self, value: object, width: Optional[int] = None):
41        self.value = value
42        if width is not None:
43            self.width = width
44        elif isinstance(value, int) or callable(value):
45            raise Exception("Creating scalar value of unspecified width")
46        elif isinstance(value, list):
47            self.width = sum([v.width for v in value])
48        elif isinstance(value, Packet):
49            self.width = value.width
50        else:
51            raise Exception(f"Malformed value {value}")
52
53    def finalize(self, parent: "Packet"):
54        if callable(self.width):
55            self.width = self.width(parent)
56
57        if callable(self.value):
58            self.value = self.value(parent)
59        elif isinstance(self.value, list):
60            for v in self.value:
61                v.finalize(parent)
62        elif isinstance(self.value, Packet):
63            self.value.finalize()
64
65    def serialize_(self, serializer: BitSerializer):
66        if isinstance(self.value, int):
67            serializer.append(self.value, self.width)
68        elif isinstance(self.value, list):
69            for v in self.value:
70                v.serialize_(serializer)
71        elif isinstance(self.value, Packet):
72            self.value.serialize_(serializer)
73        elif self.value == None:
74            pass
75        else:
76            raise Exception(f"Malformed value {self.value}")
77
78    def show(self, indent: int = 0):
79        space = " " * indent
80        if isinstance(self.value, int):
81            print(f"{space}{self.name}: {hex(self.value)}")
82        elif isinstance(self.value, list):
83            print(f"{space}{self.name}[{len(self.value)}]:")
84            for v in self.value:
85                v.show(indent + 2)
86        elif isinstance(self.value, Packet):
87            print(f"{space}{self.name}:")
88            self.value.show(indent + 2)
89
90    def to_json(self) -> object:
91        if isinstance(self.value, int):
92            return self.value
93        elif isinstance(self.value, list):
94            return [v.to_json() for v in self.value]
95        elif isinstance(self.value, Packet):
96            return self.value.to_json()
97
98
99class Field:
100    def __init__(self, value: Value, ref: ast.Field):
101        self.value = value
102        self.ref = ref
103
104    def finalize(self, parent: "Packet"):
105        self.value.finalize(parent)
106
107    def serialize_(self, serializer: BitSerializer):
108        self.value.serialize_(serializer)
109
110    def clone(self):
111        return Field(copy.copy(self.value), self.ref)
112
113
114class Packet:
115    def __init__(self, fields: List[Field], ref: ast.Declaration):
116        self.fields = fields
117        self.ref = ref
118
119    def finalize(self, parent: Optional["Packet"] = None):
120        for f in self.fields:
121            f.finalize(self)
122
123    def serialize_(self, serializer: BitSerializer):
124        for f in self.fields:
125            f.serialize_(serializer)
126
127    def serialize(self, big_endian: bool) -> bytes:
128        serializer = BitSerializer(big_endian)
129        self.serialize_(serializer)
130        if serializer.shift != 0:
131            raise Exception("The packet size is not an integral number of octets")
132        return bytes(serializer.stream)
133
134    def show(self, indent: int = 0):
135        for f in self.fields:
136            f.value.show(indent)
137
138    def to_json(self) -> dict:
139        result = dict()
140        for f in self.fields:
141            if isinstance(f.ref, (ast.PayloadField, ast.BodyField)) and isinstance(
142                f.value.value, Packet
143            ):
144                result.update(f.value.to_json())
145            elif isinstance(f.ref, (ast.PayloadField, ast.BodyField)):
146                result["payload"] = f.value.to_json()
147            elif hasattr(f.ref, "id"):
148                result[f.ref.id] = f.value.to_json()
149        return result
150
151    @property
152    def width(self) -> int:
153        self.finalize()
154        return sum([f.value.width for f in self.fields])
155
156
157class BitGenerator:
158    def __init__(self):
159        self.value = 0
160        self.shift = 0
161
162    def generate(self, width: int) -> Value:
163        """Generate an integer value of the selected width."""
164        value = 0
165        remains = width
166        while remains > 0:
167            w = min(8 - self.shift, remains)
168            mask = (1 << w) - 1
169            value = (value << w) | ((self.value >> self.shift) & mask)
170            remains -= w
171            self.shift += w
172            if self.shift >= 8:
173                self.shift = 0
174                self.value = (self.value + 1) % 0xFF
175        return Value(value, width)
176
177    def generate_list(self, width: int, count: int) -> List[Value]:
178        return [self.generate(width) for n in range(count)]
179
180
181generator = BitGenerator()
182
183
184def generate_cond_field_values(field: ast.ScalarField) -> List[Value]:
185    cond_value_present = field.cond_for.cond.value
186    cond_value_absent = 0 if field.cond_for.cond.value != 0 else 1
187
188    def get_cond_value(parent: Packet, field: ast.Field) -> int:
189        for f in parent.fields:
190            if f.ref is field:
191                return cond_value_absent if f.value.value is None else cond_value_present
192
193    return [Value(lambda p: get_cond_value(p, field.cond_for), field.width)]
194
195
196def generate_size_field_values(field: ast.SizeField) -> List[Value]:
197    def get_field_size(parent: Packet, field_id: str) -> int:
198        for f in parent.fields:
199            if (
200                (field_id == "_payload_" and isinstance(f.ref, ast.PayloadField))
201                or (field_id == "_body_" and isinstance(f.ref, ast.BodyField))
202                or (getattr(f.ref, "id", None) == field_id)
203            ):
204                assert f.value.width % 8 == 0
205                size_modifier = int(getattr(f.ref, "size_modifier", None) or 0)
206                return int(f.value.width / 8) + size_modifier
207        raise Exception(
208            "Field {} not found in packet {}".format(field_id, parent.ref.id)
209        )
210
211    return [Value(lambda p: get_field_size(p, field.field_id), field.width)]
212
213
214def generate_count_field_values(field: ast.CountField) -> List[Value]:
215    def get_array_count(parent: Packet, field_id: str) -> int:
216        for f in parent.fields:
217            if getattr(f.ref, "id", None) == field_id:
218                assert isinstance(f.value.value, list)
219                return len(f.value.value)
220        raise Exception(
221            "Field {} not found in packet {}".format(field_id, parent.ref.id)
222        )
223
224    return [Value(lambda p: get_array_count(p, field.field_id), field.width)]
225
226
227def generate_checksum_field_values(field: ast.TypedefField) -> List[Value]:
228    field_width = core.get_field_size(field)
229
230    def basic_checksum(input: bytes, width: int):
231        assert width == 8
232        return sum(input) % 256
233
234    def compute_checksum(parent: Packet, field_id: str) -> int:
235        serializer = None
236        for f in parent.fields:
237            if isinstance(f.ref, ast.ChecksumField) and f.ref.field_id == field_id:
238                serializer = BitSerializer(
239                    f.ref.parent.file.endianness.value == "big_endian"
240                )
241            elif isinstance(f.ref, ast.TypedefField) and f.ref.id == field_id:
242                return basic_checksum(serializer.stream, field_width)
243            elif serializer:
244                f.value.serialize_(serializer)
245        raise Exception("malformed checksum")
246
247    return [Value(lambda p: compute_checksum(p, field.id), field_width)]
248
249
250def generate_padding_field_values(field: ast.PaddingField) -> List[Value]:
251    preceding_field_id = field.padded_field.id
252
253    def get_padding(parent: Packet, field_id: str, width: int) -> List[Value]:
254        for f in parent.fields:
255            if (
256                (field_id == "_payload_" and isinstance(f.ref, ast.PayloadField))
257                or (field_id == "_body_" and isinstance(f.ref, ast.BodyField))
258                or (getattr(f.ref, "id", None) == field_id)
259            ):
260                assert f.value.width % 8 == 0
261                assert f.value.width <= width
262                return width - f.value.width
263        raise Exception(
264            "Field {} not found in packet {}".format(field_id, parent.ref.id)
265        )
266
267    return [Value(0, lambda p: get_padding(p, preceding_field_id, 8 * field.size))]
268
269
270def generate_payload_field_values(
271    field: Union[ast.PayloadField, ast.BodyField]
272) -> List[Value]:
273    payload_size = core.get_payload_field_size(field)
274    size_modifier = int(getattr(field, "size_modifier", None) or 0)
275
276    # If the paylaod has a size field, generate an empty payload and
277    # a payload of maximum size. If not generate a payload of the default size.
278    max_size = (1 << payload_size.width) - 1 if payload_size else DEFAULT_PAYLOAD_SIZE
279    max_size -= size_modifier
280
281    assert max_size > 0
282    return [Value([]), Value(generator.generate_list(8, max_size))]
283
284
285def generate_scalar_array_field_values(field: ast.ArrayField) -> List[Value]:
286    if field.width % 8 != 0:
287        if element_width % 8 != 0:
288            raise Exception("Array element size is not a multiple of 8")
289
290    array_size = core.get_array_field_size(field)
291    element_width = int(field.width / 8)
292
293    # TODO
294    # The array might also be bounded if it is included in the sized payload
295    # of a packet.
296
297    # Apply the size modifiers.
298    size_modifier = int(getattr(field, "size_modifier", None) or 0)
299
300    # The element width is known, and the array element count is known
301    # statically.
302    if isinstance(array_size, int):
303        return [Value(generator.generate_list(field.width, array_size))]
304
305    # The element width is known, and the array element count is known
306    # by count field.
307    elif isinstance(array_size, ast.CountField):
308        min_count = 0
309        max_count = (1 << array_size.width) - 1
310        return [Value([]), Value(generator.generate_list(field.width, max_count))]
311
312    # The element width is known, and the array full size is known
313    # by size field.
314    elif isinstance(array_size, ast.SizeField):
315        min_count = 0
316        max_size = (1 << array_size.width) - 1 - size_modifier
317        max_count = int(max_size / element_width)
318        return [Value([]), Value(generator.generate_list(field.width, max_count))]
319
320    # The element width is known, but the array size is unknown.
321    # Generate two arrays: one empty and one including some possible element
322    # values.
323    else:
324        return [
325            Value([]),
326            Value(generator.generate_list(field.width, DEFAULT_ARRAY_COUNT)),
327        ]
328
329
330def generate_typedef_array_field_values(field: ast.ArrayField) -> List[Value]:
331    array_size = core.get_array_field_size(field)
332    element_width = core.get_array_element_size(field)
333    if element_width:
334        if element_width % 8 != 0:
335            raise Exception("Array element size is not a multiple of 8")
336        element_width = int(element_width / 8)
337
338    # Generate element values to use for the generation.
339    type_decl = field.parent.file.typedef_scope[field.type_id]
340
341    def generate_list(count: Optional[int]) -> List[Value]:
342        """Generate an array of specified length.
343        If the count is None all possible array items are returned."""
344        element_values = generate_typedef_values(type_decl)
345
346        # Requested a variable count, send everything in one chunk.
347        if count is None:
348            return [Value(element_values)]
349        # Have more items than the requested count.
350        # Slice the possible array values in multiple slices.
351        elif len(element_values) > count:
352            # Add more elements in case of wrap-over.
353            elements_count = len(element_values)
354            element_values.extend(generate_typedef_values(type_decl))
355            chunk_count = int((len(elements) + count - 1) / count)
356            return [
357                Value(element_values[n * count : (n + 1) * count])
358                for n in range(chunk_count)
359            ]
360        # Have less items than the requested count.
361        # Generate additional items to fill the gap.
362        else:
363            chunk = element_values
364            while len(chunk) < count:
365                chunk.extend(generate_typedef_values(type_decl))
366            return [Value(chunk[:count])]
367
368    # TODO
369    # The array might also be bounded if it is included in the sized payload
370    # of a packet.
371
372    # Apply the size modifier.
373    size_modifier = int(getattr(field, "size_modifier", None) or 0)
374
375    min_size = 0
376    max_size = MAX_ARRAY_SIZE
377    min_count = 0
378    max_count = MAX_ARRAY_COUNT
379
380    if field.padded_size:
381        max_size = field.padded_size
382
383    if isinstance(array_size, ast.SizeField):
384        max_size = (1 << array_size.width) - 1 - size_modifier
385        min_size = size_modifier
386    elif isinstance(array_size, ast.CountField):
387        max_count = (1 << array_size.width) - 1
388    elif isinstance(array_size, int):
389        min_count = array_size
390        max_count = array_size
391
392    values = []
393    chunk = []
394    chunk_size = 0
395
396    while not values:
397        element_values = generate_typedef_values(type_decl)
398        for element_value in element_values:
399            element_size = int(element_value.width / 8)
400
401            if len(chunk) >= max_count or chunk_size + element_size > max_size:
402                assert len(chunk) >= min_count
403                values.append(Value(chunk))
404                chunk = []
405                chunk_size = 0
406
407            chunk.append(element_value)
408            chunk_size += element_size
409
410    if min_count == 0:
411        values.append(Value([]))
412
413    return values
414
415    # The element width is not known, but the array full octet size
416    # is known by size field. Generate two arrays: of minimal and maximum
417    # size. All unused element values are packed into arrays of varying size.
418    if element_width is None and isinstance(array_size, ast.SizeField):
419        element_values = generate_typedef_values(type_decl)
420        chunk = []
421        chunk_size = 0
422        values = [Value([])]
423        for element_value in element_values:
424            assert element_value.width % 8 == 0
425            element_size = int(element_value.width / 8)
426            if chunk_size + element_size > max_size:
427                values.append(Value(chunk))
428                chunk = []
429            chunk.append(element_value)
430            chunk_size += element_size
431        if chunk:
432            values.append(Value(chunk))
433        return values
434
435    # The element width is not known, but the array element count
436    # is known statically or by count field. Generate two arrays:
437    # of minimal and maximum length. All unused element values are packed into
438    # arrays of varying count.
439    elif element_width is None and isinstance(array_size, ast.CountField):
440        return [Value([])] + generate_list(max_count)
441
442    # The element width is not known, and the array element count is known
443    # statically.
444    elif element_width is None and isinstance(array_size, int):
445        return generate_list(array_size)
446
447    # Neither the count not size is known,
448    # generate two arrays: one empty and one including all possible element
449    # values.
450    elif element_width is None:
451        return [Value([])] + generate_list(None)
452
453    # The element width is known, and the array element count is known
454    # statically.
455    elif isinstance(array_size, int):
456        return generate_list(array_size)
457
458    # The element width is known, and the array element count is known
459    # by count field.
460    elif isinstance(array_size, ast.CountField):
461        return [Value([])] + generate_list(max_count)
462
463    # The element width is known, and the array full size is known
464    # by size field.
465    elif isinstance(array_size, ast.SizeField):
466        return [Value([])] + generate_list(max_count)
467
468    # The element width is known, but the array size is unknown.
469    # Generate two arrays: one empty and one including all possible element
470    # values.
471    else:
472        return [Value([])] + generate_list(None)
473
474
475def generate_array_field_values(field: ast.ArrayField) -> List[Value]:
476    if field.width is not None:
477        return generate_scalar_array_field_values(field)
478    else:
479        return generate_typedef_array_field_values(field)
480
481
482def generate_typedef_field_values(
483    field: ast.TypedefField, constraints: List[ast.Constraint]
484) -> List[Value]:
485    type_decl = field.parent.file.typedef_scope[field.type_id]
486
487    # Check for constraint on enum field.
488    if isinstance(type_decl, ast.EnumDeclaration):
489        for c in constraints:
490            if c.id == field.id:
491                for tag in type_decl.tags:
492                    if tag.id == c.tag_id:
493                        return [Value(tag.value, type_decl.width)]
494                raise Exception("undefined enum tag")
495
496    # Checksum field needs to known the checksum range.
497    if isinstance(type_decl, ast.ChecksumDeclaration):
498        return generate_checksum_field_values(field)
499
500    return generate_typedef_values(type_decl)
501
502
503def generate_field_values(
504    field: ast.Field, constraints: List[ast.Constraint], payload: Optional[List[Packet]]
505) -> List[Value]:
506    if field.cond_for:
507        return generate_cond_field_values(field)
508
509    elif isinstance(field, ast.ChecksumField):
510        # Checksum fields are just markers.
511        return [Value(0, 0)]
512
513    elif isinstance(field, ast.PaddingField):
514        return generate_padding_field_values(field)
515
516    elif isinstance(field, ast.SizeField):
517        return generate_size_field_values(field)
518
519    elif isinstance(field, ast.CountField):
520        return generate_count_field_values(field)
521
522    elif isinstance(field, (ast.BodyField, ast.PayloadField)) and payload:
523        return [Value(p) for p in payload]
524
525    elif isinstance(field, (ast.BodyField, ast.PayloadField)):
526        return generate_payload_field_values(field)
527
528    elif isinstance(field, ast.FixedField) and field.enum_id:
529        enum_decl = field.parent.file.typedef_scope[field.enum_id]
530        for tag in enum_decl.tags:
531            if tag.id == field.tag_id:
532                return [Value(tag.value, enum_decl.width)]
533        raise Exception("undefined enum tag")
534
535    elif isinstance(field, ast.FixedField) and field.width:
536        return [Value(field.value, field.width)]
537
538    elif isinstance(field, ast.ReservedField):
539        return [Value(0, field.width)]
540
541    elif isinstance(field, ast.ArrayField):
542        return generate_array_field_values(field)
543
544    elif isinstance(field, ast.ScalarField):
545        for c in constraints:
546            if c.id == field.id:
547                return [Value(c.value, field.width)]
548        mask = (1 << field.width) - 1
549        return [
550            Value(0, field.width),
551            Value(-1 & mask, field.width),
552            generator.generate(field.width),
553        ]
554
555    elif isinstance(field, ast.TypedefField):
556        return generate_typedef_field_values(field, constraints)
557
558    else:
559        raise Exception("unsupported field kind")
560
561
562def generate_fields(
563    decl: ast.Declaration,
564    constraints: List[ast.Constraint],
565    payload: Optional[List[Packet]],
566) -> List[List[Field]]:
567    fields = []
568    for f in decl.fields:
569        values = generate_field_values(f, constraints, payload)
570        optional_none = [] if not f.cond else [Field(Value(None, 0), f)]
571        fields.append(optional_none + [Field(v, f) for v in values])
572    return fields
573
574
575def generate_fields_recursive(
576    scope: dict,
577    decl: ast.Declaration,
578    constraints: List[ast.Constraint] = [],
579    payload: Optional[List[Packet]] = None,
580) -> List[List[Field]]:
581    fields = generate_fields(decl, constraints, payload)
582
583    if not decl.parent_id:
584        return fields
585
586    packets = [Packet(fields, decl) for fields in product(fields)]
587    parent_decl = scope[decl.parent_id]
588    return generate_fields_recursive(
589        scope, parent_decl, constraints + decl.constraints, payload=packets
590    )
591
592
593def generate_struct_values(decl: ast.StructDeclaration) -> List[Packet]:
594    fields = generate_fields_recursive(decl.file.typedef_scope, decl)
595    return [Packet(fields, decl) for fields in product(fields)]
596
597
598def generate_packet_values(decl: ast.PacketDeclaration) -> List[Packet]:
599    fields = generate_fields_recursive(decl.file.packet_scope, decl)
600    return [Packet(fields, decl) for fields in product(fields)]
601
602
603def generate_typedef_values(decl: ast.Declaration) -> List[Value]:
604    if isinstance(decl, ast.EnumDeclaration):
605        return [Value(t.value, decl.width) for t in decl.tags]
606
607    elif isinstance(decl, ast.ChecksumDeclaration):
608        raise Exception("ChecksumDeclaration handled in typedef field")
609
610    elif isinstance(decl, ast.CustomFieldDeclaration):
611        raise Exception("TODO custom field")
612
613    elif isinstance(decl, ast.StructDeclaration):
614        return [Value(p) for p in generate_struct_values(decl)]
615
616    else:
617        raise Exception("unsupported typedef declaration type")
618
619
620def product(fields: List[List[Field]]) -> List[List[Field]]:
621    """Perform a cartesian product of generated options for packet field values."""
622
623    def aux(vec: List[List[Field]]) -> List[List[Field]]:
624        if len(vec) == 0:
625            return [[]]
626        return [[item.clone()] + items for item in vec[0] for items in aux(vec[1:])]
627
628    count = 1
629    max_len = 0
630    for f in fields:
631        count *= len(f)
632        max_len = max(max_len, len(f))
633
634    # Limit products to 32 elements to prevent combinatorial
635    # explosion.
636    if count <= 32:
637        return aux(fields)
638
639    # If too many products, select samples which test all fields value
640    # values at the minimum.
641    else:
642        return [[f[idx % len(f)] for f in fields] for idx in range(0, max_len + 1)]
643
644
645def serialize_values(file: ast.File, values: List[Value]) -> List[dict]:
646    results = []
647    for v in values:
648        v.finalize()
649        packed = v.serialize(file.endianness.value == "big_endian")
650        result = {
651            "packed": "".join([f"{b:02x}" for b in packed]),
652            "unpacked": v.to_json(),
653        }
654        if v.ref.parent_id:
655            result["packet"] = v.ref.id
656        results.append(result)
657    return results
658
659
660def run(input: Path, packet: List[str]):
661    with open(input) as f:
662        file = ast.File.from_json(json.load(f))
663    core.desugar(file)
664
665    results = dict()
666    for decl in file.packet_scope.values():
667        if core.get_derived_packets(decl) or (packet and decl.id not in packet):
668            continue
669
670        try:
671            values = generate_packet_values(decl)
672            ancestor = core.get_packet_ancestor(decl)
673            results[ancestor.id] = results.get(ancestor.id, []) + serialize_values(
674                file, values
675            )
676        except Exception as exn:
677            print(
678                f"Skipping packet {decl.id}; cannot generate values: {exn}",
679                file=sys.stderr,
680            )
681
682    results = [{"packet": k, "tests": v} for (k, v) in results.items()]
683    json.dump(results, sys.stdout, indent=2)
684
685
686def main() -> int:
687    """Generate test vectors for top-level PDL packets."""
688    parser = argparse.ArgumentParser(description=__doc__)
689    parser.add_argument(
690        "--input", type=Path, required=True, help="Input PDL-JSON source"
691    )
692    parser.add_argument(
693        "--packet",
694        type=lambda x: x.split(","),
695        required=False,
696        action="extend",
697        default=[],
698        help="Select PDL packet to test",
699    )
700    return run(**vars(parser.parse_args()))
701
702
703if __name__ == "__main__":
704    sys.exit(main())
705