• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2022 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Proxy for transfer integration testing.
16
17This module contains a proxy for transfer intergation testing.  It is capable
18of introducing various link failures into the connection between the client and
19server.
20"""
21
22import abc
23import argparse
24import asyncio
25from enum import Enum
26import logging
27import random
28import socket
29import sys
30import time
31from typing import Any, Awaitable, Callable, Iterable, List, Optional
32
33from google.protobuf import text_format
34
35from pigweed.pw_rpc.internal import packet_pb2
36from pigweed.pw_transfer import transfer_pb2
37from pigweed.pw_transfer.integration_test import config_pb2
38from pw_hdlc import decode
39from pw_transfer.chunk import Chunk
40
41_LOG = logging.getLogger('pw_transfer_intergration_test_proxy')
42
43# This is the maximum size of the socket receive buffers. Ideally, this is set
44# to the lowest allowed value to minimize buffering between the proxy and
45# clients so rate limiting causes the client to block and wait for the
46# integration test proxy to drain rather than allowing OS buffers to backlog
47# large quantities of data.
48#
49# Note that the OS may chose to not strictly follow this requested buffer size.
50# Still, setting this value to be relatively small does reduce bufer sizes
51# significantly enough to better reflect typical inter-device communication.
52#
53# For this to be effective, clients should also configure their sockets to a
54# smaller send buffer size.
55_RECEIVE_BUFFER_SIZE = 2048
56
57
58class Event(Enum):
59    TRANSFER_START = 1
60    PARAMETERS_RETRANSMIT = 2
61    PARAMETERS_CONTINUE = 3
62    START_ACK_CONFIRMATION = 4
63
64
65class Filter(abc.ABC):
66    """An abstract interface for manipulating a stream of data.
67
68    ``Filter``s are used to implement various transforms to simulate real
69    world link properties.  Some examples include: data corruption,
70    packet loss, packet reordering, rate limiting, latency modeling.
71
72    A ``Filter`` implementation should implement the ``process`` method
73    and call ``self.send_data()`` when it has data to send.
74    """
75
76    def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
77        self.send_data = send_data
78        pass
79
80    @abc.abstractmethod
81    async def process(self, data: bytes) -> None:
82        """Processes incoming data.
83
84        Implementations of this method may send arbitrary data, or none, using
85        the ``self.send_data()`` handler.
86        """
87
88    async def __call__(self, data: bytes) -> None:
89        await self.process(data)
90
91
92class HdlcPacketizer(Filter):
93    """A filter which aggregates data into complete HDLC packets.
94
95    Since the proxy transport (SOCK_STREAM) has no framing and we want some
96    filters to operates on whole frames, this filter can be used so that
97    downstream filters see whole frames.
98    """
99
100    def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
101        super().__init__(send_data)
102        self.decoder = decode.FrameDecoder()
103
104    async def process(self, data: bytes) -> None:
105        for frame in self.decoder.process(data):
106            await self.send_data(frame.raw_encoded)
107
108
109class DataDropper(Filter):
110    """A filter which drops some data.
111
112    DataDropper will drop data passed through ``process()`` at the
113    specified ``rate``.
114    """
115
116    def __init__(
117        self,
118        send_data: Callable[[bytes], Awaitable[None]],
119        name: str,
120        rate: float,
121        seed: Optional[int] = None,
122    ):
123        super().__init__(send_data)
124        self._rate = rate
125        self._name = name
126        if seed == None:
127            seed = time.time_ns()
128        self._rng = random.Random(seed)
129        _LOG.info(f'{name} DataDropper initialized with seed {seed}')
130
131    async def process(self, data: bytes) -> None:
132        if self._rng.uniform(0.0, 1.0) < self._rate:
133            _LOG.info(f'{self._name} dropped {len(data)} bytes of data')
134        else:
135            await self.send_data(data)
136
137
138class KeepDropQueue(Filter):
139    """A filter which alternates between sending packets and dropping packets.
140
141    A KeepDropQueue filter will alternate between keeping packets and dropping
142    chunks of data based on a keep/drop queue provided during its creation. The
143    queue is looped over unless a negative element is found. A negative number
144    is effectively the same as a value of infinity.
145
146     This filter is typically most pratical when used with a packetizer so data
147     can be dropped as distinct packets.
148
149    Examples:
150
151      keep_drop_queue = [3, 2]:
152        Keeps 3 packets,
153        Drops 2 packets,
154        Keeps 3 packets,
155        Drops 2 packets,
156        ... [loops indefinitely]
157
158      keep_drop_queue = [5, 99, 1, -1]:
159        Keeps 5 packets,
160        Drops 99 packets,
161        Keeps 1 packet,
162        Drops all further packets.
163    """
164
165    def __init__(
166        self,
167        send_data: Callable[[bytes], Awaitable[None]],
168        name: str,
169        keep_drop_queue: Iterable[int],
170    ):
171        super().__init__(send_data)
172        self._keep_drop_queue = list(keep_drop_queue)
173        self._loop_idx = 0
174        self._current_count = self._keep_drop_queue[0]
175        self._keep = True
176        self._name = name
177
178    async def process(self, data: bytes) -> None:
179        # Move forward through the queue if neeeded.
180        while self._current_count == 0:
181            self._loop_idx += 1
182            self._current_count = self._keep_drop_queue[
183                self._loop_idx % len(self._keep_drop_queue)
184            ]
185            self._keep = not self._keep
186
187        if self._current_count > 0:
188            self._current_count -= 1
189
190        if self._keep:
191            await self.send_data(data)
192            _LOG.info(f'{self._name} forwarded {len(data)} bytes of data')
193        else:
194            _LOG.info(f'{self._name} dropped {len(data)} bytes of data')
195
196
197class RateLimiter(Filter):
198    """A filter which limits transmission rate.
199
200    This filter delays transmission of data by len(data)/rate.
201    """
202
203    def __init__(
204        self, send_data: Callable[[bytes], Awaitable[None]], rate: float
205    ):
206        super().__init__(send_data)
207        self._rate = rate
208
209    async def process(self, data: bytes) -> None:
210        delay = len(data) / self._rate
211        await asyncio.sleep(delay)
212        await self.send_data(data)
213
214
215class DataTransposer(Filter):
216    """A filter which occasionally transposes two chunks of data.
217
218    This filter transposes data at the specified rate.  It does this by
219    holding a chunk to transpose until another chunk arrives. The filter
220    will not hold a chunk longer than ``timeout`` seconds.
221    """
222
223    def __init__(
224        self,
225        send_data: Callable[[bytes], Awaitable[None]],
226        name: str,
227        rate: float,
228        timeout: float,
229        seed: int,
230    ):
231        super().__init__(send_data)
232        self._name = name
233        self._rate = rate
234        self._timeout = timeout
235        self._data_queue = asyncio.Queue()
236        self._rng = random.Random(seed)
237        self._transpose_task = asyncio.create_task(self._transpose_handler())
238
239        _LOG.info(f'{name} DataTranspose initialized with seed {seed}')
240
241    def __del__(self):
242        _LOG.info(f'{self._name} cleaning up transpose task.')
243        self._transpose_task.cancel()
244
245    async def _transpose_handler(self):
246        """Async task that handles the packet transposition and timeouts"""
247        held_data: Optional[bytes] = None
248        while True:
249            # Only use timeout if we have data held for transposition
250            timeout = None if held_data is None else self._timeout
251            try:
252                data = await asyncio.wait_for(
253                    self._data_queue.get(), timeout=timeout
254                )
255
256                if held_data is not None:
257                    # If we have held data, send it out of order.
258                    await self.send_data(data)
259                    await self.send_data(held_data)
260                    held_data = None
261                else:
262                    # Otherwise decide if we should transpose the current data.
263                    if self._rng.uniform(0.0, 1.0) < self._rate:
264                        _LOG.info(
265                            f'{self._name} transposing {len(data)} bytes of data'
266                        )
267                        held_data = data
268                    else:
269                        await self.send_data(data)
270
271            except asyncio.TimeoutError:
272                _LOG.info(f'{self._name} sending data in order due to timeout')
273                await self.send_data(held_data)
274                held_data = None
275
276    async def process(self, data: bytes) -> None:
277        # Queue data for processing by the transpose task.
278        await self._data_queue.put(data)
279
280
281class ServerFailure(Filter):
282    """A filter to simulate the server stopping sending packets.
283
284    ServerFailure takes a list of numbers of packets to send before
285    dropping all subsequent packets until a TRANSFER_START packet
286    is seen.  This process is repeated for each element in
287    packets_before_failure.  After that list is exhausted, ServerFailure
288    will send all packets.
289
290    This filter should be instantiated in the same filter stack as an
291    HdlcPacketizer so that EventFilter can decode complete packets.
292    """
293
294    def __init__(
295        self,
296        send_data: Callable[[bytes], Awaitable[None]],
297        name: str,
298        packets_before_failure_list: List[int],
299    ):
300        super().__init__(send_data)
301        self._name = name
302        self._relay_packets = True
303        self._packets_before_failure_list = packets_before_failure_list
304        self.advance_packets_before_failure()
305
306    def advance_packets_before_failure(self):
307        if len(self._packets_before_failure_list) > 0:
308            self._packets_before_failure = (
309                self._packets_before_failure_list.pop(0)
310            )
311        else:
312            self._packets_before_failure = None
313
314    async def process(self, data: bytes) -> None:
315        if self._packets_before_failure is None:
316            await self.send_data(data)
317        elif self._packets_before_failure > 0:
318            self._packets_before_failure -= 1
319            await self.send_data(data)
320
321    def handle_event(self, event: Event) -> None:
322        if event is Event.TRANSFER_START:
323            self.advance_packets_before_failure()
324
325
326class WindowPacketDropper(Filter):
327    """A filter to allow the same packet in each window to be dropped
328
329    WindowPacketDropper with drop the nth packet in each window as
330    specified by window_packet_to_drop.  This process will happend
331    indefinitely for each window.
332
333    This filter should be instantiated in the same filter stack as an
334    HdlcPacketizer so that EventFilter can decode complete packets.
335    """
336
337    def __init__(
338        self,
339        send_data: Callable[[bytes], Awaitable[None]],
340        name: str,
341        window_packet_to_drop: int,
342    ):
343        super().__init__(send_data)
344        self._name = name
345        self._relay_packets = True
346        self._window_packet_to_drop = window_packet_to_drop
347        self._window_packet = 0
348
349    async def process(self, data: bytes) -> None:
350        try:
351            is_data_chunk = (
352                _extract_transfer_chunk(data).type is Chunk.Type.DATA
353            )
354        except Exception:
355            # Invalid / non-chunk data (e.g. text logs); ignore.
356            is_data_chunk = False
357
358        # Only count transfer data chunks as part of a window.
359        if is_data_chunk:
360            if self._window_packet != self._window_packet_to_drop:
361                await self.send_data(data)
362
363            self._window_packet += 1
364        else:
365            await self.send_data(data)
366
367    def handle_event(self, event: Event) -> None:
368        if event in (
369            Event.PARAMETERS_RETRANSMIT,
370            Event.PARAMETERS_CONTINUE,
371            Event.START_ACK_CONFIRMATION,
372        ):
373            self._window_packet = 0
374
375
376class EventFilter(Filter):
377    """A filter that inspects packets and send events to other filters.
378
379    This filter should be instantiated in the same filter stack as an
380    HdlcPacketizer so that it can decode complete packets.
381    """
382
383    def __init__(
384        self,
385        send_data: Callable[[bytes], Awaitable[None]],
386        name: str,
387        event_queue: asyncio.Queue,
388    ):
389        super().__init__(send_data)
390        self._queue = event_queue
391
392    async def process(self, data: bytes) -> None:
393        try:
394            chunk = _extract_transfer_chunk(data)
395            if chunk.type is Chunk.Type.START:
396                await self._queue.put(Event.TRANSFER_START)
397            if chunk.type is Chunk.Type.START_ACK_CONFIRMATION:
398                await self._queue.put(Event.START_ACK_CONFIRMATION)
399            elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT:
400                await self._queue.put(Event.PARAMETERS_RETRANSMIT)
401            elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE:
402                await self._queue.put(Event.PARAMETERS_CONTINUE)
403        except:
404            # Silently ignore invalid packets
405            pass
406
407        await self.send_data(data)
408
409
410def _extract_transfer_chunk(data: bytes) -> Chunk:
411    """Gets a transfer Chunk from an HDLC frame containing an RPC packet.
412
413    Raises an exception if a valid chunk does not exist.
414    """
415
416    decoder = decode.FrameDecoder()
417    for frame in decoder.process(data):
418        packet = packet_pb2.RpcPacket()
419        packet.ParseFromString(frame.data)
420        raw_chunk = transfer_pb2.Chunk()
421        raw_chunk.ParseFromString(packet.payload)
422        return Chunk.from_message(raw_chunk)
423
424    raise ValueError("Invalid transfer frame")
425
426
427async def _handle_simplex_events(
428    event_queue: asyncio.Queue, handlers: List[Callable[[Event], None]]
429):
430    while True:
431        event = await event_queue.get()
432        for handler in handlers:
433            handler(event)
434
435
436async def _handle_simplex_connection(
437    name: str,
438    filter_stack_config: List[config_pb2.FilterConfig],
439    reader: asyncio.StreamReader,
440    writer: asyncio.StreamWriter,
441    inbound_event_queue: asyncio.Queue,
442    outbound_event_queue: asyncio.Queue,
443) -> None:
444    """Handle a single direction of a bidirectional connection between
445    server and client."""
446
447    async def send(data: bytes):
448        writer.write(data)
449        await writer.drain()
450
451    filter_stack = EventFilter(send, name, outbound_event_queue)
452
453    event_handlers: List[Callable[[Event], None]] = []
454
455    # Build the filter stack from the bottom up
456    for config in reversed(filter_stack_config):
457        filter_name = config.WhichOneof("filter")
458        if filter_name == "hdlc_packetizer":
459            filter_stack = HdlcPacketizer(filter_stack)
460        elif filter_name == "data_dropper":
461            data_dropper = config.data_dropper
462            filter_stack = DataDropper(
463                filter_stack, name, data_dropper.rate, data_dropper.seed
464            )
465        elif filter_name == "rate_limiter":
466            filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate)
467        elif filter_name == "data_transposer":
468            transposer = config.data_transposer
469            filter_stack = DataTransposer(
470                filter_stack,
471                name,
472                transposer.rate,
473                transposer.timeout,
474                transposer.seed,
475            )
476        elif filter_name == "server_failure":
477            server_failure = config.server_failure
478            filter_stack = ServerFailure(
479                filter_stack, name, server_failure.packets_before_failure
480            )
481            event_handlers.append(filter_stack.handle_event)
482        elif filter_name == "keep_drop_queue":
483            keep_drop_queue = config.keep_drop_queue
484            filter_stack = KeepDropQueue(
485                filter_stack, name, keep_drop_queue.keep_drop_queue
486            )
487        elif filter_name == "window_packet_dropper":
488            window_packet_dropper = config.window_packet_dropper
489            filter_stack = WindowPacketDropper(
490                filter_stack, name, window_packet_dropper.window_packet_to_drop
491            )
492            event_handlers.append(filter_stack.handle_event)
493        else:
494            sys.exit(f'Unknown filter {filter_name}')
495
496    event_task = asyncio.create_task(
497        _handle_simplex_events(inbound_event_queue, event_handlers)
498    )
499
500    while True:
501        # Arbitrarily chosen "page sized" read.
502        data = await reader.read(4096)
503
504        # An empty data indicates that the connection is closed.
505        if not data:
506            _LOG.info(f'{name} connection closed.')
507            return
508
509        await filter_stack.process(data)
510
511
512async def _handle_connection(
513    server_port: int,
514    config: config_pb2.ProxyConfig,
515    client_reader: asyncio.StreamReader,
516    client_writer: asyncio.StreamWriter,
517) -> None:
518    """Handle a connection between server and client."""
519
520    client_addr = client_writer.get_extra_info('peername')
521    _LOG.info(f'New client connection from {client_addr}')
522
523    # Open a new connection to the server for each client connection.
524    #
525    # TODO(konkers): catch exception and close client writer
526    server_reader, server_writer = await asyncio.open_connection(
527        'localhost', server_port
528    )
529    _LOG.info(f'New connection opened to server')
530
531    # Queues for the simplex connections to pass events to each other.
532    server_event_queue = asyncio.Queue()
533    client_event_queue = asyncio.Queue()
534
535    # Instantiate two simplex handler one for each direction of the connection.
536    _, pending = await asyncio.wait(
537        [
538            asyncio.create_task(
539                _handle_simplex_connection(
540                    "client",
541                    config.client_filter_stack,
542                    client_reader,
543                    server_writer,
544                    server_event_queue,
545                    client_event_queue,
546                )
547            ),
548            asyncio.create_task(
549                _handle_simplex_connection(
550                    "server",
551                    config.server_filter_stack,
552                    server_reader,
553                    client_writer,
554                    client_event_queue,
555                    server_event_queue,
556                )
557            ),
558        ],
559        return_when=asyncio.FIRST_COMPLETED,
560    )
561
562    # When one side terminates the connection, also terminate the other side
563    for task in pending:
564        task.cancel()
565
566    for stream in [client_writer, server_writer]:
567        stream.close()
568
569
570def _parse_args() -> argparse.Namespace:
571    parser = argparse.ArgumentParser(
572        description=__doc__,
573        formatter_class=argparse.RawDescriptionHelpFormatter,
574    )
575
576    parser.add_argument(
577        '--server-port',
578        type=int,
579        required=True,
580        help='Port of the integration test server.  The proxy will forward connections to this port',
581    )
582    parser.add_argument(
583        '--client-port',
584        type=int,
585        required=True,
586        help='Port on which to listen for connections from integration test client.',
587    )
588
589    return parser.parse_args()
590
591
592def _init_logging(level: int) -> None:
593    _LOG.setLevel(logging.DEBUG)
594    log_to_stderr = logging.StreamHandler()
595    log_to_stderr.setLevel(level)
596    log_to_stderr.setFormatter(
597        logging.Formatter(
598            fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s',
599            datefmt='%H:%M:%S',
600        )
601    )
602
603    _LOG.addHandler(log_to_stderr)
604
605
606async def _main(server_port: int, client_port: int) -> None:
607    _init_logging(logging.DEBUG)
608
609    # Load config from stdin using synchronous IO
610    text_config = sys.stdin.buffer.read()
611
612    config = text_format.Parse(text_config, config_pb2.ProxyConfig())
613
614    # Instantiate the TCP server.
615    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
616    server_socket.setsockopt(
617        socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE
618    )
619    server_socket.bind(('localhost', client_port))
620    server = await asyncio.start_server(
621        lambda reader, writer: _handle_connection(
622            server_port, config, reader, writer
623        ),
624        limit=_RECEIVE_BUFFER_SIZE,
625        sock=server_socket,
626    )
627
628    addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets)
629    _LOG.info(f'Listening for client connection on {addrs}')
630
631    # Run the TCP server.
632    async with server:
633        await server.serve_forever()
634
635
636if __name__ == '__main__':
637    asyncio.run(_main(**vars(_parse_args())))
638