• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2022 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Proxy for transfer integration testing.
16
17This module contains a proxy for transfer intergation testing.  It is capable
18of introducing various link failures into the connection between the client and
19server.
20"""
21
22import abc
23import argparse
24import asyncio
25from enum import Enum
26import logging
27import random
28import socket
29import sys
30import time
31from typing import Awaitable, Callable, Iterable, NamedTuple
32
33from google.protobuf import text_format
34
35from pigweed.pw_rpc.internal import packet_pb2
36from pigweed.pw_transfer import transfer_pb2
37from pigweed.pw_transfer.integration_test import config_pb2
38from pw_hdlc import decode
39from pw_transfer import ProtocolVersion
40from pw_transfer.chunk import Chunk
41
42_LOG = logging.getLogger('pw_transfer_intergration_test_proxy')
43
44# This is the maximum size of the socket receive buffers. Ideally, this is set
45# to the lowest allowed value to minimize buffering between the proxy and
46# clients so rate limiting causes the client to block and wait for the
47# integration test proxy to drain rather than allowing OS buffers to backlog
48# large quantities of data.
49#
50# Note that the OS may chose to not strictly follow this requested buffer size.
51# Still, setting this value to be relatively small does reduce bufer sizes
52# significantly enough to better reflect typical inter-device communication.
53#
54# For this to be effective, clients should also configure their sockets to a
55# smaller send buffer size.
56_RECEIVE_BUFFER_SIZE = 2048
57
58
59class EventType(Enum):
60    TRANSFER_START = 1
61    PARAMETERS_RETRANSMIT = 2
62    PARAMETERS_CONTINUE = 3
63    START_ACK_CONFIRMATION = 4
64
65
66class Event(NamedTuple):
67    type: EventType
68    chunk: Chunk
69
70
71class Filter(abc.ABC):
72    """An abstract interface for manipulating a stream of data.
73
74    ``Filter``s are used to implement various transforms to simulate real
75    world link properties.  Some examples include: data corruption,
76    packet loss, packet reordering, rate limiting, latency modeling.
77
78    A ``Filter`` implementation should implement the ``process`` method
79    and call ``self.send_data()`` when it has data to send.
80    """
81
82    def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
83        self.send_data = send_data
84
85    @abc.abstractmethod
86    async def process(self, data: bytes) -> None:
87        """Processes incoming data.
88
89        Implementations of this method may send arbitrary data, or none, using
90        the ``self.send_data()`` handler.
91        """
92
93    async def __call__(self, data: bytes) -> None:
94        await self.process(data)
95
96
97class HdlcPacketizer(Filter):
98    """A filter which aggregates data into complete HDLC packets.
99
100    Since the proxy transport (SOCK_STREAM) has no framing and we want some
101    filters to operates on whole frames, this filter can be used so that
102    downstream filters see whole frames.
103    """
104
105    def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
106        super().__init__(send_data)
107        self.decoder = decode.FrameDecoder()
108
109    async def process(self, data: bytes) -> None:
110        for frame in self.decoder.process(data):
111            await self.send_data(frame.raw_encoded)
112
113
114class DataDropper(Filter):
115    """A filter which drops some data.
116
117    DataDropper will drop data passed through ``process()`` at the
118    specified ``rate``.
119    """
120
121    def __init__(
122        self,
123        send_data: Callable[[bytes], Awaitable[None]],
124        name: str,
125        rate: float,
126        seed: int | None = None,
127    ):
128        super().__init__(send_data)
129        self._rate = rate
130        self._name = name
131        if seed == None:
132            seed = time.time_ns()
133        self._rng = random.Random(seed)
134        _LOG.info(f'{name} DataDropper initialized with seed {seed}')
135
136    async def process(self, data: bytes) -> None:
137        if self._rng.uniform(0.0, 1.0) < self._rate:
138            _LOG.info(f'{self._name} dropped {len(data)} bytes of data')
139        else:
140            await self.send_data(data)
141
142
143class KeepDropQueue(Filter):
144    """A filter which alternates between sending packets and dropping packets.
145
146    A KeepDropQueue filter will alternate between keeping packets and dropping
147    chunks of data based on a keep/drop queue provided during its creation. The
148    queue is looped over unless a negative element is found. A negative number
149    is effectively the same as a value of infinity.
150
151     This filter is typically most practical when used with a packetizer so data
152     can be dropped as distinct packets.
153
154    Examples:
155
156      keep_drop_queue = [3, 2]:
157        Keeps 3 packets,
158        Drops 2 packets,
159        Keeps 3 packets,
160        Drops 2 packets,
161        ... [loops indefinitely]
162
163      keep_drop_queue = [5, 99, 1, -1]:
164        Keeps 5 packets,
165        Drops 99 packets,
166        Keeps 1 packet,
167        Drops all further packets.
168    """
169
170    def __init__(
171        self,
172        send_data: Callable[[bytes], Awaitable[None]],
173        name: str,
174        keep_drop_queue: Iterable[int],
175        only_consider_transfer_chunks: bool = False,
176    ):
177        super().__init__(send_data)
178        self._keep_drop_queue = list(keep_drop_queue)
179        self._loop_idx = 0
180        self._current_count = self._keep_drop_queue[0]
181        self._keep = True
182        self._name = name
183        self._only_consider_transfer_chunks = only_consider_transfer_chunks
184
185    async def process(self, data: bytes) -> None:
186        if self._only_consider_transfer_chunks:
187            try:
188                _extract_transfer_chunk(data)
189            except Exception:
190                await self.send_data(data)
191                return
192
193        # Move forward through the queue if needed.
194        while self._current_count == 0:
195            self._loop_idx += 1
196            self._current_count = self._keep_drop_queue[
197                self._loop_idx % len(self._keep_drop_queue)
198            ]
199            self._keep = not self._keep
200
201        if self._current_count > 0:
202            self._current_count -= 1
203
204        if self._keep:
205            await self.send_data(data)
206            _LOG.info(f'{self._name} forwarded {len(data)} bytes of data')
207        else:
208            _LOG.info(f'{self._name} dropped {len(data)} bytes of data')
209
210
211class RateLimiter(Filter):
212    """A filter which limits transmission rate.
213
214    This filter delays transmission of data by len(data)/rate.
215    """
216
217    def __init__(
218        self, send_data: Callable[[bytes], Awaitable[None]], rate: float
219    ):
220        super().__init__(send_data)
221        self._rate = rate
222
223    async def process(self, data: bytes) -> None:
224        delay = len(data) / self._rate
225        await asyncio.sleep(delay)
226        await self.send_data(data)
227
228
229class DataTransposer(Filter):
230    """A filter which occasionally transposes two chunks of data.
231
232    This filter transposes data at the specified rate.  It does this by
233    holding a chunk to transpose until another chunk arrives. The filter
234    will not hold a chunk longer than ``timeout`` seconds.
235    """
236
237    def __init__(
238        self,
239        send_data: Callable[[bytes], Awaitable[None]],
240        name: str,
241        rate: float,
242        timeout: float,
243        seed: int,
244    ):
245        super().__init__(send_data)
246        self._name = name
247        self._rate = rate
248        self._timeout = timeout
249        self._data_queue = asyncio.Queue()
250        self._rng = random.Random(seed)
251        self._transpose_task = asyncio.create_task(self._transpose_handler())
252
253        _LOG.info(f'{name} DataTranspose initialized with seed {seed}')
254
255    def __del__(self):
256        _LOG.info(f'{self._name} cleaning up transpose task.')
257        self._transpose_task.cancel()
258
259    async def _transpose_handler(self):
260        """Async task that handles the packet transposition and timeouts"""
261        held_data: bytes | None = None
262        while True:
263            # Only use timeout if we have data held for transposition
264            timeout = None if held_data is None else self._timeout
265            try:
266                data = await asyncio.wait_for(
267                    self._data_queue.get(), timeout=timeout
268                )
269
270                if held_data is not None:
271                    # If we have held data, send it out of order.
272                    await self.send_data(data)
273                    await self.send_data(held_data)
274                    held_data = None
275                else:
276                    # Otherwise decide if we should transpose the current data.
277                    if self._rng.uniform(0.0, 1.0) < self._rate:
278                        _LOG.info(
279                            f'{self._name} transposing {len(data)} bytes of data'
280                        )
281                        held_data = data
282                    else:
283                        await self.send_data(data)
284
285            except asyncio.TimeoutError:
286                _LOG.info(f'{self._name} sending data in order due to timeout')
287                await self.send_data(held_data)
288                held_data = None
289
290    async def process(self, data: bytes) -> None:
291        # Queue data for processing by the transpose task.
292        await self._data_queue.put(data)
293
294
295class ServerFailure(Filter):
296    """A filter to simulate the server stopping sending packets.
297
298    ServerFailure takes a list of numbers of packets to send before
299    dropping all subsequent packets until a TRANSFER_START packet
300    is seen.  This process is repeated for each element in
301    packets_before_failure.  After that list is exhausted, ServerFailure
302    will send all packets.
303
304    This filter should be instantiated in the same filter stack as an
305    HdlcPacketizer so that EventFilter can decode complete packets.
306    """
307
308    def __init__(
309        self,
310        send_data: Callable[[bytes], Awaitable[None]],
311        name: str,
312        packets_before_failure_list: list[int],
313        start_immediately: bool = False,
314        only_consider_transfer_chunks: bool = False,
315    ):
316        super().__init__(send_data)
317        self._name = name
318        self._relay_packets = True
319        self._packets_before_failure_list = packets_before_failure_list
320        self._packets_before_failure = None
321        self._only_consider_transfer_chunks = only_consider_transfer_chunks
322        if start_immediately:
323            self.advance_packets_before_failure()
324
325    def advance_packets_before_failure(self):
326        if len(self._packets_before_failure_list) > 0:
327            self._packets_before_failure = (
328                self._packets_before_failure_list.pop(0)
329            )
330        else:
331            self._packets_before_failure = None
332
333    async def process(self, data: bytes) -> None:
334        if self._only_consider_transfer_chunks:
335            try:
336                _extract_transfer_chunk(data)
337            except Exception:
338                await self.send_data(data)
339                return
340
341        if self._packets_before_failure is None:
342            await self.send_data(data)
343        elif self._packets_before_failure > 0:
344            self._packets_before_failure -= 1
345            await self.send_data(data)
346
347    def handle_event(self, event: Event) -> None:
348        if event.type is EventType.TRANSFER_START:
349            self.advance_packets_before_failure()
350
351
352class WindowPacketDropper(Filter):
353    """A filter to allow the same packet in each window to be dropped.
354
355    WindowPacketDropper with drop the nth packet in each window as
356    specified by window_packet_to_drop.  This process will happen
357    indefinitely for each window.
358
359    This filter should be instantiated in the same filter stack as an
360    HdlcPacketizer so that EventFilter can decode complete packets.
361    """
362
363    def __init__(
364        self,
365        send_data: Callable[[bytes], Awaitable[None]],
366        name: str,
367        window_packet_to_drop: int,
368    ):
369        super().__init__(send_data)
370        self._name = name
371        self._relay_packets = True
372        self._window_packet_to_drop = window_packet_to_drop
373        self._next_window_start_offset: int | None = 0
374        self._window_packet = 0
375
376    async def process(self, data: bytes) -> None:
377        data_chunk = None
378        try:
379            chunk = _extract_transfer_chunk(data)
380            if chunk.type is Chunk.Type.DATA:
381                data_chunk = chunk
382        except Exception:
383            # Invalid / non-chunk data (e.g. text logs); ignore.
384            pass
385
386        # Only count transfer data chunks as part of a window.
387        if data_chunk is not None:
388            if data_chunk.offset == self._next_window_start_offset:
389                # If a new window has been requested, wait until the first
390                # chunk matching its requested offset to begin counting window
391                # chunks. Any in-flight chunks from the previous window are
392                # allowed through.
393                self._window_packet = 0
394                self._next_window_start_offset = None
395
396            if self._window_packet != self._window_packet_to_drop:
397                await self.send_data(data)
398
399            self._window_packet += 1
400        else:
401            await self.send_data(data)
402
403    def handle_event(self, event: Event) -> None:
404        if event.type in (
405            EventType.PARAMETERS_RETRANSMIT,
406            EventType.PARAMETERS_CONTINUE,
407            EventType.START_ACK_CONFIRMATION,
408        ):
409            # A new transmission window has been requested, starting at the
410            # offset specified in the chunk. The receiver may already have data
411            # from the previous window in-flight, so don't immediately reset
412            # the window packet counter.
413            self._next_window_start_offset = event.chunk.offset
414
415
416class EventFilter(Filter):
417    """A filter that inspects packets and send events to other filters.
418
419    This filter should be instantiated in the same filter stack as an
420    HdlcPacketizer so that it can decode complete packets.
421    """
422
423    def __init__(
424        self,
425        send_data: Callable[[bytes], Awaitable[None]],
426        name: str,
427        event_queue: asyncio.Queue,
428    ):
429        super().__init__(send_data)
430        self._name = name
431        self._queue = event_queue
432
433    async def process(self, data: bytes) -> None:
434        try:
435            chunk = _extract_transfer_chunk(data)
436            if chunk.type is Chunk.Type.START:
437                await self._queue.put(Event(EventType.TRANSFER_START, chunk))
438            if chunk.type is Chunk.Type.START_ACK_CONFIRMATION:
439                await self._queue.put(
440                    Event(EventType.START_ACK_CONFIRMATION, chunk)
441                )
442            elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT:
443                await self._queue.put(
444                    Event(EventType.PARAMETERS_RETRANSMIT, chunk)
445                )
446            elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE:
447                await self._queue.put(
448                    Event(EventType.PARAMETERS_CONTINUE, chunk)
449                )
450        except:
451            # Silently ignore invalid packets
452            pass
453
454        await self.send_data(data)
455
456
457def _extract_transfer_chunk(data: bytes) -> Chunk:
458    """Gets a transfer Chunk from an HDLC frame containing an RPC packet.
459
460    Raises an exception if a valid chunk does not exist.
461    """
462
463    decoder = decode.FrameDecoder()
464    for frame in decoder.process(data):
465        packet = packet_pb2.RpcPacket()
466        packet.ParseFromString(frame.data)
467
468        if packet.payload:
469            raw_chunk = transfer_pb2.Chunk()
470            raw_chunk.ParseFromString(packet.payload)
471            return Chunk.from_message(raw_chunk)
472
473        # The incoming data is expected to be HDLC-packetized, so only one
474        # frame should exist.
475        break
476
477    raise ValueError("Invalid transfer chunk frame")
478
479
480async def _handle_simplex_events(
481    event_queue: asyncio.Queue, handlers: list[Callable[[Event], None]]
482):
483    while True:
484        event = await event_queue.get()
485        for handler in handlers:
486            handler(event)
487
488
489async def _handle_simplex_connection(
490    name: str,
491    filter_stack_config: list[config_pb2.FilterConfig],
492    reader: asyncio.StreamReader,
493    writer: asyncio.StreamWriter,
494    inbound_event_queue: asyncio.Queue,
495    outbound_event_queue: asyncio.Queue,
496) -> None:
497    """Handle a single direction of a bidirectional connection between
498    server and client."""
499
500    async def send(data: bytes):
501        writer.write(data)
502        await writer.drain()
503
504    filter_stack = EventFilter(send, name, outbound_event_queue)
505
506    event_handlers: list[Callable[[Event], None]] = []
507
508    # Build the filter stack from the bottom up
509    for config in reversed(filter_stack_config):
510        filter_name = config.WhichOneof("filter")
511        if filter_name == "hdlc_packetizer":
512            filter_stack = HdlcPacketizer(filter_stack)
513        elif filter_name == "data_dropper":
514            data_dropper = config.data_dropper
515            filter_stack = DataDropper(
516                filter_stack, name, data_dropper.rate, data_dropper.seed
517            )
518        elif filter_name == "rate_limiter":
519            filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate)
520        elif filter_name == "data_transposer":
521            transposer = config.data_transposer
522            filter_stack = DataTransposer(
523                filter_stack,
524                name,
525                transposer.rate,
526                transposer.timeout,
527                transposer.seed,
528            )
529        elif filter_name == "server_failure":
530            server_failure = config.server_failure
531            filter_stack = ServerFailure(
532                filter_stack,
533                name,
534                server_failure.packets_before_failure,
535                server_failure.start_immediately,
536                server_failure.only_consider_transfer_chunks,
537            )
538            event_handlers.append(filter_stack.handle_event)
539        elif filter_name == "keep_drop_queue":
540            keep_drop_queue = config.keep_drop_queue
541            filter_stack = KeepDropQueue(
542                filter_stack,
543                name,
544                keep_drop_queue.keep_drop_queue,
545                keep_drop_queue.only_consider_transfer_chunks,
546            )
547        elif filter_name == "window_packet_dropper":
548            window_packet_dropper = config.window_packet_dropper
549            filter_stack = WindowPacketDropper(
550                filter_stack, name, window_packet_dropper.window_packet_to_drop
551            )
552            event_handlers.append(filter_stack.handle_event)
553        else:
554            sys.exit(f'Unknown filter {filter_name}')
555
556    event_task = asyncio.create_task(
557        _handle_simplex_events(inbound_event_queue, event_handlers)
558    )
559
560    while True:
561        # Arbitrarily chosen "page sized" read.
562        data = await reader.read(4096)
563
564        # An empty data indicates that the connection is closed.
565        if not data:
566            _LOG.info(f'{name} connection closed.')
567            return
568
569        await filter_stack.process(data)
570
571
572async def _handle_connection(
573    server_port: int,
574    config: config_pb2.ProxyConfig,
575    client_reader: asyncio.StreamReader,
576    client_writer: asyncio.StreamWriter,
577) -> None:
578    """Handle a connection between server and client."""
579
580    client_addr = client_writer.get_extra_info('peername')
581    _LOG.info(f'New client connection from {client_addr}')
582
583    # Open a new connection to the server for each client connection.
584    #
585    # TODO(konkers): catch exception and close client writer
586    server_reader, server_writer = await asyncio.open_connection(
587        'localhost', server_port
588    )
589    _LOG.info(f'New connection opened to server')
590
591    # Queues for the simplex connections to pass events to each other.
592    server_event_queue = asyncio.Queue()
593    client_event_queue = asyncio.Queue()
594
595    # Instantiate two simplex handler one for each direction of the connection.
596    _, pending = await asyncio.wait(
597        [
598            asyncio.create_task(
599                _handle_simplex_connection(
600                    "client",
601                    config.client_filter_stack,
602                    client_reader,
603                    server_writer,
604                    server_event_queue,
605                    client_event_queue,
606                )
607            ),
608            asyncio.create_task(
609                _handle_simplex_connection(
610                    "server",
611                    config.server_filter_stack,
612                    server_reader,
613                    client_writer,
614                    client_event_queue,
615                    server_event_queue,
616                )
617            ),
618        ],
619        return_when=asyncio.FIRST_COMPLETED,
620    )
621
622    # When one side terminates the connection, also terminate the other side
623    for task in pending:
624        task.cancel()
625
626    for stream in [client_writer, server_writer]:
627        stream.close()
628
629
630def _parse_args() -> argparse.Namespace:
631    parser = argparse.ArgumentParser(
632        description=__doc__,
633        formatter_class=argparse.RawDescriptionHelpFormatter,
634    )
635
636    parser.add_argument(
637        '--server-port',
638        type=int,
639        required=True,
640        help='Port of the integration test server.  The proxy will forward connections to this port',
641    )
642    parser.add_argument(
643        '--client-port',
644        type=int,
645        required=True,
646        help='Port on which to listen for connections from integration test client.',
647    )
648
649    return parser.parse_args()
650
651
652def _init_logging(level: int) -> None:
653    _LOG.setLevel(logging.DEBUG)
654    log_to_stderr = logging.StreamHandler()
655    log_to_stderr.setLevel(level)
656    log_to_stderr.setFormatter(
657        logging.Formatter(
658            fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s',
659            datefmt='%H:%M:%S',
660        )
661    )
662
663    _LOG.addHandler(log_to_stderr)
664
665
666async def _main(server_port: int, client_port: int) -> None:
667    _init_logging(logging.DEBUG)
668
669    # Load config from stdin using synchronous IO
670    text_config = sys.stdin.buffer.read()
671
672    config = text_format.Parse(text_config, config_pb2.ProxyConfig())
673
674    # Instantiate the TCP server.
675    server_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
676    server_socket.setsockopt(
677        socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE
678    )
679    server_socket.bind(('', client_port))
680    server = await asyncio.start_server(
681        lambda reader, writer: _handle_connection(
682            server_port, config, reader, writer
683        ),
684        limit=_RECEIVE_BUFFER_SIZE,
685        sock=server_socket,
686    )
687
688    addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets)
689    _LOG.info(f'Listening for client connection on {addrs}')
690
691    # Run the TCP server.
692    async with server:
693        await server.serve_forever()
694
695
696if __name__ == '__main__':
697    asyncio.run(_main(**vars(_parse_args())))
698