• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Utilities for using HDLC with pw_rpc."""
15
16import collections
17from concurrent.futures import ThreadPoolExecutor
18import io
19import logging
20from queue import SimpleQueue
21import random
22import sys
23import threading
24import time
25import socket
26import subprocess
27from typing import (Any, BinaryIO, Callable, Deque, Dict, Iterable, List,
28                    NoReturn, Optional, Sequence, Tuple, Union)
29
30from pw_protobuf_compiler import python_protos
31import pw_rpc
32from pw_rpc import callback_client
33
34from pw_hdlc.decode import Frame, FrameDecoder
35from pw_hdlc import encode
36
37_LOG = logging.getLogger(__name__)
38
39STDOUT_ADDRESS = 1
40DEFAULT_ADDRESS = ord('R')
41_VERBOSE = logging.DEBUG - 1
42
43
44def channel_output(writer: Callable[[bytes], Any],
45                   address: int = DEFAULT_ADDRESS,
46                   delay_s: float = 0) -> Callable[[bytes], None]:
47    """Returns a function that can be used as a channel output for pw_rpc."""
48
49    if delay_s:
50
51        def slow_write(data: bytes) -> None:
52            """Slows down writes in case unbuffered serial is in use."""
53            for byte in data:
54                time.sleep(delay_s)
55                writer(bytes([byte]))
56
57        return lambda data: slow_write(encode.ui_frame(address, data))
58
59    def write_hdlc(data: bytes):
60        frame = encode.ui_frame(address, data)
61        _LOG.log(_VERBOSE, 'Write %2d B: %s', len(frame), frame)
62        writer(frame)
63
64    return write_hdlc
65
66
67def _handle_error(frame: Frame) -> None:
68    _LOG.error('Failed to parse frame: %s', frame.status.value)
69    _LOG.debug('%s', frame.data)
70
71
72FrameHandlers = Dict[int, Callable[[Frame], Any]]
73
74
75def read_and_process_data(read: Callable[[], bytes],
76                          on_read_error: Callable[[Exception], Any],
77                          frame_handlers: FrameHandlers,
78                          error_handler: Callable[[Frame],
79                                                  Any] = _handle_error,
80                          handler_threads: Optional[int] = 1) -> NoReturn:
81    """Continuously reads and handles HDLC frames.
82
83    Passes frames to an executor that calls frame handler functions in other
84    threads.
85    """
86    def handle_frame(frame: Frame):
87        try:
88            if not frame.ok():
89                error_handler(frame)
90                return
91
92            try:
93                frame_handlers[frame.address](frame)
94            except KeyError:
95                _LOG.warning('Unhandled frame for address %d: %s',
96                             frame.address, frame)
97        except:  # pylint: disable=bare-except
98            _LOG.exception('Exception in HDLC frame handler thread')
99
100    decoder = FrameDecoder()
101
102    # Execute callbacks in a ThreadPoolExecutor to decouple reading the input
103    # stream from handling the data. That way, if a handler function takes a
104    # long time or crashes, this reading thread is not interrupted.
105    with ThreadPoolExecutor(max_workers=handler_threads) as executor:
106        while True:
107            try:
108                data = read()
109            except Exception as exc:  # pylint: disable=broad-except
110                on_read_error(exc)
111                continue
112
113            if data:
114                _LOG.log(_VERBOSE, 'Read %2d B: %s', len(data), data)
115
116                for frame in decoder.process_valid_frames(data):
117                    executor.submit(handle_frame, frame)
118
119
120def write_to_file(data: bytes, output: BinaryIO = sys.stdout.buffer):
121    output.write(data + b'\n')
122    output.flush()
123
124
125def default_channels(write: Callable[[bytes], Any]) -> List[pw_rpc.Channel]:
126    return [pw_rpc.Channel(1, channel_output(write))]
127
128
129PathsModulesOrProtoLibrary = Union[Iterable[python_protos.PathOrModule],
130                                   python_protos.Library]
131
132
133class HdlcRpcClient:
134    """An RPC client configured to run over HDLC."""
135    def __init__(self,
136                 read: Callable[[], bytes],
137                 paths_or_modules: PathsModulesOrProtoLibrary,
138                 channels: Iterable[pw_rpc.Channel],
139                 output: Callable[[bytes], Any] = write_to_file,
140                 client_impl: pw_rpc.client.ClientImpl = None,
141                 *,
142                 _incoming_packet_filter_for_testing: pw_rpc.
143                 ChannelManipulator = None):
144        """Creates an RPC client configured to communicate using HDLC.
145
146        Args:
147          read: Function that reads bytes; e.g serial_device.read.
148          paths_or_modules: paths to .proto files or proto modules
149          channel: RPC channels to use for output
150          output: where to write "stdout" output from the device
151        """
152        if isinstance(paths_or_modules, python_protos.Library):
153            self.protos = paths_or_modules
154        else:
155            self.protos = python_protos.Library.from_paths(paths_or_modules)
156
157        if client_impl is None:
158            client_impl = callback_client.Impl()
159
160        self.client = pw_rpc.Client.from_modules(client_impl, channels,
161                                                 self.protos.modules())
162
163        rpc_output: Callable[[bytes], Any] = self._handle_rpc_packet
164        if _incoming_packet_filter_for_testing is not None:
165            _incoming_packet_filter_for_testing.send_packet = rpc_output
166            rpc_output = _incoming_packet_filter_for_testing
167
168        frame_handlers: FrameHandlers = {
169            DEFAULT_ADDRESS: lambda frame: rpc_output(frame.data),
170            STDOUT_ADDRESS: lambda frame: output(frame.data),
171        }
172
173        # Start background thread that reads and processes RPC packets.
174        threading.Thread(target=read_and_process_data,
175                         daemon=True,
176                         args=(read, lambda exc: None,
177                               frame_handlers)).start()
178
179    def rpcs(self, channel_id: int = None) -> Any:
180        """Returns object for accessing services on the specified channel.
181
182        This skips some intermediate layers to make it simpler to invoke RPCs
183        from an HdlcRpcClient. If only one channel is in use, the channel ID is
184        not necessary.
185        """
186        if channel_id is None:
187            return next(iter(self.client.channels())).rpcs
188
189        return self.client.channel(channel_id).rpcs
190
191    def _handle_rpc_packet(self, packet: bytes) -> None:
192        if not self.client.process_packet(packet):
193            _LOG.error('Packet not handled by RPC client: %s', packet)
194
195
196def _try_connect(port: int, attempts: int = 10) -> socket.socket:
197    """Tries to connect to the specified port up to the given number of times.
198
199    This is helpful when connecting to a process that was started by this
200    script. The process may need time to start listening for connections, and
201    that length of time can vary. This retries with a short delay rather than
202    having to wait for the worst case delay every time.
203    """
204    while True:
205        attempts -= 1
206        time.sleep(0.001)
207
208        try:
209            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
210            sock.connect(('localhost', port))
211            return sock
212        except ConnectionRefusedError:
213            sock.close()
214            if attempts <= 0:
215                raise
216
217
218class SocketSubprocess:
219    """Executes a subprocess and connects to it with a socket."""
220    def __init__(self, command: Sequence, port: int) -> None:
221        self._server_process = subprocess.Popen(command, stdin=subprocess.PIPE)
222        self.stdin = self._server_process.stdin
223
224        try:
225            self.socket: socket.socket = _try_connect(port)  # ��
226        except:
227            self._server_process.terminate()
228            self._server_process.communicate()
229            raise
230
231    def close(self) -> None:
232        try:
233            self.socket.close()
234        finally:
235            self._server_process.terminate()
236            self._server_process.communicate()
237
238    def __enter__(self) -> 'SocketSubprocess':
239        return self
240
241    def __exit__(self, exc_type, exc_value, traceback) -> None:
242        self.close()
243
244
245class PacketFilter(pw_rpc.ChannelManipulator):
246    """Determines if a packet should be kept or dropped for testing purposes."""
247    _Action = Callable[[int], Tuple[bool, bool]]
248    _KEEP = lambda _: (True, False)
249    _DROP = lambda _: (False, False)
250
251    def __init__(self, name: str) -> None:
252        super().__init__()
253        self.name = name
254        self.packet_count = 0
255        self._actions: Deque[PacketFilter._Action] = collections.deque()
256
257    def process_and_send(self, packet: bytes):
258        if self.keep_packet(packet):
259            self.send_packet(packet)
260
261    def reset(self) -> None:
262        self.packet_count = 0
263        self._actions.clear()
264
265    def keep(self, count: int) -> None:
266        """Keeps the next count packets."""
267        self._actions.extend(PacketFilter._KEEP for _ in range(count))
268
269    def drop(self, count: int) -> None:
270        """Drops the next count packets."""
271        self._actions.extend(PacketFilter._DROP for _ in range(count))
272
273    def drop_every(self, every: int) -> None:
274        """Drops every Nth packet forever."""
275        self._actions.append(lambda count: (count % every != 0, True))
276
277    def randomly_drop(self, one_in: int, gen: random.Random) -> None:
278        """Drops packets randomly forever."""
279        self._actions.append(lambda _: (gen.randrange(one_in) != 0, True))
280
281    def keep_packet(self, packet: bytes) -> bool:
282        """Returns whether the provided packet should be kept or dropped."""
283        self.packet_count += 1
284
285        if not self._actions:
286            return True
287
288        keep, repeat = self._actions[0](self.packet_count)
289
290        if not repeat:
291            self._actions.popleft()
292
293        if not keep:
294            _LOG.debug('Dropping %s packet %d for testing: %s', self.name,
295                       self.packet_count, packet)
296        return keep
297
298
299class HdlcRpcLocalServerAndClient:
300    """Runs an RPC server in a subprocess and connects to it over a socket.
301
302    This can be used to run a local RPC server in an integration test.
303    """
304    def __init__(
305        self,
306        server_command: Sequence,
307        port: int,
308        protos: PathsModulesOrProtoLibrary,
309        *,
310        incoming_processor: Optional[pw_rpc.ChannelManipulator] = None,
311        outgoing_processor: Optional[pw_rpc.ChannelManipulator] = None
312    ) -> None:
313        """Creates a new HdlcRpcLocalServerAndClient."""
314
315        self.server = SocketSubprocess(server_command, port)
316
317        self._bytes_queue: 'SimpleQueue[bytes]' = SimpleQueue()
318        self._read_thread = threading.Thread(target=self._read_from_socket)
319        self._read_thread.start()
320
321        self.output = io.BytesIO()
322
323        self.channel_output: Any = self.server.socket.sendall
324
325        self._incoming_processor = incoming_processor
326        if outgoing_processor is not None:
327            outgoing_processor.send_packet = self.channel_output
328            self.channel_output = outgoing_processor
329
330        self.client = HdlcRpcClient(
331            self._bytes_queue.get,
332            protos,
333            default_channels(self.channel_output),
334            self.output.write,
335            _incoming_packet_filter_for_testing=incoming_processor).client
336
337    def _read_from_socket(self):
338        while True:
339            data = self.server.socket.recv(4096)
340            self._bytes_queue.put(data)
341            if not data:
342                return
343
344    def close(self):
345        self.server.close()
346        self.output.close()
347        self._read_thread.join()
348
349    def __enter__(self) -> 'HdlcRpcLocalServerAndClient':
350        return self
351
352    def __exit__(self, exc_type, exc_value, traceback) -> None:
353        self.close()
354