1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tools for generating Pigweed tests that execute in C++ and Python.""" 15 16import argparse 17from dataclasses import dataclass 18from datetime import datetime 19from collections import defaultdict 20import unittest 21 22from typing import ( 23 Any, 24 Callable, 25 Dict, 26 Generic, 27 Iterable, 28 Iterator, 29 List, 30 Sequence, 31 TextIO, 32 TypeVar, 33 Union, 34) 35 36_COPYRIGHT = f"""\ 37// Copyright {datetime.now().year} The Pigweed Authors 38// 39// Licensed under the Apache License, Version 2.0 (the "License"); you may not 40// use this file except in compliance with the License. You may obtain a copy of 41// the License at 42// 43// https://www.apache.org/licenses/LICENSE-2.0 44// 45// Unless required by applicable law or agreed to in writing, software 46// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 47// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 48// License for the specific language governing permissions and limitations under 49// the License. 50 51// AUTOGENERATED - DO NOT EDIT 52// 53// Generated at {datetime.now().isoformat()} 54""" 55 56_HEADER_CPP = ( 57 _COPYRIGHT 58 + """\ 59// clang-format off 60""" 61) 62 63_HEADER_JS = ( 64 _COPYRIGHT 65 + """\ 66/* eslint-env browser, jasmine */ 67""" 68) 69 70 71class Error(Exception): 72 """Something went wrong when generating tests.""" 73 74 75T = TypeVar('T') 76 77 78@dataclass 79class Context(Generic[T]): 80 """Info passed into test generator functions for each test case.""" 81 82 group: str 83 count: int 84 total: int 85 test_case: T 86 87 def cc_name(self) -> str: 88 name = ''.join( 89 w.capitalize() for w in self.group.replace('-', ' ').split(' ') 90 ) 91 name = ''.join(c if c.isalnum() else '_' for c in name) 92 return f'{name}_{self.count}' if self.total > 1 else name 93 94 def py_name(self) -> str: 95 name = 'test_' + ''.join( 96 c if c.isalnum() else '_' for c in self.group.lower() 97 ) 98 return f'{name}_{self.count}' if self.total > 1 else name 99 100 def ts_name(self) -> str: 101 name = ''.join(c if c.isalnum() else ' ' for c in self.group.lower()) 102 return f'{name} {self.count}' if self.total > 1 else name 103 104 105# Test cases are specified as a sequence of strings or test case instances. The 106# strings are used to separate the tests into named groups. For example: 107# 108# STR_SPLIT_TEST_CASES = ( 109# 'Empty input', 110# MyTestCase('', '', []), 111# MyTestCase('', 'foo', []), 112# 'Split on single character', 113# MyTestCase('abcde', 'c', ['ab', 'de']), 114# ... 115# ) 116# 117GroupOrTest = Union[str, T] 118 119# Python tests are generated by a function that returns a function usable as a 120# unittest.TestCase method. 121PyTest = Callable[[unittest.TestCase], None] 122PyTestGenerator = Callable[[Context[T]], PyTest] 123 124# C++ tests are generated with a function that returns or yields lines of C++ 125# code for the given test case. 126CcTestGenerator = Callable[[Context[T]], Iterable[str]] 127 128JsTestGenerator = Callable[[Context[T]], Iterable[str]] 129 130 131class TestGenerator(Generic[T]): 132 """Generates tests for multiple languages from a series of test cases.""" 133 134 def __init__(self, test_cases: Sequence[GroupOrTest[T]]): 135 self._cases: Dict[str, List[T]] = defaultdict(list) 136 message = '' 137 138 if len(test_cases) < 2: 139 raise Error('At least one test case must be provided') 140 141 if not isinstance(test_cases[0], str): 142 raise Error( 143 'The first item in the test cases must be a group name string' 144 ) 145 146 for case in test_cases: 147 if isinstance(case, str): 148 message = case 149 else: 150 self._cases[message].append(case) 151 152 if '' in self._cases: 153 raise Error('Empty test group names are not permitted') 154 155 def _test_contexts(self) -> Iterator[Context[T]]: 156 for group, test_list in self._cases.items(): 157 for i, test_case in enumerate(test_list, 1): 158 yield Context(group, i, len(test_list), test_case) 159 160 def _generate_python_tests(self, define_py_test: PyTestGenerator): 161 tests: Dict[str, Callable[[Any], None]] = {} 162 163 for ctx in self._test_contexts(): 164 test = define_py_test(ctx) 165 test.__name__ = ctx.py_name() 166 167 if test.__name__ in tests: 168 raise Error(f'Multiple Python tests are named {test.__name__}!') 169 170 tests[test.__name__] = test 171 172 return tests 173 174 def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type: 175 """Returns a Python unittest.TestCase class with tests for each case.""" 176 return type( 177 name, 178 (unittest.TestCase,), 179 self._generate_python_tests(define_py_test), 180 ) 181 182 def _generate_cc_tests( 183 self, define_cpp_test: CcTestGenerator, header: str, footer: str 184 ) -> Iterator[str]: 185 yield _HEADER_CPP 186 yield header 187 188 for ctx in self._test_contexts(): 189 yield from define_cpp_test(ctx) 190 yield '' 191 192 yield footer 193 194 def cc_tests( 195 self, 196 output: TextIO, 197 define_cpp_test: CcTestGenerator, 198 header: str, 199 footer: str, 200 ): 201 """Writes C++ unit tests for each test case to the given file.""" 202 for line in self._generate_cc_tests(define_cpp_test, header, footer): 203 output.write(line) 204 output.write('\n') 205 206 def _generate_ts_tests( 207 self, define_ts_test: JsTestGenerator, header: str, footer: str 208 ) -> Iterator[str]: 209 yield _HEADER_JS 210 yield header 211 212 for ctx in self._test_contexts(): 213 yield from define_ts_test(ctx) 214 yield footer 215 216 def ts_tests( 217 self, 218 output: TextIO, 219 define_js_test: JsTestGenerator, 220 header: str, 221 footer: str, 222 ): 223 """Writes JS unit tests for each test case to the given file.""" 224 for line in self._generate_ts_tests(define_js_test, header, footer): 225 output.write(line) 226 output.write('\n') 227 228 229def _to_chars(data: bytes) -> Iterator[str]: 230 for i, byte in enumerate(data): 231 try: 232 char = data[i : i + 1].decode() 233 yield char if char.isprintable() else fr'\x{byte:02x}' 234 except UnicodeDecodeError: 235 yield fr'\x{byte:02x}' 236 237 238def cc_string(data: Union[str, bytes]) -> str: 239 """Returns a C++ string literal version of a byte string or UTF-8 string.""" 240 if isinstance(data, str): 241 data = data.encode() 242 243 return '"' + ''.join(_to_chars(data)) + '"' 244 245 246def parse_test_generation_args() -> argparse.Namespace: 247 parser = argparse.ArgumentParser(description='Generate unit test files') 248 parser.add_argument( 249 '--generate-cc-test', 250 type=argparse.FileType('w'), 251 help='Generate the C++ test file', 252 ) 253 parser.add_argument( 254 '--generate-ts-test', 255 type=argparse.FileType('w'), 256 help='Generate the JS test file', 257 ) 258 return parser.parse_known_args()[0] 259