1# Copyright 2023 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14 15"""RPC log stream handler tests.""" 16 17from dataclasses import dataclass 18import logging 19from typing import Any, Callable 20from unittest import TestCase, main, mock 21 22from google.protobuf import message 23from pw_log.log_decoder import Log, LogStreamDecoder 24from pw_log.proto import log_pb2 25from pw_log_rpc.rpc_log_stream import LogStreamHandler 26from pw_rpc import callback_client, client, packets 27from pw_rpc.internal import packet_pb2 28from pw_status import Status 29 30_LOG = logging.getLogger(__name__) 31 32 33def _encode_server_stream_packet( 34 rpc: packets.RpcIds, payload: message.Message 35) -> bytes: 36 return packet_pb2.RpcPacket( 37 type=packet_pb2.PacketType.SERVER_STREAM, 38 channel_id=rpc.channel_id, 39 service_id=rpc.service_id, 40 method_id=rpc.method_id, 41 call_id=rpc.call_id, 42 payload=payload.SerializeToString(), 43 ).SerializeToString() 44 45 46def _encode_cancel(rpc: packets.RpcIds) -> bytes: 47 return packet_pb2.RpcPacket( 48 type=packet_pb2.PacketType.SERVER_ERROR, 49 status=Status.CANCELLED.value, 50 channel_id=rpc.channel_id, 51 service_id=rpc.service_id, 52 method_id=rpc.method_id, 53 call_id=rpc.call_id, 54 ).SerializeToString() 55 56 57def _encode_error(rpc: packets.RpcIds) -> bytes: 58 return packet_pb2.RpcPacket( 59 type=packet_pb2.PacketType.SERVER_ERROR, 60 status=Status.UNKNOWN.value, 61 channel_id=rpc.channel_id, 62 service_id=rpc.service_id, 63 method_id=rpc.method_id, 64 call_id=rpc.call_id, 65 ).SerializeToString() 66 67 68def _encode_completed(rpc: packets.RpcIds, status: Status) -> bytes: 69 return packet_pb2.RpcPacket( 70 type=packet_pb2.PacketType.RESPONSE, 71 status=status.value, 72 channel_id=rpc.channel_id, 73 service_id=rpc.service_id, 74 method_id=rpc.method_id, 75 call_id=rpc.call_id, 76 ).SerializeToString() 77 78 79class _CallableWithCounter: 80 """Wraps a function and counts how many time it was called.""" 81 82 @dataclass 83 class CallParams: 84 args: Any 85 kwargs: Any 86 87 def __init__(self, func: Callable[[Any], Any]): 88 self._func = func 89 self.calls: list[_CallableWithCounter.CallParams] = [] 90 91 def call_count(self) -> int: 92 return len(self.calls) 93 94 def __call__(self, *args, **kwargs) -> None: 95 self.calls.append(_CallableWithCounter.CallParams(args, kwargs)) 96 self._func(*args, **kwargs) 97 98 99class TestRpcLogStreamHandler(TestCase): 100 """Tests for TestRpcLogStreamHandler.""" 101 102 def setUp(self) -> None: 103 """Set up logs decoder.""" 104 self._channel_id = 1 105 self.client = client.Client.from_modules( 106 callback_client.Impl(), 107 [client.Channel(self._channel_id, lambda _: None)], 108 [log_pb2], 109 ) 110 111 self.captured_logs: list[Log] = [] 112 113 def decoded_log_handler(log: Log) -> None: 114 self.captured_logs.append(log) 115 116 log_decoder = LogStreamDecoder( 117 decoded_log_handler=decoded_log_handler, 118 source_name='source', 119 ) 120 self.log_stream_handler = LogStreamHandler( 121 self.client.channel(self._channel_id).rpcs, log_decoder 122 ) 123 124 def _get_rpc_ids(self) -> packets.RpcIds: 125 service = next(iter(self.client.services)) 126 method = next(iter(service.methods)) 127 128 # To handle unrequested log streams, packets' call Ids are set to 129 # kOpenCallId. 130 call_id = client.OPEN_CALL_ID 131 return packets.RpcIds(self._channel_id, service.id, method.id, call_id) 132 133 def test_listen_to_logs_subsequent_calls(self): 134 """Test a stream of RPC Logs.""" 135 self.log_stream_handler.handle_log_stream_error = mock.Mock() 136 self.log_stream_handler.handle_log_stream_completed = mock.Mock() 137 self.log_stream_handler.listen_to_logs() 138 139 self.assertIs( 140 self.client.process_packet( 141 _encode_server_stream_packet( 142 self._get_rpc_ids(), 143 log_pb2.LogEntries( 144 first_entry_sequence_id=0, 145 entries=[ 146 log_pb2.LogEntry(message=b'message0'), 147 log_pb2.LogEntry(message=b'message1'), 148 ], 149 ), 150 ) 151 ), 152 Status.OK, 153 ) 154 self.assertFalse(self.log_stream_handler.handle_log_stream_error.called) 155 self.assertFalse( 156 self.log_stream_handler.handle_log_stream_completed.called 157 ) 158 self.assertEqual(len(self.captured_logs), 2) 159 160 # A subsequent RPC packet should be handled successfully. 161 self.assertIs( 162 self.client.process_packet( 163 _encode_server_stream_packet( 164 self._get_rpc_ids(), 165 log_pb2.LogEntries( 166 first_entry_sequence_id=2, 167 entries=[ 168 log_pb2.LogEntry(message=b'message2'), 169 log_pb2.LogEntry(message=b'message3'), 170 ], 171 ), 172 ) 173 ), 174 Status.OK, 175 ) 176 self.assertFalse(self.log_stream_handler.handle_log_stream_error.called) 177 self.assertFalse( 178 self.log_stream_handler.handle_log_stream_completed.called 179 ) 180 self.assertEqual(len(self.captured_logs), 4) 181 182 def test_log_stream_cancelled(self): 183 """Tests that a cancelled log stream is not restarted.""" 184 self.log_stream_handler.handle_log_stream_error = mock.Mock() 185 self.log_stream_handler.handle_log_stream_completed = mock.Mock() 186 187 listen_function = _CallableWithCounter( 188 self.log_stream_handler.listen_to_logs 189 ) 190 self.log_stream_handler.listen_to_logs = listen_function 191 self.log_stream_handler.listen_to_logs() 192 193 # Send logs prior to cancellation. 194 self.assertIs( 195 self.client.process_packet( 196 _encode_server_stream_packet( 197 self._get_rpc_ids(), 198 log_pb2.LogEntries( 199 first_entry_sequence_id=0, 200 entries=[ 201 log_pb2.LogEntry(message=b'message0'), 202 log_pb2.LogEntry(message=b'message1'), 203 ], 204 ), 205 ) 206 ), 207 Status.OK, 208 ) 209 self.assertIs( 210 self.client.process_packet(_encode_cancel(self._get_rpc_ids())), 211 Status.OK, 212 ) 213 self.log_stream_handler.handle_log_stream_error.assert_called_once_with( 214 Status.CANCELLED 215 ) 216 self.assertFalse( 217 self.log_stream_handler.handle_log_stream_completed.called 218 ) 219 self.assertEqual(len(self.captured_logs), 2) 220 self.assertEqual(listen_function.call_count(), 1) 221 222 def test_log_stream_error_stream_restarted(self): 223 """Tests that an error on the log stream restarts the stream.""" 224 self.log_stream_handler.handle_log_stream_completed = mock.Mock() 225 226 error_handler = _CallableWithCounter( 227 self.log_stream_handler.handle_log_stream_error 228 ) 229 self.log_stream_handler.handle_log_stream_error = error_handler 230 231 listen_function = _CallableWithCounter( 232 self.log_stream_handler.listen_to_logs 233 ) 234 self.log_stream_handler.listen_to_logs = listen_function 235 self.log_stream_handler.listen_to_logs() 236 237 # Send logs prior to cancellation. 238 self.assertIs( 239 self.client.process_packet( 240 _encode_server_stream_packet( 241 self._get_rpc_ids(), 242 log_pb2.LogEntries( 243 first_entry_sequence_id=0, 244 entries=[ 245 log_pb2.LogEntry(message=b'message0'), 246 log_pb2.LogEntry(message=b'message1'), 247 ], 248 ), 249 ) 250 ), 251 Status.OK, 252 ) 253 self.assertIs( 254 self.client.process_packet(_encode_error(self._get_rpc_ids())), 255 Status.OK, 256 ) 257 258 self.assertFalse( 259 self.log_stream_handler.handle_log_stream_completed.called 260 ) 261 self.assertEqual(len(self.captured_logs), 2) 262 self.assertEqual(listen_function.call_count(), 2) 263 self.assertEqual(error_handler.call_count(), 1) 264 self.assertEqual(error_handler.calls[0].args, (Status.UNKNOWN,)) 265 266 def test_log_stream_completed_ok_stream_restarted(self): 267 """Tests that when the log stream completes the stream is restarted.""" 268 self.log_stream_handler.handle_log_stream_error = mock.Mock() 269 270 completion_handler = _CallableWithCounter( 271 self.log_stream_handler.handle_log_stream_completed 272 ) 273 self.log_stream_handler.handle_log_stream_completed = completion_handler 274 275 listen_function = _CallableWithCounter( 276 self.log_stream_handler.listen_to_logs 277 ) 278 self.log_stream_handler.listen_to_logs = listen_function 279 self.log_stream_handler.listen_to_logs() 280 281 # Send logs prior to cancellation. 282 self.assertIs( 283 self.client.process_packet( 284 _encode_server_stream_packet( 285 self._get_rpc_ids(), 286 log_pb2.LogEntries( 287 first_entry_sequence_id=0, 288 entries=[ 289 log_pb2.LogEntry(message=b'message0'), 290 log_pb2.LogEntry(message=b'message1'), 291 ], 292 ), 293 ) 294 ), 295 Status.OK, 296 ) 297 self.assertIs( 298 self.client.process_packet( 299 _encode_completed(self._get_rpc_ids(), Status.OK) 300 ), 301 Status.OK, 302 ) 303 304 self.assertFalse(self.log_stream_handler.handle_log_stream_error.called) 305 self.assertEqual(len(self.captured_logs), 2) 306 self.assertEqual(listen_function.call_count(), 2) 307 self.assertEqual(completion_handler.call_count(), 1) 308 self.assertEqual(completion_handler.calls[0].args, (Status.OK,)) 309 310 def test_log_stream_completed_with_error_stream_restarted(self): 311 """Tests that when the log stream completes the stream is restarted.""" 312 self.log_stream_handler.handle_log_stream_error = mock.Mock() 313 314 completion_handler = _CallableWithCounter( 315 self.log_stream_handler.handle_log_stream_completed 316 ) 317 self.log_stream_handler.handle_log_stream_completed = completion_handler 318 319 listen_function = _CallableWithCounter( 320 self.log_stream_handler.listen_to_logs 321 ) 322 self.log_stream_handler.listen_to_logs = listen_function 323 self.log_stream_handler.listen_to_logs() 324 325 # Send logs prior to cancellation. 326 self.assertIs( 327 self.client.process_packet( 328 _encode_server_stream_packet( 329 self._get_rpc_ids(), 330 log_pb2.LogEntries( 331 first_entry_sequence_id=0, 332 entries=[ 333 log_pb2.LogEntry(message=b'message0'), 334 log_pb2.LogEntry(message=b'message1'), 335 ], 336 ), 337 ) 338 ), 339 Status.OK, 340 ) 341 self.assertIs( 342 self.client.process_packet( 343 _encode_completed(self._get_rpc_ids(), Status.UNKNOWN) 344 ), 345 Status.OK, 346 ) 347 348 self.assertFalse(self.log_stream_handler.handle_log_stream_error.called) 349 self.assertEqual(len(self.captured_logs), 2) 350 self.assertEqual(listen_function.call_count(), 2) 351 self.assertEqual(completion_handler.call_count(), 1) 352 self.assertEqual(completion_handler.calls[0].args, (Status.UNKNOWN,)) 353 354 355if __name__ == '__main__': 356 main() 357