1#!/usr/bin/env python3 2 3# Copyright 2023 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import argparse 18from dataclasses import dataclass, field 19import json 20from pathlib import Path 21import sys 22from textwrap import dedent 23from typing import List, Tuple, Union, Optional 24 25from pdl import ast, core 26from pdl.utils import indent, to_pascal_case 27 28 29def get_cxx_scalar_type(width: int) -> str: 30 """Return the cxx scalar type to be used to back a PDL type.""" 31 for n in [8, 16, 32, 64]: 32 if width <= n: 33 return f'uint{n}_t' 34 # PDL type does not fit on non-extended scalar types. 35 assert False 36 37 38def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str: 39 """Generate the implementation of unit tests for the selected packet.""" 40 41 def parse_packet(packet: ast.PacketDeclaration) -> str: 42 parent = parse_packet(packet.parent) if packet.parent else "input" 43 return f"{packet.id}View::Create({parent})" 44 45 def input_bytes(input: str) -> List[str]: 46 input = bytes.fromhex(input) 47 input_bytes = [] 48 for i in range(0, len(input), 16): 49 input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16])) 50 return input_bytes 51 52 def get_field(decl: ast.Declaration, var: str, id: str) -> str: 53 if isinstance(decl, ast.StructDeclaration): 54 return f"{var}.{id}_" 55 else: 56 return f"{var}.Get{to_pascal_case(id)}()" 57 58 def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]: 59 checks = [] 60 for (id, value) in expected.items(): 61 field = core.get_packet_field(decl, id) 62 sanitized_var = var.replace('[', '_').replace(']', '') 63 field_var = f'{sanitized_var}_{id}' 64 65 if isinstance(field, ast.ScalarField): 66 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") 67 68 elif (isinstance(field, ast.TypedefField) and 69 isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))): 70 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));") 71 72 elif isinstance(field, ast.TypedefField): 73 checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};") 74 checks.extend(check_members(field.type, field_var, value)) 75 76 elif isinstance(field, (ast.PayloadField, ast.BodyField)): 77 checks.append(f"std::vector<uint8_t> expected_{field_var} {{") 78 for i in range(0, len(value), 16): 79 checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) 80 checks.append("};") 81 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") 82 83 elif isinstance(field, ast.ArrayField) and field.width: 84 checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{") 85 step = int(16 * 8 / field.width) 86 for i in range(0, len(value), step): 87 checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) 88 checks.append("};") 89 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") 90 91 elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)): 92 checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{") 93 for v in value: 94 checks.append(f" {field.type_id}({v}),") 95 checks.append("};") 96 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") 97 98 elif isinstance(field, ast.ArrayField): 99 checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};") 100 checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") 101 for (n, value) in enumerate(value): 102 checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) 103 104 else: 105 pass 106 107 return checks 108 109 generated_tests = [] 110 for (test_nr, test) in enumerate(tests): 111 child_packet_id = test.get('packet', packet.id) 112 child_packet = packet.file.packet_scope[child_packet_id] 113 114 generated_tests.append( 115 dedent("""\ 116 117 TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{ 118 pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{ 119 {input_bytes} 120 }})); 121 {child_packet_id}View packet = {parse_packet}; 122 ASSERT_TRUE(packet.IsValid()); 123 {checks} 124 }} 125 """).format(parser_test_suite=parser_test_suite, 126 packet_id=packet.id, 127 child_packet_id=child_packet_id, 128 test_nr=test_nr, 129 input_bytes=indent(input_bytes(test['packed']), 2), 130 parse_packet=parse_packet(child_packet), 131 checks=indent(check_members(packet, 'packet', test['unpacked']), 1))) 132 133 return ''.join(generated_tests) 134 135 136def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration, 137 tests: List[object]) -> str: 138 """Generate the implementation of unit tests for the selected packet.""" 139 140 def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]): 141 fields = core.get_unconstrained_parent_fields(decl) + decl.fields 142 declarations = [] 143 parameters = [] 144 for field in fields: 145 sanitized_var = var.replace('[', '_').replace(']', '') 146 field_id = getattr(field, 'id', None) 147 field_var = f'{sanitized_var}_{field_id}' 148 value = initializer['payload'] if isinstance(field, (ast.PayloadField, 149 ast.BodyField)) else initializer.get(field_id, None) 150 151 if isinstance(field, ast.ScalarField): 152 parameters.append(f"{value}") 153 154 elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): 155 parameters.append(f"{field.type_id}({value})") 156 157 elif isinstance(field, ast.TypedefField): 158 (element, intermediate_declarations) = build_packet(field.type, field_var, value) 159 declarations.extend(intermediate_declarations) 160 parameters.append(element) 161 162 elif isinstance(field, (ast.PayloadField, ast.BodyField)): 163 declarations.append(f"std::vector<uint8_t> {field_var} {{") 164 for i in range(0, len(value), 16): 165 declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) 166 declarations.append("};") 167 parameters.append(f"std::move({field_var})") 168 169 elif isinstance(field, ast.ArrayField) and field.width: 170 declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{") 171 step = int(16 * 8 / field.width) 172 for i in range(0, len(value), step): 173 declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) 174 declarations.append("};") 175 parameters.append(f"std::move({field_var})") 176 177 elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration): 178 declarations.append(f"std::vector<{field.type_id}> {field_var} {{") 179 for v in value: 180 declarations.append(f" {field.type_id}({v}),") 181 declarations.append("};") 182 parameters.append(f"std::move({field_var})") 183 184 elif isinstance(field, ast.ArrayField): 185 elements = [] 186 for (n, value) in enumerate(value): 187 (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) 188 elements.append(element) 189 declarations.extend(intermediate_declarations) 190 declarations.append(f"std::vector<{field.type_id}> {field_var} {{") 191 for element in elements: 192 declarations.append(f" {element},") 193 declarations.append("};") 194 parameters.append(f"std::move({field_var})") 195 196 else: 197 pass 198 199 constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id 200 return (f"{constructor_name}({', '.join(parameters)})", declarations) 201 202 def output_bytes(output: str) -> List[str]: 203 output = bytes.fromhex(output) 204 output_bytes = [] 205 for i in range(0, len(output), 16): 206 output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16])) 207 return output_bytes 208 209 generated_tests = [] 210 for (test_nr, test) in enumerate(tests): 211 child_packet_id = test.get('packet', packet.id) 212 child_packet = packet.file.packet_scope[child_packet_id] 213 214 (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked']) 215 generated_tests.append( 216 dedent("""\ 217 218 TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{ 219 std::vector<uint8_t> expected_output {{ 220 {output_bytes} 221 }}; 222 {intermediate_declarations} 223 {child_packet_id}Builder packet = {built_packet}; 224 ASSERT_EQ(packet.pdl::packet::Builder::Serialize(), expected_output); 225 }} 226 """).format(serializer_test_suite=serializer_test_suite, 227 packet_id=packet.id, 228 child_packet_id=child_packet_id, 229 test_nr=test_nr, 230 output_bytes=indent(output_bytes(test['packed']), 2), 231 built_packet=built_packet, 232 intermediate_declarations=indent(intermediate_declarations, 1))) 233 234 return ''.join(generated_tests) 235 236 237def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str], 238 using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str): 239 240 file = ast.File.from_json(json.load(input)) 241 tests = json.load(test_vectors) 242 core.desugar(file) 243 244 include_header = '\n'.join([f'#include <{header}>' for header in include_header]) 245 using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace]) 246 247 skipped_tests = [ 248 'Packet_Checksum_Field_FromStart', 249 'Packet_Checksum_Field_FromEnd', 250 'Struct_Checksum_Field_FromStart', 251 'Struct_Checksum_Field_FromEnd', 252 'PartialParent5', 253 'PartialParent12', 254 ] 255 256 output.write( 257 dedent("""\ 258 // File generated from {input_name} and {test_vectors_name}, with the command: 259 // {input_command} 260 // /!\\ Do not edit by hand 261 262 #include <cstdint> 263 #include <string> 264 #include <gtest/gtest.h> 265 #include <packet_runtime.h> 266 267 {include_header} 268 {using_namespace} 269 270 namespace {namespace} {{ 271 272 class {parser_test_suite} : public testing::Test {{}}; 273 class {serializer_test_suite} : public testing::Test {{}}; 274 """).format(parser_test_suite=parser_test_suite, 275 serializer_test_suite=serializer_test_suite, 276 input_name=input.name, 277 input_command=' '.join(sys.argv), 278 test_vectors_name=test_vectors.name, 279 include_header=include_header, 280 using_namespace=using_namespace, 281 namespace=namespace)) 282 283 for decl in file.declarations: 284 if decl.id in skipped_tests: 285 continue 286 287 if isinstance(decl, ast.PacketDeclaration): 288 matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id] 289 matching_tests = [test for test_list in matching_tests for test in test_list] 290 if matching_tests: 291 output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests)) 292 output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests)) 293 294 output.write(f"}} // namespace {namespace}\n") 295 296 297def main() -> int: 298 """Generate cxx PDL backend.""" 299 parser = argparse.ArgumentParser(description=__doc__) 300 parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') 301 parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file') 302 parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file') 303 parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file') 304 parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite') 305 parser.add_argument('--serializer-test-suite', 306 type=str, 307 default='SerializerTest', 308 help='Name of the serializer test suite') 309 parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives') 310 parser.add_argument('--using-namespace', 311 type=str, 312 default=[], 313 action='append', 314 help='Added using namespace statements') 315 return run(**vars(parser.parse_args())) 316 317 318if __name__ == '__main__': 319 sys.exit(main()) 320