1#!/usr/bin/env python3 2# Copyright 2023 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""device module unit tests""" 16 17from contextlib import contextmanager 18import logging 19import queue 20import threading 21import time 22import unittest 23 24from pw_hdlc.rpc import RpcClient, HdlcRpcClient, CancellableReader 25 26 27class QueueFile: 28 """A fake file object backed by a queue for testing.""" 29 30 EOF = object() 31 32 def __init__(self): 33 # Operator puts; consumer gets 34 self._q = queue.Queue() 35 36 # Consumer side access only! 37 self._readbuf = b'' 38 self._eof = False 39 40 ############### 41 # Consumer side 42 43 def __enter__(self): 44 return self 45 46 def __exit__(self, *exc_info): 47 self.close() 48 49 def _read_from_buf(self, size: int) -> bytes: 50 data = self._readbuf[:size] 51 self._readbuf = self._readbuf[size:] 52 return data 53 54 def read(self, size: int = 1) -> bytes: 55 """Reads data from the queue""" 56 # First try to get buffered data 57 data = self._read_from_buf(size) 58 assert len(data) <= size 59 size -= len(data) 60 61 # if size == 0: 62 if data: 63 return data 64 65 # No more data in the buffer 66 assert not self._readbuf 67 68 if self._eof: 69 return data # may be empty 70 71 # Not enough in the buffer; block on the queue 72 item = self._q.get() 73 74 # NOTE: We can't call Queue.task_done() here because the reader hasn't 75 # actually *acted* on the read item yet. 76 77 # Queued data 78 if isinstance(item, bytes): 79 self._readbuf = item 80 return self._read_from_buf(size) 81 82 # Queued exception 83 if isinstance(item, Exception): 84 raise item 85 86 # Report EOF 87 if item is self.EOF: 88 self._eof = True 89 return data # may be empty 90 91 raise Exception('unexpected item type') 92 93 def write(self, data: bytes) -> None: 94 pass 95 96 ##################### 97 # Weird middle ground 98 99 # It is a violation of most file-like object APIs for one thread to call 100 # close() while another thread is calling read(). The behavior is 101 # undefined. 102 # 103 # - On Linux, close() may wake up a select(), leaving the caller with a bad 104 # file descriptor (which could get reused!) 105 # - Or the read() could continue to block indefinitely. 106 # 107 # We choose to cause a subsequent/parallel read to receive an exception. 108 def close(self) -> None: 109 self.cause_read_exc(Exception('closed')) 110 111 ############### 112 # Operator side 113 114 def put_read_data(self, data: bytes) -> None: 115 self._q.put(data) 116 117 def cause_read_exc(self, exc: Exception) -> None: 118 self._q.put(exc) 119 120 def set_read_eof(self) -> None: 121 self._q.put(self.EOF) 122 123 def wait_for_drain(self, timeout=None) -> None: 124 """Wait for the queue to drain (be fully consumed). 125 126 Args: 127 timeout: The maximum time (in seconds) to wait, or wait forever 128 if None. 129 130 Raises: 131 TimeoutError: If timeout is given and has elapsed. 132 """ 133 # It would be great to use Queue.join() here, but that requires the 134 # consumer to call Queue.task_done(), and we can't do that because 135 # the consumer of read() doesn't know anything about it. 136 # Instead, we poll. ¯\_(ツ)_/¯ 137 start_time = time.time() 138 while not self._q.empty(): 139 if timeout is not None: 140 elapsed = time.time() - start_time 141 if elapsed > timeout: 142 raise TimeoutError(f"Queue not empty after {elapsed} sec") 143 time.sleep(0.1) 144 145 146class QueueFileTest(unittest.TestCase): 147 """Test the QueueFile class""" 148 149 def test_read_data(self) -> None: 150 file = QueueFile() 151 file.put_read_data(b'hello') 152 self.assertEqual(file.read(5), b'hello') 153 154 def test_read_data_multi_read(self) -> None: 155 file = QueueFile() 156 file.put_read_data(b'helloworld') 157 self.assertEqual(file.read(5), b'hello') 158 self.assertEqual(file.read(5), b'world') 159 160 def test_read_data_multi_put(self) -> None: 161 file = QueueFile() 162 file.put_read_data(b'hello') 163 file.put_read_data(b'world') 164 self.assertEqual(file.read(5), b'hello') 165 self.assertEqual(file.read(5), b'world') 166 167 def test_read_eof(self) -> None: 168 file = QueueFile() 169 file.set_read_eof() 170 result = file.read(5) 171 self.assertEqual(result, b'') 172 173 def test_read_exception(self) -> None: 174 file = QueueFile() 175 message = 'test exception' 176 file.cause_read_exc(ValueError(message)) 177 with self.assertRaisesRegex(ValueError, message): 178 file.read(5) 179 180 def test_wait_for_drain_works(self) -> None: 181 file = QueueFile() 182 file.put_read_data(b'hello') 183 file.read() 184 try: 185 # Timeout is arbitrary; will return immediately. 186 file.wait_for_drain(0.1) 187 except TimeoutError: 188 self.fail("wait_for_drain raised TimeoutError") 189 190 def test_wait_for_drain_raises(self) -> None: 191 file = QueueFile() 192 file.put_read_data(b'hello') 193 # don't read 194 with self.assertRaises(TimeoutError): 195 # Timeout is arbitrary; it will raise no matter what. 196 file.wait_for_drain(0.1) 197 198 199class Sentinel: 200 def __repr__(self): 201 return 'Sentinel' 202 203 204class _QueueReader(CancellableReader): 205 def cancel_read(self) -> None: 206 self._base_obj.close() 207 208 209def _get_client(file) -> RpcClient: 210 return HdlcRpcClient( 211 _QueueReader(file), 212 paths_or_modules=[], 213 channels=[], 214 ) 215 216 217# This should take <10ms but we'll wait up to 1000x longer. 218_QUEUE_DRAIN_TIMEOUT = 10.0 219 220 221class HdlcRpcClientTest(unittest.TestCase): 222 """Tests the pw_hdlc.rpc.HdlcRpcClient class.""" 223 224 # NOTE: There is no test here for stream EOF because Serial.read() 225 # can return an empty result if configured with timeout != None. 226 # The reader thread will continue in this case. 227 228 def test_clean_close_after_stream_close(self) -> None: 229 """Assert RpcClient closes cleanly when stream closes.""" 230 # See b/293595266. 231 file = QueueFile() 232 233 with self.assert_no_hdlc_rpc_error_logs(): 234 with file: 235 with _get_client(file): 236 # We want to make sure the reader thread is blocked on 237 # read() and doesn't exit immediately. 238 file.put_read_data(b'') 239 file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT) 240 241 # RpcClient.__exit__ calls stop() on the reader thread, but 242 # it is blocked on file.read(). 243 244 # QueueFile.close() is called, triggering an exception in the 245 # blocking read() (by implementation choice). The reader should 246 # handle it by *not* logging it and exiting immediately. 247 248 self.assert_no_background_threads_running() 249 250 def test_device_handles_read_exception(self) -> None: 251 """Assert RpcClient closes cleanly when read raises an exception.""" 252 # See b/293595266. 253 file = QueueFile() 254 255 logger = logging.getLogger('pw_hdlc.rpc') 256 test_exc = Exception('boom') 257 with self.assertLogs(logger, level=logging.ERROR) as ctx: 258 with _get_client(file): 259 # Cause read() to raise an exception. The reader should 260 # handle it by logging it and exiting immediately. 261 file.cause_read_exc(test_exc) 262 file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT) 263 264 # Assert one exception was raised 265 self.assertEqual(len(ctx.records), 1) 266 rec = ctx.records[0] 267 self.assertIsNotNone(rec.exc_info) 268 assert rec.exc_info is not None # for mypy 269 self.assertEqual(rec.exc_info[1], test_exc) 270 271 self.assert_no_background_threads_running() 272 273 @contextmanager 274 def assert_no_hdlc_rpc_error_logs(self): 275 logger = logging.getLogger('pw_hdlc.rpc') 276 sentinel = Sentinel() 277 with self.assertLogs(logger, level=logging.ERROR) as ctx: 278 # TODO: b/294861320 - use assertNoLogs() in Python 3.10+ 279 # We actually want to assert there are no errors, but 280 # TestCase.assertNoLogs() is not available until Python 3.10. 281 # So we log one error to keep the test from failing and manually 282 # inspect the list of captured records. 283 logger.error(sentinel) 284 285 yield ctx 286 287 self.assertEqual([record.msg for record in ctx.records], [sentinel]) 288 289 def assert_no_background_threads_running(self): 290 self.assertEqual(threading.enumerate(), [threading.current_thread()]) 291 292 293if __name__ == '__main__': 294 unittest.main() 295