1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tools for generating Pigweed tests that execute in C++ and Python.""" 15 16import argparse 17from dataclasses import dataclass 18from datetime import datetime 19from collections import defaultdict 20import unittest 21 22from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List, 23 Sequence, TextIO, TypeVar, Union) 24 25_CPP_HEADER = f"""\ 26// Copyright {datetime.now().year} The Pigweed Authors 27// 28// Licensed under the Apache License, Version 2.0 (the "License"); you may not 29// use this file except in compliance with the License. You may obtain a copy of 30// the License at 31// 32// https://www.apache.org/licenses/LICENSE-2.0 33// 34// Unless required by applicable law or agreed to in writing, software 35// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 36// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 37// License for the specific language governing permissions and limitations under 38// the License. 39 40// AUTOGENERATED - DO NOT EDIT 41// 42// Generated at {datetime.now().isoformat()} 43 44// clang-format off 45""" 46 47 48class Error(Exception): 49 """Something went wrong when generating tests.""" 50 51 52T = TypeVar('T') 53 54 55@dataclass 56class Context(Generic[T]): 57 """Info passed into test generator functions for each test case.""" 58 group: str 59 count: int 60 total: int 61 test_case: T 62 63 def cc_name(self) -> str: 64 name = ''.join(w.capitalize() 65 for w in self.group.replace('-', ' ').split(' ')) 66 name = ''.join(c if c.isalnum() else '_' for c in name) 67 return f'{name}_{self.count}' if self.total > 1 else name 68 69 def py_name(self) -> str: 70 name = 'test_' + ''.join(c if c.isalnum() else '_' 71 for c in self.group.lower()) 72 return f'{name}_{self.count}' if self.total > 1 else name 73 74 75# Test cases are specified as a sequence of strings or test case instances. The 76# strings are used to separate the tests into named groups. For example: 77# 78# STR_SPLIT_TEST_CASES = ( 79# 'Empty input', 80# MyTestCase('', '', []), 81# MyTestCase('', 'foo', []), 82# 'Split on single character', 83# MyTestCase('abcde', 'c', ['ab', 'de']), 84# ... 85# ) 86# 87GroupOrTest = Union[str, T] 88 89# Python tests are generated by a function that returns a function usable as a 90# unittest.TestCase method. 91PyTest = Callable[[unittest.TestCase], None] 92PyTestGenerator = Callable[[Context[T]], PyTest] 93 94# C++ tests are generated with a function that returns or yields lines of C++ 95# code for the given test case. 96CcTestGenerator = Callable[[Context[T]], Iterable[str]] 97 98 99class TestGenerator(Generic[T]): 100 """Generates tests for multiple languages from a series of test cases.""" 101 def __init__(self, test_cases: Sequence[GroupOrTest[T]]): 102 self._cases: Dict[str, List[T]] = defaultdict(list) 103 message = '' 104 105 if len(test_cases) < 2: 106 raise Error('At least one test case must be provided') 107 108 if not isinstance(test_cases[0], str): 109 raise Error( 110 'The first item in the test cases must be a group name string') 111 112 for case in test_cases: 113 if isinstance(case, str): 114 message = case 115 else: 116 self._cases[message].append(case) 117 118 if '' in self._cases: 119 raise Error('Empty test group names are not permitted') 120 121 def _test_contexts(self) -> Iterator[Context[T]]: 122 for group, test_list in self._cases.items(): 123 for i, test_case in enumerate(test_list, 1): 124 yield Context(group, i, len(test_list), test_case) 125 126 def _generate_python_tests(self, define_py_test: PyTestGenerator): 127 tests: Dict[str, Callable[[Any], None]] = {} 128 129 for ctx in self._test_contexts(): 130 test = define_py_test(ctx) 131 test.__name__ = ctx.py_name() 132 133 if test.__name__ in tests: 134 raise Error( 135 f'Multiple Python tests are named {test.__name__}!') 136 137 tests[test.__name__] = test 138 139 return tests 140 141 def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type: 142 """Returns a Python unittest.TestCase class with tests for each case.""" 143 return type(name, (unittest.TestCase, ), 144 self._generate_python_tests(define_py_test)) 145 146 def _generate_cc_tests(self, define_cpp_test: CcTestGenerator, header: str, 147 footer: str) -> Iterator[str]: 148 yield _CPP_HEADER 149 yield header 150 151 for ctx in self._test_contexts(): 152 yield from define_cpp_test(ctx) 153 yield '' 154 155 yield footer 156 157 def cc_tests(self, output: TextIO, define_cpp_test: CcTestGenerator, 158 header: str, footer: str): 159 """Writes C++ unit tests for each test case to the given file.""" 160 for line in self._generate_cc_tests(define_cpp_test, header, footer): 161 output.write(line) 162 output.write('\n') 163 164 165def _to_chars(data: bytes) -> Iterator[str]: 166 for i, byte in enumerate(data): 167 try: 168 char = data[i:i + 1].decode() 169 yield char if char.isprintable() else fr'\x{byte:02x}' 170 except UnicodeDecodeError: 171 yield fr'\x{byte:02x}' 172 173 174def cc_string(data: Union[str, bytes]) -> str: 175 """Returns a C++ string literal version of a byte string or UTF-8 string.""" 176 if isinstance(data, str): 177 data = data.encode() 178 179 return '"' + ''.join(_to_chars(data)) + '"' 180 181 182def parse_test_generation_args() -> argparse.Namespace: 183 parser = argparse.ArgumentParser(description='Generate unit test files') 184 parser.add_argument('--generate-cc-test', 185 type=argparse.FileType('w'), 186 help='Generate the C++ test file') 187 return parser.parse_known_args()[0] 188