• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Utilities for running unit tests over Pigweed RPC."""
15
16import enum
17import abc
18from dataclasses import dataclass
19import logging
20from typing import Iterable
21
22from pw_rpc.client import Services
23from pw_rpc.callback_client import OptionalTimeout, UseDefault
24from pw_unit_test_proto import unit_test_pb2
25
26_LOG = logging.getLogger(__package__)
27
28
29@dataclass(frozen=True)
30class TestCase:
31    suite_name: str
32    test_name: str
33    file_name: str
34
35    def __str__(self) -> str:
36        return f'{self.suite_name}.{self.test_name}'
37
38    def __repr__(self) -> str:
39        return f'TestCase({str(self)})'
40
41
42def _test_case(raw_test_case: unit_test_pb2.TestCaseDescriptor) -> TestCase:
43    return TestCase(
44        raw_test_case.suite_name,
45        raw_test_case.test_name,
46        raw_test_case.file_name,
47    )
48
49
50@dataclass(frozen=True)
51class TestExpectation:
52    expression: str
53    evaluated_expression: str
54    line_number: int
55    success: bool
56
57    def __str__(self) -> str:
58        return self.expression
59
60    def __repr__(self) -> str:
61        return f'TestExpectation({str(self)})'
62
63
64class TestCaseResult(enum.IntEnum):
65    SUCCESS = unit_test_pb2.TestCaseResult.SUCCESS
66    FAILURE = unit_test_pb2.TestCaseResult.FAILURE
67    SKIPPED = unit_test_pb2.TestCaseResult.SKIPPED
68
69
70class EventHandler(abc.ABC):
71    @abc.abstractmethod
72    def run_all_tests_start(self) -> None:
73        """Called before all tests are run."""
74
75    @abc.abstractmethod
76    def run_all_tests_end(self, passed_tests: int, failed_tests: int) -> None:
77        """Called after the test run is complete."""
78
79    @abc.abstractmethod
80    def test_case_start(self, test_case: TestCase) -> None:
81        """Called when a new test case is started."""
82
83    @abc.abstractmethod
84    def test_case_end(
85        self, test_case: TestCase, result: TestCaseResult
86    ) -> None:
87        """Called when a test case completes with its overall result."""
88
89    @abc.abstractmethod
90    def test_case_disabled(self, test_case: TestCase) -> None:
91        """Called when a disabled test case is encountered."""
92
93    @abc.abstractmethod
94    def test_case_expect(
95        self, test_case: TestCase, expectation: TestExpectation
96    ) -> None:
97        """Called after each expect/assert statement within a test case."""
98
99
100class LoggingEventHandler(EventHandler):
101    """Event handler that logs test events using Google Test format."""
102
103    def run_all_tests_start(self) -> None:
104        _LOG.info('[==========] Running all tests.')
105
106    def run_all_tests_end(self, passed_tests: int, failed_tests: int) -> None:
107        _LOG.info('[==========] Done running all tests.')
108        _LOG.info('[  PASSED  ] %d test(s).', passed_tests)
109        if failed_tests:
110            _LOG.info('[  FAILED  ] %d test(s).', failed_tests)
111
112    def test_case_start(self, test_case: TestCase) -> None:
113        _LOG.info('[ RUN      ] %s', test_case)
114
115    def test_case_end(
116        self, test_case: TestCase, result: TestCaseResult
117    ) -> None:
118        if result == TestCaseResult.SUCCESS:
119            _LOG.info('[       OK ] %s', test_case)
120        else:
121            _LOG.info('[  FAILED  ] %s', test_case)
122
123    def test_case_disabled(self, test_case: TestCase) -> None:
124        _LOG.info('Skipping disabled test %s', test_case)
125
126    def test_case_expect(
127        self, test_case: TestCase, expectation: TestExpectation
128    ) -> None:
129        result = 'Success' if expectation.success else 'Failure'
130        log = _LOG.info if expectation.success else _LOG.error
131        log('%s:%d: %s', test_case.file_name, expectation.line_number, result)
132        log('      Expected: %s', expectation.expression)
133        log('        Actual: %s', expectation.evaluated_expression)
134
135
136def run_tests(
137    rpcs: Services,
138    report_passed_expectations: bool = False,
139    test_suites: Iterable[str] = (),
140    event_handlers: Iterable[EventHandler] = (LoggingEventHandler(),),
141    timeout_s: OptionalTimeout = UseDefault.VALUE,
142) -> bool:
143    """Runs unit tests on a device over Pigweed RPC.
144
145    Calls each of the provided event handlers as test events occur, and returns
146    True if all tests pass.
147    """
148    unit_test_service = rpcs.pw.unit_test.UnitTest  # type: ignore[attr-defined]
149    request = unit_test_service.Run.request(
150        report_passed_expectations=report_passed_expectations,
151        test_suite=test_suites,
152    )
153    call = unit_test_service.Run.invoke(request, timeout_s=timeout_s)
154    test_responses = iter(call)
155
156    # Read the first response, which must be a test_run_start message.
157    try:
158        first_response = next(test_responses)
159    except StopIteration:
160        _LOG.error(
161            'The "test_run_start" message was dropped! UnitTest.Run '
162            'concluded with %s.',
163            call.status,
164        )
165        raise
166
167    if not first_response.HasField('test_run_start'):
168        raise ValueError(
169            'Expected a "test_run_start" response from pw.unit_test.Run, '
170            'but received a different message type. A response may have been '
171            'dropped.'
172        )
173
174    for event_handler in event_handlers:
175        event_handler.run_all_tests_start()
176
177    all_tests_passed = False
178
179    for response in test_responses:
180        if response.HasField('test_case_start'):
181            raw_test_case = response.test_case_start
182            current_test_case = _test_case(raw_test_case)
183
184        for event_handler in event_handlers:
185            if response.HasField('test_run_start'):
186                event_handler.run_all_tests_start()
187            elif response.HasField('test_run_end'):
188                event_handler.run_all_tests_end(
189                    response.test_run_end.passed, response.test_run_end.failed
190                )
191                if response.test_run_end.failed == 0:
192                    all_tests_passed = True
193            elif response.HasField('test_case_start'):
194                event_handler.test_case_start(current_test_case)
195            elif response.HasField('test_case_end'):
196                result = TestCaseResult(response.test_case_end)
197                event_handler.test_case_end(current_test_case, result)
198            elif response.HasField('test_case_disabled'):
199                event_handler.test_case_disabled(
200                    _test_case(response.test_case_disabled)
201                )
202            elif response.HasField('test_case_expectation'):
203                raw_expectation = response.test_case_expectation
204                expectation = TestExpectation(
205                    raw_expectation.expression,
206                    raw_expectation.evaluated_expression,
207                    raw_expectation.line_number,
208                    raw_expectation.success,
209                )
210                event_handler.test_case_expect(current_test_case, expectation)
211
212    return all_tests_passed
213