1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Utilities for using HDLC with pw_rpc.""" 15 16import collections 17from concurrent.futures import ThreadPoolExecutor 18import io 19import logging 20from queue import SimpleQueue 21import random 22import sys 23import threading 24import time 25import socket 26import subprocess 27from typing import (Any, BinaryIO, Callable, Deque, Dict, Iterable, List, 28 NoReturn, Optional, Sequence, Tuple, Union) 29 30from pw_protobuf_compiler import python_protos 31import pw_rpc 32from pw_rpc import callback_client 33 34from pw_hdlc.decode import Frame, FrameDecoder 35from pw_hdlc import encode 36 37_LOG = logging.getLogger(__name__) 38 39STDOUT_ADDRESS = 1 40DEFAULT_ADDRESS = ord('R') 41_VERBOSE = logging.DEBUG - 1 42 43 44def channel_output(writer: Callable[[bytes], Any], 45 address: int = DEFAULT_ADDRESS, 46 delay_s: float = 0) -> Callable[[bytes], None]: 47 """Returns a function that can be used as a channel output for pw_rpc.""" 48 49 if delay_s: 50 51 def slow_write(data: bytes) -> None: 52 """Slows down writes in case unbuffered serial is in use.""" 53 for byte in data: 54 time.sleep(delay_s) 55 writer(bytes([byte])) 56 57 return lambda data: slow_write(encode.ui_frame(address, data)) 58 59 def write_hdlc(data: bytes): 60 frame = encode.ui_frame(address, data) 61 _LOG.log(_VERBOSE, 'Write %2d B: %s', len(frame), frame) 62 writer(frame) 63 64 return write_hdlc 65 66 67def _handle_error(frame: Frame) -> None: 68 _LOG.error('Failed to parse frame: %s', frame.status.value) 69 _LOG.debug('%s', frame.data) 70 71 72FrameHandlers = Dict[int, Callable[[Frame], Any]] 73 74 75def read_and_process_data(read: Callable[[], bytes], 76 on_read_error: Callable[[Exception], Any], 77 frame_handlers: FrameHandlers, 78 error_handler: Callable[[Frame], 79 Any] = _handle_error, 80 handler_threads: Optional[int] = 1) -> NoReturn: 81 """Continuously reads and handles HDLC frames. 82 83 Passes frames to an executor that calls frame handler functions in other 84 threads. 85 """ 86 def handle_frame(frame: Frame): 87 try: 88 if not frame.ok(): 89 error_handler(frame) 90 return 91 92 try: 93 frame_handlers[frame.address](frame) 94 except KeyError: 95 _LOG.warning('Unhandled frame for address %d: %s', 96 frame.address, frame) 97 except: # pylint: disable=bare-except 98 _LOG.exception('Exception in HDLC frame handler thread') 99 100 decoder = FrameDecoder() 101 102 # Execute callbacks in a ThreadPoolExecutor to decouple reading the input 103 # stream from handling the data. That way, if a handler function takes a 104 # long time or crashes, this reading thread is not interrupted. 105 with ThreadPoolExecutor(max_workers=handler_threads) as executor: 106 while True: 107 try: 108 data = read() 109 except Exception as exc: # pylint: disable=broad-except 110 on_read_error(exc) 111 continue 112 113 if data: 114 _LOG.log(_VERBOSE, 'Read %2d B: %s', len(data), data) 115 116 for frame in decoder.process_valid_frames(data): 117 executor.submit(handle_frame, frame) 118 119 120def write_to_file(data: bytes, output: BinaryIO = sys.stdout.buffer): 121 output.write(data + b'\n') 122 output.flush() 123 124 125def default_channels(write: Callable[[bytes], Any]) -> List[pw_rpc.Channel]: 126 return [pw_rpc.Channel(1, channel_output(write))] 127 128 129PathsModulesOrProtoLibrary = Union[Iterable[python_protos.PathOrModule], 130 python_protos.Library] 131 132 133class HdlcRpcClient: 134 """An RPC client configured to run over HDLC.""" 135 def __init__(self, 136 read: Callable[[], bytes], 137 paths_or_modules: PathsModulesOrProtoLibrary, 138 channels: Iterable[pw_rpc.Channel], 139 output: Callable[[bytes], Any] = write_to_file, 140 client_impl: pw_rpc.client.ClientImpl = None, 141 *, 142 _incoming_packet_filter_for_testing: pw_rpc. 143 ChannelManipulator = None): 144 """Creates an RPC client configured to communicate using HDLC. 145 146 Args: 147 read: Function that reads bytes; e.g serial_device.read. 148 paths_or_modules: paths to .proto files or proto modules 149 channel: RPC channels to use for output 150 output: where to write "stdout" output from the device 151 """ 152 if isinstance(paths_or_modules, python_protos.Library): 153 self.protos = paths_or_modules 154 else: 155 self.protos = python_protos.Library.from_paths(paths_or_modules) 156 157 if client_impl is None: 158 client_impl = callback_client.Impl() 159 160 self.client = pw_rpc.Client.from_modules(client_impl, channels, 161 self.protos.modules()) 162 163 rpc_output: Callable[[bytes], Any] = self._handle_rpc_packet 164 if _incoming_packet_filter_for_testing is not None: 165 _incoming_packet_filter_for_testing.send_packet = rpc_output 166 rpc_output = _incoming_packet_filter_for_testing 167 168 frame_handlers: FrameHandlers = { 169 DEFAULT_ADDRESS: lambda frame: rpc_output(frame.data), 170 STDOUT_ADDRESS: lambda frame: output(frame.data), 171 } 172 173 # Start background thread that reads and processes RPC packets. 174 threading.Thread(target=read_and_process_data, 175 daemon=True, 176 args=(read, lambda exc: None, 177 frame_handlers)).start() 178 179 def rpcs(self, channel_id: int = None) -> Any: 180 """Returns object for accessing services on the specified channel. 181 182 This skips some intermediate layers to make it simpler to invoke RPCs 183 from an HdlcRpcClient. If only one channel is in use, the channel ID is 184 not necessary. 185 """ 186 if channel_id is None: 187 return next(iter(self.client.channels())).rpcs 188 189 return self.client.channel(channel_id).rpcs 190 191 def _handle_rpc_packet(self, packet: bytes) -> None: 192 if not self.client.process_packet(packet): 193 _LOG.error('Packet not handled by RPC client: %s', packet) 194 195 196def _try_connect(port: int, attempts: int = 10) -> socket.socket: 197 """Tries to connect to the specified port up to the given number of times. 198 199 This is helpful when connecting to a process that was started by this 200 script. The process may need time to start listening for connections, and 201 that length of time can vary. This retries with a short delay rather than 202 having to wait for the worst case delay every time. 203 """ 204 while True: 205 attempts -= 1 206 time.sleep(0.001) 207 208 try: 209 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 210 sock.connect(('localhost', port)) 211 return sock 212 except ConnectionRefusedError: 213 sock.close() 214 if attempts <= 0: 215 raise 216 217 218class SocketSubprocess: 219 """Executes a subprocess and connects to it with a socket.""" 220 def __init__(self, command: Sequence, port: int) -> None: 221 self._server_process = subprocess.Popen(command, stdin=subprocess.PIPE) 222 self.stdin = self._server_process.stdin 223 224 try: 225 self.socket: socket.socket = _try_connect(port) # 226 except: 227 self._server_process.terminate() 228 self._server_process.communicate() 229 raise 230 231 def close(self) -> None: 232 try: 233 self.socket.close() 234 finally: 235 self._server_process.terminate() 236 self._server_process.communicate() 237 238 def __enter__(self) -> 'SocketSubprocess': 239 return self 240 241 def __exit__(self, exc_type, exc_value, traceback) -> None: 242 self.close() 243 244 245class PacketFilter(pw_rpc.ChannelManipulator): 246 """Determines if a packet should be kept or dropped for testing purposes.""" 247 _Action = Callable[[int], Tuple[bool, bool]] 248 _KEEP = lambda _: (True, False) 249 _DROP = lambda _: (False, False) 250 251 def __init__(self, name: str) -> None: 252 super().__init__() 253 self.name = name 254 self.packet_count = 0 255 self._actions: Deque[PacketFilter._Action] = collections.deque() 256 257 def process_and_send(self, packet: bytes): 258 if self.keep_packet(packet): 259 self.send_packet(packet) 260 261 def reset(self) -> None: 262 self.packet_count = 0 263 self._actions.clear() 264 265 def keep(self, count: int) -> None: 266 """Keeps the next count packets.""" 267 self._actions.extend(PacketFilter._KEEP for _ in range(count)) 268 269 def drop(self, count: int) -> None: 270 """Drops the next count packets.""" 271 self._actions.extend(PacketFilter._DROP for _ in range(count)) 272 273 def drop_every(self, every: int) -> None: 274 """Drops every Nth packet forever.""" 275 self._actions.append(lambda count: (count % every != 0, True)) 276 277 def randomly_drop(self, one_in: int, gen: random.Random) -> None: 278 """Drops packets randomly forever.""" 279 self._actions.append(lambda _: (gen.randrange(one_in) != 0, True)) 280 281 def keep_packet(self, packet: bytes) -> bool: 282 """Returns whether the provided packet should be kept or dropped.""" 283 self.packet_count += 1 284 285 if not self._actions: 286 return True 287 288 keep, repeat = self._actions[0](self.packet_count) 289 290 if not repeat: 291 self._actions.popleft() 292 293 if not keep: 294 _LOG.debug('Dropping %s packet %d for testing: %s', self.name, 295 self.packet_count, packet) 296 return keep 297 298 299class HdlcRpcLocalServerAndClient: 300 """Runs an RPC server in a subprocess and connects to it over a socket. 301 302 This can be used to run a local RPC server in an integration test. 303 """ 304 def __init__( 305 self, 306 server_command: Sequence, 307 port: int, 308 protos: PathsModulesOrProtoLibrary, 309 *, 310 incoming_processor: Optional[pw_rpc.ChannelManipulator] = None, 311 outgoing_processor: Optional[pw_rpc.ChannelManipulator] = None 312 ) -> None: 313 """Creates a new HdlcRpcLocalServerAndClient.""" 314 315 self.server = SocketSubprocess(server_command, port) 316 317 self._bytes_queue: 'SimpleQueue[bytes]' = SimpleQueue() 318 self._read_thread = threading.Thread(target=self._read_from_socket) 319 self._read_thread.start() 320 321 self.output = io.BytesIO() 322 323 self.channel_output: Any = self.server.socket.sendall 324 325 self._incoming_processor = incoming_processor 326 if outgoing_processor is not None: 327 outgoing_processor.send_packet = self.channel_output 328 self.channel_output = outgoing_processor 329 330 self.client = HdlcRpcClient( 331 self._bytes_queue.get, 332 protos, 333 default_channels(self.channel_output), 334 self.output.write, 335 _incoming_packet_filter_for_testing=incoming_processor).client 336 337 def _read_from_socket(self): 338 while True: 339 data = self.server.socket.recv(4096) 340 self._bytes_queue.put(data) 341 if not data: 342 return 343 344 def close(self): 345 self.server.close() 346 self.output.close() 347 self._read_thread.join() 348 349 def __enter__(self) -> 'HdlcRpcLocalServerAndClient': 350 return self 351 352 def __exit__(self, exc_type, exc_value, traceback) -> None: 353 self.close() 354