1#!/usr/bin/env python3 2 3# Copyright 2023 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17# Tests the generated python backend against standard PDL 18# constructs, with matching input vectors. 19 20import dataclasses 21import enum 22import json 23import typing 24import unittest 25from importlib import resources 26 27# (le|be)_backend are the names of the modules generated from the canonical 28# little endian and big endian test grammars. The purpose of this module 29# is to validate the generated parsers against the set of pre-generated 30# test vectors in canonical/(le|be)_test_vectors.json. 31import le_backend 32import be_backend 33 34 35SKIPPED_TESTS = [ 36 "Packet_Array_Field_VariableElementSize_ConstantSize", 37 "Packet_Array_Field_VariableElementSize_VariableSize", 38 "Packet_Array_Field_VariableElementSize_VariableCount", 39 "Packet_Array_Field_VariableElementSize_UnknownSize", 40] 41 42 43def match_object(self, left, right): 44 """Recursively match a python class object against a reference 45 json object.""" 46 if isinstance(right, int): 47 self.assertEqual(left, right) 48 elif isinstance(right, list): 49 self.assertEqual(len(left), len(right)) 50 for n in range(len(right)): 51 match_object(self, left[n], right[n]) 52 elif isinstance(right, dict): 53 for (k, v) in right.items(): 54 self.assertTrue(hasattr(left, k)) 55 match_object(self, getattr(left, k), v) 56 57 58def create_object(typ, value): 59 """Build an object of the selected type using the input value.""" 60 if dataclasses.is_dataclass(typ): 61 field_types = dict([(f.name, f.type) for f in dataclasses.fields(typ)]) 62 values = dict() 63 for (f, v) in value.items(): 64 field_type = field_types[f] 65 values[f] = create_object(field_type, v) 66 return typ(**values) 67 elif typing.get_origin(typ) is list: 68 typ = typing.get_args(typ)[0] 69 return [create_object(typ, v) for v in value] 70 elif typing.get_origin(typ) is typing.Union: 71 # typing.Optional[int] expands to typing.Union[int, None] 72 typ = typing.get_args(typ)[0] 73 return create_object(typ, value) if value is not None else None 74 elif typ is bytes: 75 return bytes(value) 76 elif typ is bytearray: 77 return bytearray(value) 78 elif issubclass(typ, enum.Enum): 79 from_int = getattr(typ, 'from_int') 80 return from_int(value) 81 elif typ is int: 82 return value 83 else: 84 raise Exception(f"unsupported type annotation {typ}") 85 86 87class PacketParserTest(unittest.TestCase): 88 """Validate the generated parser against pre-generated test 89 vectors in canonical/(le|be)_test_vectors.json""" 90 91 def testLittleEndian(self): 92 with resources.files('tests.canonical').joinpath('le_test_vectors.json').open('r') as f: 93 reference = json.load(f) 94 95 for item in reference: 96 # 'packet' is the name of the packet being tested, 97 # 'tests' lists input vectors that must match the 98 # selected packet. 99 packet = item['packet'] 100 tests = item['tests'] 101 102 if packet in SKIPPED_TESTS: 103 continue 104 105 with self.subTest(packet=packet): 106 # Retrieve the class object from the generated 107 # module, in order to invoke the proper parse 108 # method for this test. 109 cls = getattr(le_backend, packet) 110 for test in tests: 111 result = cls.parse_all(bytes.fromhex(test['packed'])) 112 match_object(self, result, test['unpacked']) 113 114 def testBigEndian(self): 115 with resources.files('tests.canonical').joinpath('be_test_vectors.json').open('r') as f: 116 reference = json.load(f) 117 118 for item in reference: 119 # 'packet' is the name of the packet being tested, 120 # 'tests' lists input vectors that must match the 121 # selected packet. 122 packet = item['packet'] 123 tests = item['tests'] 124 125 if packet in SKIPPED_TESTS: 126 continue 127 128 with self.subTest(packet=packet): 129 # Retrieve the class object from the generated 130 # module, in order to invoke the proper constructor 131 # method for this test. 132 cls = getattr(be_backend, packet) 133 for test in tests: 134 result = cls.parse_all(bytes.fromhex(test['packed'])) 135 match_object(self, result, test['unpacked']) 136 137 138class PacketSerializerTest(unittest.TestCase): 139 """Validate the generated serializer against pre-generated test 140 vectors in canonical/(le|be)_test_vectors.json""" 141 142 def testLittleEndian(self): 143 with resources.files('tests.canonical').joinpath('le_test_vectors.json').open('r') as f: 144 reference = json.load(f) 145 146 for item in reference: 147 # 'packet' is the name of the packet being tested, 148 # 'tests' lists input vectors that must match the 149 # selected packet. 150 packet = item['packet'] 151 tests = item['tests'] 152 153 if packet in SKIPPED_TESTS: 154 continue 155 156 with self.subTest(packet=packet): 157 # Retrieve the class object from the generated 158 # module, in order to invoke the proper constructor 159 # method for this test. 160 for test in tests: 161 cls = getattr(le_backend, test.get('packet', packet)) 162 obj = create_object(cls, test['unpacked']) 163 result = obj.serialize() 164 self.assertEqual(result, bytes.fromhex(test['packed'])) 165 166 def testBigEndian(self): 167 with resources.files('tests.canonical').joinpath('be_test_vectors.json').open('r') as f: 168 reference = json.load(f) 169 170 for item in reference: 171 # 'packet' is the name of the packet being tested, 172 # 'tests' lists input vectors that must match the 173 # selected packet. 174 packet = item['packet'] 175 tests = item['tests'] 176 177 if packet in SKIPPED_TESTS: 178 continue 179 180 with self.subTest(packet=packet): 181 # Retrieve the class object from the generated 182 # module, in order to invoke the proper parse 183 # method for this test. 184 for test in tests: 185 cls = getattr(be_backend, test.get('packet', packet)) 186 obj = create_object(cls, test['unpacked']) 187 result = obj.serialize() 188 self.assertEqual(result, bytes.fromhex(test['packed'])) 189 190 191class CustomPacketParserTest(unittest.TestCase): 192 """Manual testing for custom fields.""" 193 194 def testCustomField(self): 195 result = le_backend.Packet_Custom_Field_ConstantSize.parse_all([1]) 196 self.assertEqual(result.a.value, 1) 197 198 result = le_backend.Packet_Custom_Field_VariableSize.parse_all([1]) 199 self.assertEqual(result.a.value, 1) 200 201 result = le_backend.Struct_Custom_Field_ConstantSize.parse_all([1]) 202 self.assertEqual(result.s.a.value, 1) 203 204 result = le_backend.Struct_Custom_Field_VariableSize.parse_all([1]) 205 self.assertEqual(result.s.a.value, 1) 206 207 result = be_backend.Packet_Custom_Field_ConstantSize.parse_all([1]) 208 self.assertEqual(result.a.value, 1) 209 210 result = be_backend.Packet_Custom_Field_VariableSize.parse_all([1]) 211 self.assertEqual(result.a.value, 1) 212 213 result = be_backend.Struct_Custom_Field_ConstantSize.parse_all([1]) 214 self.assertEqual(result.s.a.value, 1) 215 216 result = be_backend.Struct_Custom_Field_VariableSize.parse_all([1]) 217 self.assertEqual(result.s.a.value, 1) 218 219 220if __name__ == '__main__': 221 unittest.main(verbosity=3) 222