1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Utilities for running unit tests over Pigweed RPC.""" 15 16import enum 17import abc 18from dataclasses import dataclass 19import logging 20from typing import Iterable 21 22from pw_rpc.client import Services 23from pw_rpc.callback_client import OptionalTimeout, UseDefault 24from pw_unit_test_proto import unit_test_pb2 25 26_LOG = logging.getLogger(__package__) 27 28 29@dataclass(frozen=True) 30class TestCase: 31 suite_name: str 32 test_name: str 33 file_name: str 34 35 def __str__(self) -> str: 36 return f'{self.suite_name}.{self.test_name}' 37 38 def __repr__(self) -> str: 39 return f'TestCase({str(self)})' 40 41 42def _test_case(raw_test_case: unit_test_pb2.TestCaseDescriptor) -> TestCase: 43 return TestCase( 44 raw_test_case.suite_name, 45 raw_test_case.test_name, 46 raw_test_case.file_name, 47 ) 48 49 50@dataclass(frozen=True) 51class TestExpectation: 52 expression: str 53 evaluated_expression: str 54 line_number: int 55 success: bool 56 57 def __str__(self) -> str: 58 return self.expression 59 60 def __repr__(self) -> str: 61 return f'TestExpectation({str(self)})' 62 63 64class TestCaseResult(enum.IntEnum): 65 SUCCESS = unit_test_pb2.TestCaseResult.SUCCESS 66 FAILURE = unit_test_pb2.TestCaseResult.FAILURE 67 SKIPPED = unit_test_pb2.TestCaseResult.SKIPPED 68 69 70class EventHandler(abc.ABC): 71 @abc.abstractmethod 72 def run_all_tests_start(self) -> None: 73 """Called before all tests are run.""" 74 75 @abc.abstractmethod 76 def run_all_tests_end(self, passed_tests: int, failed_tests: int) -> None: 77 """Called after the test run is complete.""" 78 79 @abc.abstractmethod 80 def test_case_start(self, test_case: TestCase) -> None: 81 """Called when a new test case is started.""" 82 83 @abc.abstractmethod 84 def test_case_end( 85 self, test_case: TestCase, result: TestCaseResult 86 ) -> None: 87 """Called when a test case completes with its overall result.""" 88 89 @abc.abstractmethod 90 def test_case_disabled(self, test_case: TestCase) -> None: 91 """Called when a disabled test case is encountered.""" 92 93 @abc.abstractmethod 94 def test_case_expect( 95 self, test_case: TestCase, expectation: TestExpectation 96 ) -> None: 97 """Called after each expect/assert statement within a test case.""" 98 99 100class LoggingEventHandler(EventHandler): 101 """Event handler that logs test events using Google Test format.""" 102 103 def run_all_tests_start(self) -> None: 104 _LOG.info('[==========] Running all tests.') 105 106 def run_all_tests_end(self, passed_tests: int, failed_tests: int) -> None: 107 _LOG.info('[==========] Done running all tests.') 108 _LOG.info('[ PASSED ] %d test(s).', passed_tests) 109 if failed_tests: 110 _LOG.info('[ FAILED ] %d test(s).', failed_tests) 111 112 def test_case_start(self, test_case: TestCase) -> None: 113 _LOG.info('[ RUN ] %s', test_case) 114 115 def test_case_end( 116 self, test_case: TestCase, result: TestCaseResult 117 ) -> None: 118 if result == TestCaseResult.SUCCESS: 119 _LOG.info('[ OK ] %s', test_case) 120 else: 121 _LOG.info('[ FAILED ] %s', test_case) 122 123 def test_case_disabled(self, test_case: TestCase) -> None: 124 _LOG.info('Skipping disabled test %s', test_case) 125 126 def test_case_expect( 127 self, test_case: TestCase, expectation: TestExpectation 128 ) -> None: 129 result = 'Success' if expectation.success else 'Failure' 130 log = _LOG.info if expectation.success else _LOG.error 131 log('%s:%d: %s', test_case.file_name, expectation.line_number, result) 132 log(' Expected: %s', expectation.expression) 133 log(' Actual: %s', expectation.evaluated_expression) 134 135 136def run_tests( 137 rpcs: Services, 138 report_passed_expectations: bool = False, 139 test_suites: Iterable[str] = (), 140 event_handlers: Iterable[EventHandler] = (LoggingEventHandler(),), 141 timeout_s: OptionalTimeout = UseDefault.VALUE, 142) -> bool: 143 """Runs unit tests on a device over Pigweed RPC. 144 145 Calls each of the provided event handlers as test events occur, and returns 146 True if all tests pass. 147 """ 148 unit_test_service = rpcs.pw.unit_test.UnitTest # type: ignore[attr-defined] 149 request = unit_test_service.Run.request( 150 report_passed_expectations=report_passed_expectations, 151 test_suite=test_suites, 152 ) 153 call = unit_test_service.Run.invoke(request, timeout_s=timeout_s) 154 test_responses = iter(call) 155 156 # Read the first response, which must be a test_run_start message. 157 try: 158 first_response = next(test_responses) 159 except StopIteration: 160 _LOG.error( 161 'The "test_run_start" message was dropped! UnitTest.Run ' 162 'concluded with %s.', 163 call.status, 164 ) 165 raise 166 167 if not first_response.HasField('test_run_start'): 168 raise ValueError( 169 'Expected a "test_run_start" response from pw.unit_test.Run, ' 170 'but received a different message type. A response may have been ' 171 'dropped.' 172 ) 173 174 for event_handler in event_handlers: 175 event_handler.run_all_tests_start() 176 177 all_tests_passed = False 178 179 for response in test_responses: 180 if response.HasField('test_case_start'): 181 raw_test_case = response.test_case_start 182 current_test_case = _test_case(raw_test_case) 183 184 for event_handler in event_handlers: 185 if response.HasField('test_run_start'): 186 event_handler.run_all_tests_start() 187 elif response.HasField('test_run_end'): 188 event_handler.run_all_tests_end( 189 response.test_run_end.passed, response.test_run_end.failed 190 ) 191 if response.test_run_end.failed == 0: 192 all_tests_passed = True 193 elif response.HasField('test_case_start'): 194 event_handler.test_case_start(current_test_case) 195 elif response.HasField('test_case_end'): 196 result = TestCaseResult(response.test_case_end) 197 event_handler.test_case_end(current_test_case, result) 198 elif response.HasField('test_case_disabled'): 199 event_handler.test_case_disabled( 200 _test_case(response.test_case_disabled) 201 ) 202 elif response.HasField('test_case_expectation'): 203 raw_expectation = response.test_case_expectation 204 expectation = TestExpectation( 205 raw_expectation.expression, 206 raw_expectation.evaluated_expression, 207 raw_expectation.line_number, 208 raw_expectation.success, 209 ) 210 event_handler.test_case_expect(current_test_case, expectation) 211 212 return all_tests_passed 213