• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests using the callback client for pw_rpc."""
16
17import logging
18from pathlib import Path
19from typing import List, Tuple
20import unittest
21from unittest import mock
22
23from pw_hdlc import rpc
24from pw_rpc import testing
25from pw_unit_test_proto import unit_test_pb2
26from pw_unit_test import run_tests, EventHandler, TestCase
27from pw_status import Status
28
29# The three suites (Passing, Failing, and DISABLED_Disabled) have these cases.
30_CASES = ('Zero', 'One', 'Two', 'DISABLED_Disabled')
31_FILE = 'pw_unit_test/test_rpc_server.cc'
32
33PASSING = tuple(TestCase('Passing', case, _FILE) for case in _CASES[:-1])
34FAILING = tuple(TestCase('Failing', case, _FILE) for case in _CASES[:-1])
35EXECUTED_TESTS = PASSING + FAILING
36
37DISABLED_SUITE = tuple(
38    TestCase('DISABLED_Disabled', case, _FILE) for case in _CASES
39)
40
41ALL_DISABLED_TESTS = (
42    TestCase('Passing', 'DISABLED_Disabled', _FILE),
43    TestCase('Failing', 'DISABLED_Disabled', _FILE),
44    *DISABLED_SUITE,
45)
46
47
48class RpcIntegrationTest(unittest.TestCase):
49    """Calls RPCs on an RPC server through a socket."""
50
51    test_server_command: Tuple[str, ...] = ()
52    port: int
53
54    def setUp(self) -> None:
55        self._context = rpc.HdlcRpcLocalServerAndClient(
56            self.test_server_command, self.port, [unit_test_pb2]
57        )
58        self.rpcs = self._context.client.channel(1).rpcs
59        self.handler = mock.NonCallableMagicMock(spec=EventHandler)
60
61    def tearDown(self) -> None:
62        self._context.close()
63
64    def test_run_tests_default_handler(self) -> None:
65        with self.assertLogs(logging.getLogger('pw_unit_test'), 'INFO') as logs:
66            self.assertFalse(run_tests(self.rpcs))
67
68        for test in EXECUTED_TESTS:
69            self.assertTrue(any(str(test) in log for log in logs.output), test)
70
71    def test_run_tests_calls_test_case_start(self) -> None:
72        self.assertFalse(run_tests(self.rpcs, event_handlers=[self.handler]))
73
74        self.handler.test_case_start.assert_has_calls(
75            [mock.call(case) for case in EXECUTED_TESTS], any_order=True
76        )
77
78    def test_run_tests_calls_test_case_end(self) -> None:
79        self.assertFalse(run_tests(self.rpcs, event_handlers=[self.handler]))
80
81        calls = [
82            mock.call(
83                case,
84                unit_test_pb2.SUCCESS
85                if case.suite_name == 'Passing'
86                else unit_test_pb2.FAILURE,
87            )
88            for case in EXECUTED_TESTS
89        ]
90        self.handler.test_case_end.assert_has_calls(calls, any_order=True)
91
92    def test_run_tests_calls_test_case_disabled(self) -> None:
93        self.assertFalse(run_tests(self.rpcs, event_handlers=[self.handler]))
94
95        self.handler.test_case_disabled.assert_has_calls(
96            [mock.call(case) for case in ALL_DISABLED_TESTS], any_order=True
97        )
98
99    def test_passing_tests_only(self) -> None:
100        self.assertTrue(
101            run_tests(
102                self.rpcs,
103                test_suites=['Passing'],
104                event_handlers=[self.handler],
105            )
106        )
107        calls = [mock.call(case, unit_test_pb2.SUCCESS) for case in PASSING]
108        self.handler.test_case_end.assert_has_calls(calls, any_order=True)
109
110    def test_disabled_tests_only(self) -> None:
111        self.assertTrue(
112            run_tests(
113                self.rpcs,
114                test_suites=['DISABLED_Disabled'],
115                event_handlers=[self.handler],
116            )
117        )
118
119        self.handler.test_case_start.assert_not_called()
120        self.handler.test_case_end.assert_not_called()
121        self.handler.test_case_disabled.assert_has_calls(
122            [mock.call(case) for case in DISABLED_SUITE], any_order=True
123        )
124
125    def test_failing_tests(self) -> None:
126        self.assertFalse(
127            run_tests(
128                self.rpcs,
129                test_suites=['Failing'],
130                event_handlers=[self.handler],
131            )
132        )
133        calls = [mock.call(case, unit_test_pb2.FAILURE) for case in FAILING]
134        self.handler.test_case_end.assert_has_calls(calls, any_order=True)
135
136
137def _main(
138    test_server_command: List[str], port: int, unittest_args: List[str]
139) -> None:
140    RpcIntegrationTest.test_server_command = tuple(test_server_command)
141    RpcIntegrationTest.port = port
142    unittest.main(argv=unittest_args)
143
144
145if __name__ == '__main__':
146    _main(**vars(testing.parse_test_server_args()))
147