• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18from dataclasses import dataclass, field
19import json
20from pathlib import Path
21import sys
22from textwrap import dedent
23from typing import List, Tuple, Union, Optional
24
25from pdl import ast, core
26from pdl.utils import indent, to_pascal_case
27
28
29def get_cxx_scalar_type(width: int) -> str:
30    """Return the cxx scalar type to be used to back a PDL type."""
31    for n in [8, 16, 32, 64]:
32        if width <= n:
33            return f'uint{n}_t'
34    # PDL type does not fit on non-extended scalar types.
35    assert False
36
37
38def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str:
39    """Generate the implementation of unit tests for the selected packet."""
40
41    def parse_packet(packet: ast.PacketDeclaration) -> str:
42        parent = parse_packet(packet.parent) if packet.parent else "input"
43        return f"{packet.id}View::Create({parent})"
44
45    def input_bytes(input: str) -> List[str]:
46        input = bytes.fromhex(input)
47        input_bytes = []
48        for i in range(0, len(input), 16):
49            input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16]))
50        return input_bytes
51
52    def get_field(decl: ast.Declaration, var: str, id: str) -> str:
53        if isinstance(decl, ast.StructDeclaration):
54            return f"{var}.{id}_"
55        else:
56            return f"{var}.Get{to_pascal_case(id)}()"
57
58    def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]:
59        checks = []
60        for (id, value) in expected.items():
61            field = core.get_packet_field(decl, id)
62            sanitized_var = var.replace('[', '_').replace(']', '')
63            field_var = f'{sanitized_var}_{id}'
64
65            if isinstance(field, ast.ScalarField):
66                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")
67
68            elif (isinstance(field, ast.TypedefField) and
69                  isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))):
70                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));")
71
72            elif isinstance(field, ast.TypedefField):
73                checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};")
74                checks.extend(check_members(field.type, field_var, value))
75
76            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
77                checks.append(f"std::vector<uint8_t> expected_{field_var} {{")
78                for i in range(0, len(value), 16):
79                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
80                checks.append("};")
81                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
82
83            elif isinstance(field, ast.ArrayField) and field.width:
84                checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{")
85                step = int(16 * 8 / field.width)
86                for i in range(0, len(value), step):
87                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
88                checks.append("};")
89                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
90
91            elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)):
92                checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{")
93                for v in value:
94                    checks.append(f"    {field.type_id}({v}),")
95                checks.append("};")
96                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
97
98            elif isinstance(field, ast.ArrayField):
99                checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};")
100                checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});")
101                for (n, value) in enumerate(value):
102                    checks.extend(check_members(field.type, f"{field_var}[{n}]", value))
103
104            else:
105                pass
106
107        return checks
108
109    generated_tests = []
110    for (test_nr, test) in enumerate(tests):
111        child_packet_id = test.get('packet', packet.id)
112        child_packet = packet.file.packet_scope[child_packet_id]
113
114        generated_tests.append(
115            dedent("""\
116
117            TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{
118                pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{
119                    {input_bytes}
120                }}));
121                {child_packet_id}View packet = {parse_packet};
122                ASSERT_TRUE(packet.IsValid());
123                {checks}
124            }}
125            """).format(parser_test_suite=parser_test_suite,
126                        packet_id=packet.id,
127                        child_packet_id=child_packet_id,
128                        test_nr=test_nr,
129                        input_bytes=indent(input_bytes(test['packed']), 2),
130                        parse_packet=parse_packet(child_packet),
131                        checks=indent(check_members(packet, 'packet', test['unpacked']), 1)))
132
133    return ''.join(generated_tests)
134
135
136def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration,
137                                    tests: List[object]) -> str:
138    """Generate the implementation of unit tests for the selected packet."""
139
140    def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]):
141        fields = core.get_unconstrained_parent_fields(decl) + decl.fields
142        declarations = []
143        parameters = []
144        for field in fields:
145            sanitized_var = var.replace('[', '_').replace(']', '')
146            field_id = getattr(field, 'id', None)
147            field_var = f'{sanitized_var}_{field_id}'
148            value = initializer['payload'] if isinstance(field, (ast.PayloadField,
149                                                                 ast.BodyField)) else initializer.get(field_id, None)
150
151            if isinstance(field, ast.ScalarField):
152                parameters.append(f"{value}")
153
154            elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
155                parameters.append(f"{field.type_id}({value})")
156
157            elif isinstance(field, ast.TypedefField):
158                (element, intermediate_declarations) = build_packet(field.type, field_var, value)
159                declarations.extend(intermediate_declarations)
160                parameters.append(element)
161
162            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
163                declarations.append(f"std::vector<uint8_t> {field_var} {{")
164                for i in range(0, len(value), 16):
165                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
166                declarations.append("};")
167                parameters.append(f"std::move({field_var})")
168
169            elif isinstance(field, ast.ArrayField) and field.width:
170                declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{")
171                step = int(16 * 8 / field.width)
172                for i in range(0, len(value), step):
173                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
174                declarations.append("};")
175                parameters.append(f"std::move({field_var})")
176
177            elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration):
178                declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
179                for v in value:
180                    declarations.append(f"    {field.type_id}({v}),")
181                declarations.append("};")
182                parameters.append(f"std::move({field_var})")
183
184            elif isinstance(field, ast.ArrayField):
185                elements = []
186                for (n, value) in enumerate(value):
187                    (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value)
188                    elements.append(element)
189                    declarations.extend(intermediate_declarations)
190                declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
191                for element in elements:
192                    declarations.append(f"    {element},")
193                declarations.append("};")
194                parameters.append(f"std::move({field_var})")
195
196            else:
197                pass
198
199        constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id
200        return (f"{constructor_name}({', '.join(parameters)})", declarations)
201
202    def output_bytes(output: str) -> List[str]:
203        output = bytes.fromhex(output)
204        output_bytes = []
205        for i in range(0, len(output), 16):
206            output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16]))
207        return output_bytes
208
209    generated_tests = []
210    for (test_nr, test) in enumerate(tests):
211        child_packet_id = test.get('packet', packet.id)
212        child_packet = packet.file.packet_scope[child_packet_id]
213
214        (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked'])
215        generated_tests.append(
216            dedent("""\
217
218            TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{
219                std::vector<uint8_t> expected_output {{
220                    {output_bytes}
221                }};
222                {intermediate_declarations}
223                {child_packet_id}Builder packet = {built_packet};
224                ASSERT_EQ(packet.pdl::packet::Builder::Serialize(), expected_output);
225            }}
226            """).format(serializer_test_suite=serializer_test_suite,
227                        packet_id=packet.id,
228                        child_packet_id=child_packet_id,
229                        test_nr=test_nr,
230                        output_bytes=indent(output_bytes(test['packed']), 2),
231                        built_packet=built_packet,
232                        intermediate_declarations=indent(intermediate_declarations, 1)))
233
234    return ''.join(generated_tests)
235
236
237def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str],
238        using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str):
239
240    file = ast.File.from_json(json.load(input))
241    tests = json.load(test_vectors)
242    core.desugar(file)
243
244    include_header = '\n'.join([f'#include <{header}>' for header in include_header])
245    using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace])
246
247    skipped_tests = [
248        'Packet_Checksum_Field_FromStart',
249        'Packet_Checksum_Field_FromEnd',
250        'Struct_Checksum_Field_FromStart',
251        'Struct_Checksum_Field_FromEnd',
252        'PartialParent5',
253        'PartialParent12',
254    ]
255
256    output.write(
257        dedent("""\
258        // File generated from {input_name} and {test_vectors_name}, with the command:
259        //  {input_command}
260        // /!\\ Do not edit by hand
261
262        #include <cstdint>
263        #include <string>
264        #include <gtest/gtest.h>
265        #include <packet_runtime.h>
266
267        {include_header}
268        {using_namespace}
269
270        namespace {namespace} {{
271
272        class {parser_test_suite} : public testing::Test {{}};
273        class {serializer_test_suite} : public testing::Test {{}};
274        """).format(parser_test_suite=parser_test_suite,
275                    serializer_test_suite=serializer_test_suite,
276                    input_name=input.name,
277                    input_command=' '.join(sys.argv),
278                    test_vectors_name=test_vectors.name,
279                    include_header=include_header,
280                    using_namespace=using_namespace,
281                    namespace=namespace))
282
283    for decl in file.declarations:
284        if decl.id in skipped_tests:
285            continue
286
287        if isinstance(decl, ast.PacketDeclaration):
288            matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id]
289            matching_tests = [test for test_list in matching_tests for test in test_list]
290            if matching_tests:
291                output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests))
292                output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests))
293
294    output.write(f"}}  // namespace {namespace}\n")
295
296
297def main() -> int:
298    """Generate cxx PDL backend."""
299    parser = argparse.ArgumentParser(description=__doc__)
300    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
301    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file')
302    parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file')
303    parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file')
304    parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite')
305    parser.add_argument('--serializer-test-suite',
306                        type=str,
307                        default='SerializerTest',
308                        help='Name of the serializer test suite')
309    parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives')
310    parser.add_argument('--using-namespace',
311                        type=str,
312                        default=[],
313                        action='append',
314                        help='Added using namespace statements')
315    return run(**vars(parser.parse_args()))
316
317
318if __name__ == '__main__':
319    sys.exit(main())
320