1#!/usr/bin/env python3 2# Copyright 2022 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Proxy for transfer integration testing. 16 17This module contains a proxy for transfer intergation testing. It is capable 18of introducing various link failures into the connection between the client and 19server. 20""" 21 22import abc 23import argparse 24import asyncio 25from enum import Enum 26import logging 27import random 28import socket 29import sys 30import time 31from typing import Any, Awaitable, Callable, Iterable, List, Optional 32 33from google.protobuf import text_format 34 35from pigweed.pw_rpc.internal import packet_pb2 36from pigweed.pw_transfer import transfer_pb2 37from pigweed.pw_transfer.integration_test import config_pb2 38from pw_hdlc import decode 39from pw_transfer.chunk import Chunk 40 41_LOG = logging.getLogger('pw_transfer_intergration_test_proxy') 42 43# This is the maximum size of the socket receive buffers. Ideally, this is set 44# to the lowest allowed value to minimize buffering between the proxy and 45# clients so rate limiting causes the client to block and wait for the 46# integration test proxy to drain rather than allowing OS buffers to backlog 47# large quantities of data. 48# 49# Note that the OS may chose to not strictly follow this requested buffer size. 50# Still, setting this value to be relatively small does reduce bufer sizes 51# significantly enough to better reflect typical inter-device communication. 52# 53# For this to be effective, clients should also configure their sockets to a 54# smaller send buffer size. 55_RECEIVE_BUFFER_SIZE = 2048 56 57 58class Event(Enum): 59 TRANSFER_START = 1 60 PARAMETERS_RETRANSMIT = 2 61 PARAMETERS_CONTINUE = 3 62 START_ACK_CONFIRMATION = 4 63 64 65class Filter(abc.ABC): 66 """An abstract interface for manipulating a stream of data. 67 68 ``Filter``s are used to implement various transforms to simulate real 69 world link properties. Some examples include: data corruption, 70 packet loss, packet reordering, rate limiting, latency modeling. 71 72 A ``Filter`` implementation should implement the ``process`` method 73 and call ``self.send_data()`` when it has data to send. 74 """ 75 76 def __init__(self, send_data: Callable[[bytes], Awaitable[None]]): 77 self.send_data = send_data 78 pass 79 80 @abc.abstractmethod 81 async def process(self, data: bytes) -> None: 82 """Processes incoming data. 83 84 Implementations of this method may send arbitrary data, or none, using 85 the ``self.send_data()`` handler. 86 """ 87 88 async def __call__(self, data: bytes) -> None: 89 await self.process(data) 90 91 92class HdlcPacketizer(Filter): 93 """A filter which aggregates data into complete HDLC packets. 94 95 Since the proxy transport (SOCK_STREAM) has no framing and we want some 96 filters to operates on whole frames, this filter can be used so that 97 downstream filters see whole frames. 98 """ 99 100 def __init__(self, send_data: Callable[[bytes], Awaitable[None]]): 101 super().__init__(send_data) 102 self.decoder = decode.FrameDecoder() 103 104 async def process(self, data: bytes) -> None: 105 for frame in self.decoder.process(data): 106 await self.send_data(frame.raw_encoded) 107 108 109class DataDropper(Filter): 110 """A filter which drops some data. 111 112 DataDropper will drop data passed through ``process()`` at the 113 specified ``rate``. 114 """ 115 116 def __init__( 117 self, 118 send_data: Callable[[bytes], Awaitable[None]], 119 name: str, 120 rate: float, 121 seed: Optional[int] = None, 122 ): 123 super().__init__(send_data) 124 self._rate = rate 125 self._name = name 126 if seed == None: 127 seed = time.time_ns() 128 self._rng = random.Random(seed) 129 _LOG.info(f'{name} DataDropper initialized with seed {seed}') 130 131 async def process(self, data: bytes) -> None: 132 if self._rng.uniform(0.0, 1.0) < self._rate: 133 _LOG.info(f'{self._name} dropped {len(data)} bytes of data') 134 else: 135 await self.send_data(data) 136 137 138class KeepDropQueue(Filter): 139 """A filter which alternates between sending packets and dropping packets. 140 141 A KeepDropQueue filter will alternate between keeping packets and dropping 142 chunks of data based on a keep/drop queue provided during its creation. The 143 queue is looped over unless a negative element is found. A negative number 144 is effectively the same as a value of infinity. 145 146 This filter is typically most pratical when used with a packetizer so data 147 can be dropped as distinct packets. 148 149 Examples: 150 151 keep_drop_queue = [3, 2]: 152 Keeps 3 packets, 153 Drops 2 packets, 154 Keeps 3 packets, 155 Drops 2 packets, 156 ... [loops indefinitely] 157 158 keep_drop_queue = [5, 99, 1, -1]: 159 Keeps 5 packets, 160 Drops 99 packets, 161 Keeps 1 packet, 162 Drops all further packets. 163 """ 164 165 def __init__( 166 self, 167 send_data: Callable[[bytes], Awaitable[None]], 168 name: str, 169 keep_drop_queue: Iterable[int], 170 ): 171 super().__init__(send_data) 172 self._keep_drop_queue = list(keep_drop_queue) 173 self._loop_idx = 0 174 self._current_count = self._keep_drop_queue[0] 175 self._keep = True 176 self._name = name 177 178 async def process(self, data: bytes) -> None: 179 # Move forward through the queue if neeeded. 180 while self._current_count == 0: 181 self._loop_idx += 1 182 self._current_count = self._keep_drop_queue[ 183 self._loop_idx % len(self._keep_drop_queue) 184 ] 185 self._keep = not self._keep 186 187 if self._current_count > 0: 188 self._current_count -= 1 189 190 if self._keep: 191 await self.send_data(data) 192 _LOG.info(f'{self._name} forwarded {len(data)} bytes of data') 193 else: 194 _LOG.info(f'{self._name} dropped {len(data)} bytes of data') 195 196 197class RateLimiter(Filter): 198 """A filter which limits transmission rate. 199 200 This filter delays transmission of data by len(data)/rate. 201 """ 202 203 def __init__( 204 self, send_data: Callable[[bytes], Awaitable[None]], rate: float 205 ): 206 super().__init__(send_data) 207 self._rate = rate 208 209 async def process(self, data: bytes) -> None: 210 delay = len(data) / self._rate 211 await asyncio.sleep(delay) 212 await self.send_data(data) 213 214 215class DataTransposer(Filter): 216 """A filter which occasionally transposes two chunks of data. 217 218 This filter transposes data at the specified rate. It does this by 219 holding a chunk to transpose until another chunk arrives. The filter 220 will not hold a chunk longer than ``timeout`` seconds. 221 """ 222 223 def __init__( 224 self, 225 send_data: Callable[[bytes], Awaitable[None]], 226 name: str, 227 rate: float, 228 timeout: float, 229 seed: int, 230 ): 231 super().__init__(send_data) 232 self._name = name 233 self._rate = rate 234 self._timeout = timeout 235 self._data_queue = asyncio.Queue() 236 self._rng = random.Random(seed) 237 self._transpose_task = asyncio.create_task(self._transpose_handler()) 238 239 _LOG.info(f'{name} DataTranspose initialized with seed {seed}') 240 241 def __del__(self): 242 _LOG.info(f'{self._name} cleaning up transpose task.') 243 self._transpose_task.cancel() 244 245 async def _transpose_handler(self): 246 """Async task that handles the packet transposition and timeouts""" 247 held_data: Optional[bytes] = None 248 while True: 249 # Only use timeout if we have data held for transposition 250 timeout = None if held_data is None else self._timeout 251 try: 252 data = await asyncio.wait_for( 253 self._data_queue.get(), timeout=timeout 254 ) 255 256 if held_data is not None: 257 # If we have held data, send it out of order. 258 await self.send_data(data) 259 await self.send_data(held_data) 260 held_data = None 261 else: 262 # Otherwise decide if we should transpose the current data. 263 if self._rng.uniform(0.0, 1.0) < self._rate: 264 _LOG.info( 265 f'{self._name} transposing {len(data)} bytes of data' 266 ) 267 held_data = data 268 else: 269 await self.send_data(data) 270 271 except asyncio.TimeoutError: 272 _LOG.info(f'{self._name} sending data in order due to timeout') 273 await self.send_data(held_data) 274 held_data = None 275 276 async def process(self, data: bytes) -> None: 277 # Queue data for processing by the transpose task. 278 await self._data_queue.put(data) 279 280 281class ServerFailure(Filter): 282 """A filter to simulate the server stopping sending packets. 283 284 ServerFailure takes a list of numbers of packets to send before 285 dropping all subsequent packets until a TRANSFER_START packet 286 is seen. This process is repeated for each element in 287 packets_before_failure. After that list is exhausted, ServerFailure 288 will send all packets. 289 290 This filter should be instantiated in the same filter stack as an 291 HdlcPacketizer so that EventFilter can decode complete packets. 292 """ 293 294 def __init__( 295 self, 296 send_data: Callable[[bytes], Awaitable[None]], 297 name: str, 298 packets_before_failure_list: List[int], 299 ): 300 super().__init__(send_data) 301 self._name = name 302 self._relay_packets = True 303 self._packets_before_failure_list = packets_before_failure_list 304 self.advance_packets_before_failure() 305 306 def advance_packets_before_failure(self): 307 if len(self._packets_before_failure_list) > 0: 308 self._packets_before_failure = ( 309 self._packets_before_failure_list.pop(0) 310 ) 311 else: 312 self._packets_before_failure = None 313 314 async def process(self, data: bytes) -> None: 315 if self._packets_before_failure is None: 316 await self.send_data(data) 317 elif self._packets_before_failure > 0: 318 self._packets_before_failure -= 1 319 await self.send_data(data) 320 321 def handle_event(self, event: Event) -> None: 322 if event is Event.TRANSFER_START: 323 self.advance_packets_before_failure() 324 325 326class WindowPacketDropper(Filter): 327 """A filter to allow the same packet in each window to be dropped 328 329 WindowPacketDropper with drop the nth packet in each window as 330 specified by window_packet_to_drop. This process will happend 331 indefinitely for each window. 332 333 This filter should be instantiated in the same filter stack as an 334 HdlcPacketizer so that EventFilter can decode complete packets. 335 """ 336 337 def __init__( 338 self, 339 send_data: Callable[[bytes], Awaitable[None]], 340 name: str, 341 window_packet_to_drop: int, 342 ): 343 super().__init__(send_data) 344 self._name = name 345 self._relay_packets = True 346 self._window_packet_to_drop = window_packet_to_drop 347 self._window_packet = 0 348 349 async def process(self, data: bytes) -> None: 350 try: 351 is_data_chunk = ( 352 _extract_transfer_chunk(data).type is Chunk.Type.DATA 353 ) 354 except Exception: 355 # Invalid / non-chunk data (e.g. text logs); ignore. 356 is_data_chunk = False 357 358 # Only count transfer data chunks as part of a window. 359 if is_data_chunk: 360 if self._window_packet != self._window_packet_to_drop: 361 await self.send_data(data) 362 363 self._window_packet += 1 364 else: 365 await self.send_data(data) 366 367 def handle_event(self, event: Event) -> None: 368 if event in ( 369 Event.PARAMETERS_RETRANSMIT, 370 Event.PARAMETERS_CONTINUE, 371 Event.START_ACK_CONFIRMATION, 372 ): 373 self._window_packet = 0 374 375 376class EventFilter(Filter): 377 """A filter that inspects packets and send events to other filters. 378 379 This filter should be instantiated in the same filter stack as an 380 HdlcPacketizer so that it can decode complete packets. 381 """ 382 383 def __init__( 384 self, 385 send_data: Callable[[bytes], Awaitable[None]], 386 name: str, 387 event_queue: asyncio.Queue, 388 ): 389 super().__init__(send_data) 390 self._queue = event_queue 391 392 async def process(self, data: bytes) -> None: 393 try: 394 chunk = _extract_transfer_chunk(data) 395 if chunk.type is Chunk.Type.START: 396 await self._queue.put(Event.TRANSFER_START) 397 if chunk.type is Chunk.Type.START_ACK_CONFIRMATION: 398 await self._queue.put(Event.START_ACK_CONFIRMATION) 399 elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT: 400 await self._queue.put(Event.PARAMETERS_RETRANSMIT) 401 elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE: 402 await self._queue.put(Event.PARAMETERS_CONTINUE) 403 except: 404 # Silently ignore invalid packets 405 pass 406 407 await self.send_data(data) 408 409 410def _extract_transfer_chunk(data: bytes) -> Chunk: 411 """Gets a transfer Chunk from an HDLC frame containing an RPC packet. 412 413 Raises an exception if a valid chunk does not exist. 414 """ 415 416 decoder = decode.FrameDecoder() 417 for frame in decoder.process(data): 418 packet = packet_pb2.RpcPacket() 419 packet.ParseFromString(frame.data) 420 raw_chunk = transfer_pb2.Chunk() 421 raw_chunk.ParseFromString(packet.payload) 422 return Chunk.from_message(raw_chunk) 423 424 raise ValueError("Invalid transfer frame") 425 426 427async def _handle_simplex_events( 428 event_queue: asyncio.Queue, handlers: List[Callable[[Event], None]] 429): 430 while True: 431 event = await event_queue.get() 432 for handler in handlers: 433 handler(event) 434 435 436async def _handle_simplex_connection( 437 name: str, 438 filter_stack_config: List[config_pb2.FilterConfig], 439 reader: asyncio.StreamReader, 440 writer: asyncio.StreamWriter, 441 inbound_event_queue: asyncio.Queue, 442 outbound_event_queue: asyncio.Queue, 443) -> None: 444 """Handle a single direction of a bidirectional connection between 445 server and client.""" 446 447 async def send(data: bytes): 448 writer.write(data) 449 await writer.drain() 450 451 filter_stack = EventFilter(send, name, outbound_event_queue) 452 453 event_handlers: List[Callable[[Event], None]] = [] 454 455 # Build the filter stack from the bottom up 456 for config in reversed(filter_stack_config): 457 filter_name = config.WhichOneof("filter") 458 if filter_name == "hdlc_packetizer": 459 filter_stack = HdlcPacketizer(filter_stack) 460 elif filter_name == "data_dropper": 461 data_dropper = config.data_dropper 462 filter_stack = DataDropper( 463 filter_stack, name, data_dropper.rate, data_dropper.seed 464 ) 465 elif filter_name == "rate_limiter": 466 filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate) 467 elif filter_name == "data_transposer": 468 transposer = config.data_transposer 469 filter_stack = DataTransposer( 470 filter_stack, 471 name, 472 transposer.rate, 473 transposer.timeout, 474 transposer.seed, 475 ) 476 elif filter_name == "server_failure": 477 server_failure = config.server_failure 478 filter_stack = ServerFailure( 479 filter_stack, name, server_failure.packets_before_failure 480 ) 481 event_handlers.append(filter_stack.handle_event) 482 elif filter_name == "keep_drop_queue": 483 keep_drop_queue = config.keep_drop_queue 484 filter_stack = KeepDropQueue( 485 filter_stack, name, keep_drop_queue.keep_drop_queue 486 ) 487 elif filter_name == "window_packet_dropper": 488 window_packet_dropper = config.window_packet_dropper 489 filter_stack = WindowPacketDropper( 490 filter_stack, name, window_packet_dropper.window_packet_to_drop 491 ) 492 event_handlers.append(filter_stack.handle_event) 493 else: 494 sys.exit(f'Unknown filter {filter_name}') 495 496 event_task = asyncio.create_task( 497 _handle_simplex_events(inbound_event_queue, event_handlers) 498 ) 499 500 while True: 501 # Arbitrarily chosen "page sized" read. 502 data = await reader.read(4096) 503 504 # An empty data indicates that the connection is closed. 505 if not data: 506 _LOG.info(f'{name} connection closed.') 507 return 508 509 await filter_stack.process(data) 510 511 512async def _handle_connection( 513 server_port: int, 514 config: config_pb2.ProxyConfig, 515 client_reader: asyncio.StreamReader, 516 client_writer: asyncio.StreamWriter, 517) -> None: 518 """Handle a connection between server and client.""" 519 520 client_addr = client_writer.get_extra_info('peername') 521 _LOG.info(f'New client connection from {client_addr}') 522 523 # Open a new connection to the server for each client connection. 524 # 525 # TODO(konkers): catch exception and close client writer 526 server_reader, server_writer = await asyncio.open_connection( 527 'localhost', server_port 528 ) 529 _LOG.info(f'New connection opened to server') 530 531 # Queues for the simplex connections to pass events to each other. 532 server_event_queue = asyncio.Queue() 533 client_event_queue = asyncio.Queue() 534 535 # Instantiate two simplex handler one for each direction of the connection. 536 _, pending = await asyncio.wait( 537 [ 538 asyncio.create_task( 539 _handle_simplex_connection( 540 "client", 541 config.client_filter_stack, 542 client_reader, 543 server_writer, 544 server_event_queue, 545 client_event_queue, 546 ) 547 ), 548 asyncio.create_task( 549 _handle_simplex_connection( 550 "server", 551 config.server_filter_stack, 552 server_reader, 553 client_writer, 554 client_event_queue, 555 server_event_queue, 556 ) 557 ), 558 ], 559 return_when=asyncio.FIRST_COMPLETED, 560 ) 561 562 # When one side terminates the connection, also terminate the other side 563 for task in pending: 564 task.cancel() 565 566 for stream in [client_writer, server_writer]: 567 stream.close() 568 569 570def _parse_args() -> argparse.Namespace: 571 parser = argparse.ArgumentParser( 572 description=__doc__, 573 formatter_class=argparse.RawDescriptionHelpFormatter, 574 ) 575 576 parser.add_argument( 577 '--server-port', 578 type=int, 579 required=True, 580 help='Port of the integration test server. The proxy will forward connections to this port', 581 ) 582 parser.add_argument( 583 '--client-port', 584 type=int, 585 required=True, 586 help='Port on which to listen for connections from integration test client.', 587 ) 588 589 return parser.parse_args() 590 591 592def _init_logging(level: int) -> None: 593 _LOG.setLevel(logging.DEBUG) 594 log_to_stderr = logging.StreamHandler() 595 log_to_stderr.setLevel(level) 596 log_to_stderr.setFormatter( 597 logging.Formatter( 598 fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s', 599 datefmt='%H:%M:%S', 600 ) 601 ) 602 603 _LOG.addHandler(log_to_stderr) 604 605 606async def _main(server_port: int, client_port: int) -> None: 607 _init_logging(logging.DEBUG) 608 609 # Load config from stdin using synchronous IO 610 text_config = sys.stdin.buffer.read() 611 612 config = text_format.Parse(text_config, config_pb2.ProxyConfig()) 613 614 # Instantiate the TCP server. 615 server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 616 server_socket.setsockopt( 617 socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE 618 ) 619 server_socket.bind(('localhost', client_port)) 620 server = await asyncio.start_server( 621 lambda reader, writer: _handle_connection( 622 server_port, config, reader, writer 623 ), 624 limit=_RECEIVE_BUFFER_SIZE, 625 sock=server_socket, 626 ) 627 628 addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets) 629 _LOG.info(f'Listening for client connection on {addrs}') 630 631 # Run the TCP server. 632 async with server: 633 await server.serve_forever() 634 635 636if __name__ == '__main__': 637 asyncio.run(_main(**vars(_parse_args()))) 638