• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Tests the generated python backend against standard PDL
18# constructs, with matching input vectors.
19
20import dataclasses
21import enum
22import json
23import typing
24import unittest
25from importlib import resources
26
27# (le|be)_backend are the names of the modules generated from the canonical
28# little endian and big endian test grammars. The purpose of this module
29# is to validate the generated parsers against the set of pre-generated
30# test vectors in canonical/(le|be)_test_vectors.json.
31import le_backend
32import be_backend
33
34
35SKIPPED_TESTS = [
36    "Packet_Array_Field_VariableElementSize_ConstantSize",
37    "Packet_Array_Field_VariableElementSize_VariableSize",
38    "Packet_Array_Field_VariableElementSize_VariableCount",
39    "Packet_Array_Field_VariableElementSize_UnknownSize",
40]
41
42
43def match_object(self, left, right):
44    """Recursively match a python class object against a reference
45    json object."""
46    if isinstance(right, int):
47        self.assertEqual(left, right)
48    elif isinstance(right, list):
49        self.assertEqual(len(left), len(right))
50        for n in range(len(right)):
51            match_object(self, left[n], right[n])
52    elif isinstance(right, dict):
53        for (k, v) in right.items():
54            self.assertTrue(hasattr(left, k))
55            match_object(self, getattr(left, k), v)
56
57
58def create_object(typ, value):
59    """Build an object of the selected type using the input value."""
60    if dataclasses.is_dataclass(typ):
61        field_types = dict([(f.name, f.type) for f in dataclasses.fields(typ)])
62        values = dict()
63        for (f, v) in value.items():
64            field_type = field_types[f]
65            values[f] = create_object(field_type, v)
66        return typ(**values)
67    elif typing.get_origin(typ) is list:
68        typ = typing.get_args(typ)[0]
69        return [create_object(typ, v) for v in value]
70    elif typing.get_origin(typ) is typing.Union:
71        # typing.Optional[int] expands to typing.Union[int, None]
72        typ = typing.get_args(typ)[0]
73        return create_object(typ, value) if value is not None else None
74    elif typ is bytes:
75        return bytes(value)
76    elif typ is bytearray:
77        return bytearray(value)
78    elif issubclass(typ, enum.Enum):
79        from_int = getattr(typ, 'from_int')
80        return from_int(value)
81    elif typ is int:
82        return value
83    else:
84        raise Exception(f"unsupported type annotation {typ}")
85
86
87class PacketParserTest(unittest.TestCase):
88    """Validate the generated parser against pre-generated test
89       vectors in canonical/(le|be)_test_vectors.json"""
90
91    def testLittleEndian(self):
92        with resources.files('tests.canonical').joinpath('le_test_vectors.json').open('r') as f:
93            reference = json.load(f)
94
95        for item in reference:
96            # 'packet' is the name of the packet being tested,
97            # 'tests' lists input vectors that must match the
98            # selected packet.
99            packet = item['packet']
100            tests = item['tests']
101
102            if packet in SKIPPED_TESTS:
103                continue
104
105            with self.subTest(packet=packet):
106                # Retrieve the class object from the generated
107                # module, in order to invoke the proper parse
108                # method for this test.
109                cls = getattr(le_backend, packet)
110                for test in tests:
111                    result = cls.parse_all(bytes.fromhex(test['packed']))
112                    match_object(self, result, test['unpacked'])
113
114    def testBigEndian(self):
115        with resources.files('tests.canonical').joinpath('be_test_vectors.json').open('r') as f:
116            reference = json.load(f)
117
118        for item in reference:
119            # 'packet' is the name of the packet being tested,
120            # 'tests' lists input vectors that must match the
121            # selected packet.
122            packet = item['packet']
123            tests = item['tests']
124
125            if packet in SKIPPED_TESTS:
126                continue
127
128            with self.subTest(packet=packet):
129                # Retrieve the class object from the generated
130                # module, in order to invoke the proper constructor
131                # method for this test.
132                cls = getattr(be_backend, packet)
133                for test in tests:
134                    result = cls.parse_all(bytes.fromhex(test['packed']))
135                    match_object(self, result, test['unpacked'])
136
137
138class PacketSerializerTest(unittest.TestCase):
139    """Validate the generated serializer against pre-generated test
140       vectors in canonical/(le|be)_test_vectors.json"""
141
142    def testLittleEndian(self):
143        with resources.files('tests.canonical').joinpath('le_test_vectors.json').open('r') as f:
144            reference = json.load(f)
145
146        for item in reference:
147            # 'packet' is the name of the packet being tested,
148            # 'tests' lists input vectors that must match the
149            # selected packet.
150            packet = item['packet']
151            tests = item['tests']
152
153            if packet in SKIPPED_TESTS:
154                continue
155
156            with self.subTest(packet=packet):
157                # Retrieve the class object from the generated
158                # module, in order to invoke the proper constructor
159                # method for this test.
160                for test in tests:
161                    cls = getattr(le_backend, test.get('packet', packet))
162                    obj = create_object(cls, test['unpacked'])
163                    result = obj.serialize()
164                    self.assertEqual(result, bytes.fromhex(test['packed']))
165
166    def testBigEndian(self):
167        with resources.files('tests.canonical').joinpath('be_test_vectors.json').open('r') as f:
168            reference = json.load(f)
169
170        for item in reference:
171            # 'packet' is the name of the packet being tested,
172            # 'tests' lists input vectors that must match the
173            # selected packet.
174            packet = item['packet']
175            tests = item['tests']
176
177            if packet in SKIPPED_TESTS:
178                continue
179
180            with self.subTest(packet=packet):
181                # Retrieve the class object from the generated
182                # module, in order to invoke the proper parse
183                # method for this test.
184                for test in tests:
185                    cls = getattr(be_backend, test.get('packet', packet))
186                    obj = create_object(cls, test['unpacked'])
187                    result = obj.serialize()
188                    self.assertEqual(result, bytes.fromhex(test['packed']))
189
190
191class CustomPacketParserTest(unittest.TestCase):
192    """Manual testing for custom fields."""
193
194    def testCustomField(self):
195        result = le_backend.Packet_Custom_Field_ConstantSize.parse_all([1])
196        self.assertEqual(result.a.value, 1)
197
198        result = le_backend.Packet_Custom_Field_VariableSize.parse_all([1])
199        self.assertEqual(result.a.value, 1)
200
201        result = le_backend.Struct_Custom_Field_ConstantSize.parse_all([1])
202        self.assertEqual(result.s.a.value, 1)
203
204        result = le_backend.Struct_Custom_Field_VariableSize.parse_all([1])
205        self.assertEqual(result.s.a.value, 1)
206
207        result = be_backend.Packet_Custom_Field_ConstantSize.parse_all([1])
208        self.assertEqual(result.a.value, 1)
209
210        result = be_backend.Packet_Custom_Field_VariableSize.parse_all([1])
211        self.assertEqual(result.a.value, 1)
212
213        result = be_backend.Struct_Custom_Field_ConstantSize.parse_all([1])
214        self.assertEqual(result.s.a.value, 1)
215
216        result = be_backend.Struct_Custom_Field_VariableSize.parse_all([1])
217        self.assertEqual(result.s.a.value, 1)
218
219
220if __name__ == '__main__':
221    unittest.main(verbosity=3)
222