1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tools for generating Pigweed tests that execute in C++ and Python.""" 15 16import argparse 17from dataclasses import dataclass 18from datetime import datetime 19from collections import defaultdict 20import unittest 21 22from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List, 23 Sequence, TextIO, TypeVar, Union) 24 25_COPYRIGHT = f"""\ 26// Copyright {datetime.now().year} The Pigweed Authors 27// 28// Licensed under the Apache License, Version 2.0 (the "License"); you may not 29// use this file except in compliance with the License. You may obtain a copy of 30// the License at 31// 32// https://www.apache.org/licenses/LICENSE-2.0 33// 34// Unless required by applicable law or agreed to in writing, software 35// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 36// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 37// License for the specific language governing permissions and limitations under 38// the License. 39 40// AUTOGENERATED - DO NOT EDIT 41// 42// Generated at {datetime.now().isoformat()} 43""" 44 45_HEADER_CPP = _COPYRIGHT + """\ 46// clang-format off 47""" 48 49_HEADER_JS = _COPYRIGHT + """\ 50/* eslint-env browser, jasmine */ 51""" 52 53 54class Error(Exception): 55 """Something went wrong when generating tests.""" 56 57 58T = TypeVar('T') 59 60 61@dataclass 62class Context(Generic[T]): 63 """Info passed into test generator functions for each test case.""" 64 group: str 65 count: int 66 total: int 67 test_case: T 68 69 def cc_name(self) -> str: 70 name = ''.join(w.capitalize() 71 for w in self.group.replace('-', ' ').split(' ')) 72 name = ''.join(c if c.isalnum() else '_' for c in name) 73 return f'{name}_{self.count}' if self.total > 1 else name 74 75 def py_name(self) -> str: 76 name = 'test_' + ''.join(c if c.isalnum() else '_' 77 for c in self.group.lower()) 78 return f'{name}_{self.count}' if self.total > 1 else name 79 80 def ts_name(self) -> str: 81 name = ''.join(c if c.isalnum() else ' ' for c in self.group.lower()) 82 return f'{name} {self.count}' if self.total > 1 else name 83 84 85# Test cases are specified as a sequence of strings or test case instances. The 86# strings are used to separate the tests into named groups. For example: 87# 88# STR_SPLIT_TEST_CASES = ( 89# 'Empty input', 90# MyTestCase('', '', []), 91# MyTestCase('', 'foo', []), 92# 'Split on single character', 93# MyTestCase('abcde', 'c', ['ab', 'de']), 94# ... 95# ) 96# 97GroupOrTest = Union[str, T] 98 99# Python tests are generated by a function that returns a function usable as a 100# unittest.TestCase method. 101PyTest = Callable[[unittest.TestCase], None] 102PyTestGenerator = Callable[[Context[T]], PyTest] 103 104# C++ tests are generated with a function that returns or yields lines of C++ 105# code for the given test case. 106CcTestGenerator = Callable[[Context[T]], Iterable[str]] 107 108JsTestGenerator = Callable[[Context[T]], Iterable[str]] 109 110 111class TestGenerator(Generic[T]): 112 """Generates tests for multiple languages from a series of test cases.""" 113 def __init__(self, test_cases: Sequence[GroupOrTest[T]]): 114 self._cases: Dict[str, List[T]] = defaultdict(list) 115 message = '' 116 117 if len(test_cases) < 2: 118 raise Error('At least one test case must be provided') 119 120 if not isinstance(test_cases[0], str): 121 raise Error( 122 'The first item in the test cases must be a group name string') 123 124 for case in test_cases: 125 if isinstance(case, str): 126 message = case 127 else: 128 self._cases[message].append(case) 129 130 if '' in self._cases: 131 raise Error('Empty test group names are not permitted') 132 133 def _test_contexts(self) -> Iterator[Context[T]]: 134 for group, test_list in self._cases.items(): 135 for i, test_case in enumerate(test_list, 1): 136 yield Context(group, i, len(test_list), test_case) 137 138 def _generate_python_tests(self, define_py_test: PyTestGenerator): 139 tests: Dict[str, Callable[[Any], None]] = {} 140 141 for ctx in self._test_contexts(): 142 test = define_py_test(ctx) 143 test.__name__ = ctx.py_name() 144 145 if test.__name__ in tests: 146 raise Error( 147 f'Multiple Python tests are named {test.__name__}!') 148 149 tests[test.__name__] = test 150 151 return tests 152 153 def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type: 154 """Returns a Python unittest.TestCase class with tests for each case.""" 155 return type(name, (unittest.TestCase, ), 156 self._generate_python_tests(define_py_test)) 157 158 def _generate_cc_tests(self, define_cpp_test: CcTestGenerator, header: str, 159 footer: str) -> Iterator[str]: 160 yield _HEADER_CPP 161 yield header 162 163 for ctx in self._test_contexts(): 164 yield from define_cpp_test(ctx) 165 yield '' 166 167 yield footer 168 169 def cc_tests(self, output: TextIO, define_cpp_test: CcTestGenerator, 170 header: str, footer: str): 171 """Writes C++ unit tests for each test case to the given file.""" 172 for line in self._generate_cc_tests(define_cpp_test, header, footer): 173 output.write(line) 174 output.write('\n') 175 176 def _generate_ts_tests(self, define_ts_test: JsTestGenerator, header: str, 177 footer: str) -> Iterator[str]: 178 yield _HEADER_JS 179 yield header 180 181 for ctx in self._test_contexts(): 182 yield from define_ts_test(ctx) 183 yield footer 184 185 def ts_tests(self, output: TextIO, define_js_test: JsTestGenerator, 186 header: str, footer: str): 187 """Writes JS unit tests for each test case to the given file.""" 188 for line in self._generate_ts_tests(define_js_test, header, footer): 189 output.write(line) 190 output.write('\n') 191 192 193def _to_chars(data: bytes) -> Iterator[str]: 194 for i, byte in enumerate(data): 195 try: 196 char = data[i:i + 1].decode() 197 yield char if char.isprintable() else fr'\x{byte:02x}' 198 except UnicodeDecodeError: 199 yield fr'\x{byte:02x}' 200 201 202def cc_string(data: Union[str, bytes]) -> str: 203 """Returns a C++ string literal version of a byte string or UTF-8 string.""" 204 if isinstance(data, str): 205 data = data.encode() 206 207 return '"' + ''.join(_to_chars(data)) + '"' 208 209 210def parse_test_generation_args() -> argparse.Namespace: 211 parser = argparse.ArgumentParser(description='Generate unit test files') 212 parser.add_argument('--generate-cc-test', 213 type=argparse.FileType('w'), 214 help='Generate the C++ test file') 215 parser.add_argument('--generate-ts-test', 216 type=argparse.FileType('w'), 217 help='Generate the JS test file') 218 return parser.parse_known_args()[0] 219