• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tools for generating Pigweed tests that execute in C++ and Python."""
15
16import argparse
17from dataclasses import dataclass
18from datetime import datetime
19from collections import defaultdict
20import unittest
21
22from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
23                    Sequence, TextIO, TypeVar, Union)
24
25_COPYRIGHT = f"""\
26// Copyright {datetime.now().year} The Pigweed Authors
27//
28// Licensed under the Apache License, Version 2.0 (the "License"); you may not
29// use this file except in compliance with the License. You may obtain a copy of
30// the License at
31//
32//     https://www.apache.org/licenses/LICENSE-2.0
33//
34// Unless required by applicable law or agreed to in writing, software
35// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
36// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
37// License for the specific language governing permissions and limitations under
38// the License.
39
40// AUTOGENERATED - DO NOT EDIT
41//
42// Generated at {datetime.now().isoformat()}
43"""
44
45_HEADER_CPP = _COPYRIGHT + """\
46// clang-format off
47"""
48
49_HEADER_JS = _COPYRIGHT + """\
50/* eslint-env browser, jasmine */
51"""
52
53
54class Error(Exception):
55    """Something went wrong when generating tests."""
56
57
58T = TypeVar('T')
59
60
61@dataclass
62class Context(Generic[T]):
63    """Info passed into test generator functions for each test case."""
64    group: str
65    count: int
66    total: int
67    test_case: T
68
69    def cc_name(self) -> str:
70        name = ''.join(w.capitalize()
71                       for w in self.group.replace('-', ' ').split(' '))
72        name = ''.join(c if c.isalnum() else '_' for c in name)
73        return f'{name}_{self.count}' if self.total > 1 else name
74
75    def py_name(self) -> str:
76        name = 'test_' + ''.join(c if c.isalnum() else '_'
77                                 for c in self.group.lower())
78        return f'{name}_{self.count}' if self.total > 1 else name
79
80    def ts_name(self) -> str:
81        name = ''.join(c if c.isalnum() else ' ' for c in self.group.lower())
82        return f'{name} {self.count}' if self.total > 1 else name
83
84
85# Test cases are specified as a sequence of strings or test case instances. The
86# strings are used to separate the tests into named groups. For example:
87#
88#   STR_SPLIT_TEST_CASES = (
89#     'Empty input',
90#     MyTestCase('', '', []),
91#     MyTestCase('', 'foo', []),
92#     'Split on single character',
93#     MyTestCase('abcde', 'c', ['ab', 'de']),
94#     ...
95#   )
96#
97GroupOrTest = Union[str, T]
98
99# Python tests are generated by a function that returns a function usable as a
100# unittest.TestCase method.
101PyTest = Callable[[unittest.TestCase], None]
102PyTestGenerator = Callable[[Context[T]], PyTest]
103
104# C++ tests are generated with a function that returns or yields lines of C++
105# code for the given test case.
106CcTestGenerator = Callable[[Context[T]], Iterable[str]]
107
108JsTestGenerator = Callable[[Context[T]], Iterable[str]]
109
110
111class TestGenerator(Generic[T]):
112    """Generates tests for multiple languages from a series of test cases."""
113    def __init__(self, test_cases: Sequence[GroupOrTest[T]]):
114        self._cases: Dict[str, List[T]] = defaultdict(list)
115        message = ''
116
117        if len(test_cases) < 2:
118            raise Error('At least one test case must be provided')
119
120        if not isinstance(test_cases[0], str):
121            raise Error(
122                'The first item in the test cases must be a group name string')
123
124        for case in test_cases:
125            if isinstance(case, str):
126                message = case
127            else:
128                self._cases[message].append(case)
129
130        if '' in self._cases:
131            raise Error('Empty test group names are not permitted')
132
133    def _test_contexts(self) -> Iterator[Context[T]]:
134        for group, test_list in self._cases.items():
135            for i, test_case in enumerate(test_list, 1):
136                yield Context(group, i, len(test_list), test_case)
137
138    def _generate_python_tests(self, define_py_test: PyTestGenerator):
139        tests: Dict[str, Callable[[Any], None]] = {}
140
141        for ctx in self._test_contexts():
142            test = define_py_test(ctx)
143            test.__name__ = ctx.py_name()
144
145            if test.__name__ in tests:
146                raise Error(
147                    f'Multiple Python tests are named {test.__name__}!')
148
149            tests[test.__name__] = test
150
151        return tests
152
153    def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type:
154        """Returns a Python unittest.TestCase class with tests for each case."""
155        return type(name, (unittest.TestCase, ),
156                    self._generate_python_tests(define_py_test))
157
158    def _generate_cc_tests(self, define_cpp_test: CcTestGenerator, header: str,
159                           footer: str) -> Iterator[str]:
160        yield _HEADER_CPP
161        yield header
162
163        for ctx in self._test_contexts():
164            yield from define_cpp_test(ctx)
165            yield ''
166
167        yield footer
168
169    def cc_tests(self, output: TextIO, define_cpp_test: CcTestGenerator,
170                 header: str, footer: str):
171        """Writes C++ unit tests for each test case to the given file."""
172        for line in self._generate_cc_tests(define_cpp_test, header, footer):
173            output.write(line)
174            output.write('\n')
175
176    def _generate_ts_tests(self, define_ts_test: JsTestGenerator, header: str,
177                           footer: str) -> Iterator[str]:
178        yield _HEADER_JS
179        yield header
180
181        for ctx in self._test_contexts():
182            yield from define_ts_test(ctx)
183        yield footer
184
185    def ts_tests(self, output: TextIO, define_js_test: JsTestGenerator,
186                 header: str, footer: str):
187        """Writes JS unit tests for each test case to the given file."""
188        for line in self._generate_ts_tests(define_js_test, header, footer):
189            output.write(line)
190            output.write('\n')
191
192
193def _to_chars(data: bytes) -> Iterator[str]:
194    for i, byte in enumerate(data):
195        try:
196            char = data[i:i + 1].decode()
197            yield char if char.isprintable() else fr'\x{byte:02x}'
198        except UnicodeDecodeError:
199            yield fr'\x{byte:02x}'
200
201
202def cc_string(data: Union[str, bytes]) -> str:
203    """Returns a C++ string literal version of a byte string or UTF-8 string."""
204    if isinstance(data, str):
205        data = data.encode()
206
207    return '"' + ''.join(_to_chars(data)) + '"'
208
209
210def parse_test_generation_args() -> argparse.Namespace:
211    parser = argparse.ArgumentParser(description='Generate unit test files')
212    parser.add_argument('--generate-cc-test',
213                        type=argparse.FileType('w'),
214                        help='Generate the C++ test file')
215    parser.add_argument('--generate-ts-test',
216                        type=argparse.FileType('w'),
217                        help='Generate the JS test file')
218    return parser.parse_known_args()[0]
219