• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14
15"""RPC log stream handler tests."""
16
17from dataclasses import dataclass
18import logging
19from typing import Any, Callable
20from unittest import TestCase, main, mock
21
22from google.protobuf import message
23from pw_log.log_decoder import Log, LogStreamDecoder
24from pw_log.proto import log_pb2
25from pw_log_rpc.rpc_log_stream import LogStreamHandler
26from pw_rpc import callback_client, client, packets
27from pw_rpc.internal import packet_pb2
28from pw_status import Status
29
30_LOG = logging.getLogger(__name__)
31
32
33def _encode_server_stream_packet(
34    rpc: packets.RpcIds, payload: message.Message
35) -> bytes:
36    return packet_pb2.RpcPacket(
37        type=packet_pb2.PacketType.SERVER_STREAM,
38        channel_id=rpc.channel_id,
39        service_id=rpc.service_id,
40        method_id=rpc.method_id,
41        call_id=rpc.call_id,
42        payload=payload.SerializeToString(),
43    ).SerializeToString()
44
45
46def _encode_cancel(rpc: packets.RpcIds) -> bytes:
47    return packet_pb2.RpcPacket(
48        type=packet_pb2.PacketType.SERVER_ERROR,
49        status=Status.CANCELLED.value,
50        channel_id=rpc.channel_id,
51        service_id=rpc.service_id,
52        method_id=rpc.method_id,
53        call_id=rpc.call_id,
54    ).SerializeToString()
55
56
57def _encode_error(rpc: packets.RpcIds) -> bytes:
58    return packet_pb2.RpcPacket(
59        type=packet_pb2.PacketType.SERVER_ERROR,
60        status=Status.UNKNOWN.value,
61        channel_id=rpc.channel_id,
62        service_id=rpc.service_id,
63        method_id=rpc.method_id,
64        call_id=rpc.call_id,
65    ).SerializeToString()
66
67
68def _encode_completed(rpc: packets.RpcIds, status: Status) -> bytes:
69    return packet_pb2.RpcPacket(
70        type=packet_pb2.PacketType.RESPONSE,
71        status=status.value,
72        channel_id=rpc.channel_id,
73        service_id=rpc.service_id,
74        method_id=rpc.method_id,
75        call_id=rpc.call_id,
76    ).SerializeToString()
77
78
79class _CallableWithCounter:
80    """Wraps a function and counts how many time it was called."""
81
82    @dataclass
83    class CallParams:
84        args: Any
85        kwargs: Any
86
87    def __init__(self, func: Callable[[Any], Any]):
88        self._func = func
89        self.calls: list[_CallableWithCounter.CallParams] = []
90
91    def call_count(self) -> int:
92        return len(self.calls)
93
94    def __call__(self, *args, **kwargs) -> None:
95        self.calls.append(_CallableWithCounter.CallParams(args, kwargs))
96        self._func(*args, **kwargs)
97
98
99class TestRpcLogStreamHandler(TestCase):
100    """Tests for TestRpcLogStreamHandler."""
101
102    def setUp(self) -> None:
103        """Set up logs decoder."""
104        self._channel_id = 1
105        self.client = client.Client.from_modules(
106            callback_client.Impl(),
107            [client.Channel(self._channel_id, lambda _: None)],
108            [log_pb2],
109        )
110
111        self.captured_logs: list[Log] = []
112
113        def decoded_log_handler(log: Log) -> None:
114            self.captured_logs.append(log)
115
116        log_decoder = LogStreamDecoder(
117            decoded_log_handler=decoded_log_handler,
118            source_name='source',
119        )
120        self.log_stream_handler = LogStreamHandler(
121            self.client.channel(self._channel_id).rpcs, log_decoder
122        )
123
124    def _get_rpc_ids(self) -> packets.RpcIds:
125        service = next(iter(self.client.services))
126        method = next(iter(service.methods))
127
128        # To handle unrequested log streams, packets' call Ids are set to
129        # kOpenCallId.
130        call_id = client.OPEN_CALL_ID
131        return packets.RpcIds(self._channel_id, service.id, method.id, call_id)
132
133    def test_listen_to_logs_subsequent_calls(self):
134        """Test a stream of RPC Logs."""
135        self.log_stream_handler.handle_log_stream_error = mock.Mock()
136        self.log_stream_handler.handle_log_stream_completed = mock.Mock()
137        self.log_stream_handler.listen_to_logs()
138
139        self.assertIs(
140            self.client.process_packet(
141                _encode_server_stream_packet(
142                    self._get_rpc_ids(),
143                    log_pb2.LogEntries(
144                        first_entry_sequence_id=0,
145                        entries=[
146                            log_pb2.LogEntry(message=b'message0'),
147                            log_pb2.LogEntry(message=b'message1'),
148                        ],
149                    ),
150                )
151            ),
152            Status.OK,
153        )
154        self.assertFalse(self.log_stream_handler.handle_log_stream_error.called)
155        self.assertFalse(
156            self.log_stream_handler.handle_log_stream_completed.called
157        )
158        self.assertEqual(len(self.captured_logs), 2)
159
160        # A subsequent RPC packet should be handled successfully.
161        self.assertIs(
162            self.client.process_packet(
163                _encode_server_stream_packet(
164                    self._get_rpc_ids(),
165                    log_pb2.LogEntries(
166                        first_entry_sequence_id=2,
167                        entries=[
168                            log_pb2.LogEntry(message=b'message2'),
169                            log_pb2.LogEntry(message=b'message3'),
170                        ],
171                    ),
172                )
173            ),
174            Status.OK,
175        )
176        self.assertFalse(self.log_stream_handler.handle_log_stream_error.called)
177        self.assertFalse(
178            self.log_stream_handler.handle_log_stream_completed.called
179        )
180        self.assertEqual(len(self.captured_logs), 4)
181
182    def test_log_stream_cancelled(self):
183        """Tests that a cancelled log stream is not restarted."""
184        self.log_stream_handler.handle_log_stream_error = mock.Mock()
185        self.log_stream_handler.handle_log_stream_completed = mock.Mock()
186
187        listen_function = _CallableWithCounter(
188            self.log_stream_handler.listen_to_logs
189        )
190        self.log_stream_handler.listen_to_logs = listen_function
191        self.log_stream_handler.listen_to_logs()
192
193        # Send logs prior to cancellation.
194        self.assertIs(
195            self.client.process_packet(
196                _encode_server_stream_packet(
197                    self._get_rpc_ids(),
198                    log_pb2.LogEntries(
199                        first_entry_sequence_id=0,
200                        entries=[
201                            log_pb2.LogEntry(message=b'message0'),
202                            log_pb2.LogEntry(message=b'message1'),
203                        ],
204                    ),
205                )
206            ),
207            Status.OK,
208        )
209        self.assertIs(
210            self.client.process_packet(_encode_cancel(self._get_rpc_ids())),
211            Status.OK,
212        )
213        self.log_stream_handler.handle_log_stream_error.assert_called_once_with(
214            Status.CANCELLED
215        )
216        self.assertFalse(
217            self.log_stream_handler.handle_log_stream_completed.called
218        )
219        self.assertEqual(len(self.captured_logs), 2)
220        self.assertEqual(listen_function.call_count(), 1)
221
222    def test_log_stream_error_stream_restarted(self):
223        """Tests that an error on the log stream restarts the stream."""
224        self.log_stream_handler.handle_log_stream_completed = mock.Mock()
225
226        error_handler = _CallableWithCounter(
227            self.log_stream_handler.handle_log_stream_error
228        )
229        self.log_stream_handler.handle_log_stream_error = error_handler
230
231        listen_function = _CallableWithCounter(
232            self.log_stream_handler.listen_to_logs
233        )
234        self.log_stream_handler.listen_to_logs = listen_function
235        self.log_stream_handler.listen_to_logs()
236
237        # Send logs prior to cancellation.
238        self.assertIs(
239            self.client.process_packet(
240                _encode_server_stream_packet(
241                    self._get_rpc_ids(),
242                    log_pb2.LogEntries(
243                        first_entry_sequence_id=0,
244                        entries=[
245                            log_pb2.LogEntry(message=b'message0'),
246                            log_pb2.LogEntry(message=b'message1'),
247                        ],
248                    ),
249                )
250            ),
251            Status.OK,
252        )
253        self.assertIs(
254            self.client.process_packet(_encode_error(self._get_rpc_ids())),
255            Status.OK,
256        )
257
258        self.assertFalse(
259            self.log_stream_handler.handle_log_stream_completed.called
260        )
261        self.assertEqual(len(self.captured_logs), 2)
262        self.assertEqual(listen_function.call_count(), 2)
263        self.assertEqual(error_handler.call_count(), 1)
264        self.assertEqual(error_handler.calls[0].args, (Status.UNKNOWN,))
265
266    def test_log_stream_completed_ok_stream_restarted(self):
267        """Tests that when the log stream completes the stream is restarted."""
268        self.log_stream_handler.handle_log_stream_error = mock.Mock()
269
270        completion_handler = _CallableWithCounter(
271            self.log_stream_handler.handle_log_stream_completed
272        )
273        self.log_stream_handler.handle_log_stream_completed = completion_handler
274
275        listen_function = _CallableWithCounter(
276            self.log_stream_handler.listen_to_logs
277        )
278        self.log_stream_handler.listen_to_logs = listen_function
279        self.log_stream_handler.listen_to_logs()
280
281        # Send logs prior to cancellation.
282        self.assertIs(
283            self.client.process_packet(
284                _encode_server_stream_packet(
285                    self._get_rpc_ids(),
286                    log_pb2.LogEntries(
287                        first_entry_sequence_id=0,
288                        entries=[
289                            log_pb2.LogEntry(message=b'message0'),
290                            log_pb2.LogEntry(message=b'message1'),
291                        ],
292                    ),
293                )
294            ),
295            Status.OK,
296        )
297        self.assertIs(
298            self.client.process_packet(
299                _encode_completed(self._get_rpc_ids(), Status.OK)
300            ),
301            Status.OK,
302        )
303
304        self.assertFalse(self.log_stream_handler.handle_log_stream_error.called)
305        self.assertEqual(len(self.captured_logs), 2)
306        self.assertEqual(listen_function.call_count(), 2)
307        self.assertEqual(completion_handler.call_count(), 1)
308        self.assertEqual(completion_handler.calls[0].args, (Status.OK,))
309
310    def test_log_stream_completed_with_error_stream_restarted(self):
311        """Tests that when the log stream completes the stream is restarted."""
312        self.log_stream_handler.handle_log_stream_error = mock.Mock()
313
314        completion_handler = _CallableWithCounter(
315            self.log_stream_handler.handle_log_stream_completed
316        )
317        self.log_stream_handler.handle_log_stream_completed = completion_handler
318
319        listen_function = _CallableWithCounter(
320            self.log_stream_handler.listen_to_logs
321        )
322        self.log_stream_handler.listen_to_logs = listen_function
323        self.log_stream_handler.listen_to_logs()
324
325        # Send logs prior to cancellation.
326        self.assertIs(
327            self.client.process_packet(
328                _encode_server_stream_packet(
329                    self._get_rpc_ids(),
330                    log_pb2.LogEntries(
331                        first_entry_sequence_id=0,
332                        entries=[
333                            log_pb2.LogEntry(message=b'message0'),
334                            log_pb2.LogEntry(message=b'message1'),
335                        ],
336                    ),
337                )
338            ),
339            Status.OK,
340        )
341        self.assertIs(
342            self.client.process_packet(
343                _encode_completed(self._get_rpc_ids(), Status.UNKNOWN)
344            ),
345            Status.OK,
346        )
347
348        self.assertFalse(self.log_stream_handler.handle_log_stream_error.called)
349        self.assertEqual(len(self.captured_logs), 2)
350        self.assertEqual(listen_function.call_count(), 2)
351        self.assertEqual(completion_handler.call_count(), 1)
352        self.assertEqual(completion_handler.calls[0].args, (Status.UNKNOWN,))
353
354
355if __name__ == '__main__':
356    main()
357