• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tools for generating Pigweed tests that execute in C++ and Python."""
15
16import argparse
17from dataclasses import dataclass
18from datetime import datetime
19from collections import defaultdict
20import unittest
21
22from typing import (
23    Any,
24    Callable,
25    Dict,
26    Generic,
27    Iterable,
28    Iterator,
29    List,
30    Sequence,
31    TextIO,
32    TypeVar,
33    Union,
34)
35
36_COPYRIGHT = f"""\
37// Copyright {datetime.now().year} The Pigweed Authors
38//
39// Licensed under the Apache License, Version 2.0 (the "License"); you may not
40// use this file except in compliance with the License. You may obtain a copy of
41// the License at
42//
43//     https://www.apache.org/licenses/LICENSE-2.0
44//
45// Unless required by applicable law or agreed to in writing, software
46// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
47// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
48// License for the specific language governing permissions and limitations under
49// the License.
50
51// AUTOGENERATED - DO NOT EDIT
52//
53// Generated at {datetime.now().isoformat()}
54"""
55
56_HEADER_CPP = (
57    _COPYRIGHT
58    + """\
59// clang-format off
60"""
61)
62
63_HEADER_JS = (
64    _COPYRIGHT
65    + """\
66/* eslint-env browser, jasmine */
67"""
68)
69
70
71class Error(Exception):
72    """Something went wrong when generating tests."""
73
74
75T = TypeVar('T')
76
77
78@dataclass
79class Context(Generic[T]):
80    """Info passed into test generator functions for each test case."""
81
82    group: str
83    count: int
84    total: int
85    test_case: T
86
87    def cc_name(self) -> str:
88        name = ''.join(
89            w.capitalize() for w in self.group.replace('-', ' ').split(' ')
90        )
91        name = ''.join(c if c.isalnum() else '_' for c in name)
92        return f'{name}_{self.count}' if self.total > 1 else name
93
94    def py_name(self) -> str:
95        name = 'test_' + ''.join(
96            c if c.isalnum() else '_' for c in self.group.lower()
97        )
98        return f'{name}_{self.count}' if self.total > 1 else name
99
100    def ts_name(self) -> str:
101        name = ''.join(c if c.isalnum() else ' ' for c in self.group.lower())
102        return f'{name} {self.count}' if self.total > 1 else name
103
104
105# Test cases are specified as a sequence of strings or test case instances. The
106# strings are used to separate the tests into named groups. For example:
107#
108#   STR_SPLIT_TEST_CASES = (
109#     'Empty input',
110#     MyTestCase('', '', []),
111#     MyTestCase('', 'foo', []),
112#     'Split on single character',
113#     MyTestCase('abcde', 'c', ['ab', 'de']),
114#     ...
115#   )
116#
117GroupOrTest = Union[str, T]
118
119# Python tests are generated by a function that returns a function usable as a
120# unittest.TestCase method.
121PyTest = Callable[[unittest.TestCase], None]
122PyTestGenerator = Callable[[Context[T]], PyTest]
123
124# C++ tests are generated with a function that returns or yields lines of C++
125# code for the given test case.
126CcTestGenerator = Callable[[Context[T]], Iterable[str]]
127
128JsTestGenerator = Callable[[Context[T]], Iterable[str]]
129
130
131class TestGenerator(Generic[T]):
132    """Generates tests for multiple languages from a series of test cases."""
133
134    def __init__(self, test_cases: Sequence[GroupOrTest[T]]):
135        self._cases: Dict[str, List[T]] = defaultdict(list)
136        message = ''
137
138        if len(test_cases) < 2:
139            raise Error('At least one test case must be provided')
140
141        if not isinstance(test_cases[0], str):
142            raise Error(
143                'The first item in the test cases must be a group name string'
144            )
145
146        for case in test_cases:
147            if isinstance(case, str):
148                message = case
149            else:
150                self._cases[message].append(case)
151
152        if '' in self._cases:
153            raise Error('Empty test group names are not permitted')
154
155    def _test_contexts(self) -> Iterator[Context[T]]:
156        for group, test_list in self._cases.items():
157            for i, test_case in enumerate(test_list, 1):
158                yield Context(group, i, len(test_list), test_case)
159
160    def _generate_python_tests(self, define_py_test: PyTestGenerator):
161        tests: Dict[str, Callable[[Any], None]] = {}
162
163        for ctx in self._test_contexts():
164            test = define_py_test(ctx)
165            test.__name__ = ctx.py_name()
166
167            if test.__name__ in tests:
168                raise Error(f'Multiple Python tests are named {test.__name__}!')
169
170            tests[test.__name__] = test
171
172        return tests
173
174    def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type:
175        """Returns a Python unittest.TestCase class with tests for each case."""
176        return type(
177            name,
178            (unittest.TestCase,),
179            self._generate_python_tests(define_py_test),
180        )
181
182    def _generate_cc_tests(
183        self, define_cpp_test: CcTestGenerator, header: str, footer: str
184    ) -> Iterator[str]:
185        yield _HEADER_CPP
186        yield header
187
188        for ctx in self._test_contexts():
189            yield from define_cpp_test(ctx)
190            yield ''
191
192        yield footer
193
194    def cc_tests(
195        self,
196        output: TextIO,
197        define_cpp_test: CcTestGenerator,
198        header: str,
199        footer: str,
200    ):
201        """Writes C++ unit tests for each test case to the given file."""
202        for line in self._generate_cc_tests(define_cpp_test, header, footer):
203            output.write(line)
204            output.write('\n')
205
206    def _generate_ts_tests(
207        self, define_ts_test: JsTestGenerator, header: str, footer: str
208    ) -> Iterator[str]:
209        yield _HEADER_JS
210        yield header
211
212        for ctx in self._test_contexts():
213            yield from define_ts_test(ctx)
214        yield footer
215
216    def ts_tests(
217        self,
218        output: TextIO,
219        define_js_test: JsTestGenerator,
220        header: str,
221        footer: str,
222    ):
223        """Writes JS unit tests for each test case to the given file."""
224        for line in self._generate_ts_tests(define_js_test, header, footer):
225            output.write(line)
226            output.write('\n')
227
228
229def _to_chars(data: bytes) -> Iterator[str]:
230    for i, byte in enumerate(data):
231        try:
232            char = data[i : i + 1].decode()
233            yield char if char.isprintable() else fr'\x{byte:02x}'
234        except UnicodeDecodeError:
235            yield fr'\x{byte:02x}'
236
237
238def cc_string(data: Union[str, bytes]) -> str:
239    """Returns a C++ string literal version of a byte string or UTF-8 string."""
240    if isinstance(data, str):
241        data = data.encode()
242
243    return '"' + ''.join(_to_chars(data)) + '"'
244
245
246def parse_test_generation_args() -> argparse.Namespace:
247    parser = argparse.ArgumentParser(description='Generate unit test files')
248    parser.add_argument(
249        '--generate-cc-test',
250        type=argparse.FileType('w'),
251        help='Generate the C++ test file',
252    )
253    parser.add_argument(
254        '--generate-ts-test',
255        type=argparse.FileType('w'),
256        help='Generate the JS test file',
257    )
258    return parser.parse_known_args()[0]
259