1#!/usr/bin/env python3 2# Copyright 2022 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Proxy for transfer integration testing. 16 17This module contains a proxy for transfer intergation testing. It is capable 18of introducing various link failures into the connection between the client and 19server. 20""" 21 22import abc 23import argparse 24import asyncio 25from enum import Enum 26import logging 27import random 28import socket 29import sys 30import time 31from typing import Awaitable, Callable, Iterable, NamedTuple 32 33from google.protobuf import text_format 34 35from pigweed.pw_rpc.internal import packet_pb2 36from pigweed.pw_transfer import transfer_pb2 37from pigweed.pw_transfer.integration_test import config_pb2 38from pw_hdlc import decode 39from pw_transfer import ProtocolVersion 40from pw_transfer.chunk import Chunk 41 42_LOG = logging.getLogger('pw_transfer_intergration_test_proxy') 43 44# This is the maximum size of the socket receive buffers. Ideally, this is set 45# to the lowest allowed value to minimize buffering between the proxy and 46# clients so rate limiting causes the client to block and wait for the 47# integration test proxy to drain rather than allowing OS buffers to backlog 48# large quantities of data. 49# 50# Note that the OS may chose to not strictly follow this requested buffer size. 51# Still, setting this value to be relatively small does reduce bufer sizes 52# significantly enough to better reflect typical inter-device communication. 53# 54# For this to be effective, clients should also configure their sockets to a 55# smaller send buffer size. 56_RECEIVE_BUFFER_SIZE = 2048 57 58 59class EventType(Enum): 60 TRANSFER_START = 1 61 PARAMETERS_RETRANSMIT = 2 62 PARAMETERS_CONTINUE = 3 63 START_ACK_CONFIRMATION = 4 64 65 66class Event(NamedTuple): 67 type: EventType 68 chunk: Chunk 69 70 71class Filter(abc.ABC): 72 """An abstract interface for manipulating a stream of data. 73 74 ``Filter``s are used to implement various transforms to simulate real 75 world link properties. Some examples include: data corruption, 76 packet loss, packet reordering, rate limiting, latency modeling. 77 78 A ``Filter`` implementation should implement the ``process`` method 79 and call ``self.send_data()`` when it has data to send. 80 """ 81 82 def __init__(self, send_data: Callable[[bytes], Awaitable[None]]): 83 self.send_data = send_data 84 85 @abc.abstractmethod 86 async def process(self, data: bytes) -> None: 87 """Processes incoming data. 88 89 Implementations of this method may send arbitrary data, or none, using 90 the ``self.send_data()`` handler. 91 """ 92 93 async def __call__(self, data: bytes) -> None: 94 await self.process(data) 95 96 97class HdlcPacketizer(Filter): 98 """A filter which aggregates data into complete HDLC packets. 99 100 Since the proxy transport (SOCK_STREAM) has no framing and we want some 101 filters to operates on whole frames, this filter can be used so that 102 downstream filters see whole frames. 103 """ 104 105 def __init__(self, send_data: Callable[[bytes], Awaitable[None]]): 106 super().__init__(send_data) 107 self.decoder = decode.FrameDecoder() 108 109 async def process(self, data: bytes) -> None: 110 for frame in self.decoder.process(data): 111 await self.send_data(frame.raw_encoded) 112 113 114class DataDropper(Filter): 115 """A filter which drops some data. 116 117 DataDropper will drop data passed through ``process()`` at the 118 specified ``rate``. 119 """ 120 121 def __init__( 122 self, 123 send_data: Callable[[bytes], Awaitable[None]], 124 name: str, 125 rate: float, 126 seed: int | None = None, 127 ): 128 super().__init__(send_data) 129 self._rate = rate 130 self._name = name 131 if seed == None: 132 seed = time.time_ns() 133 self._rng = random.Random(seed) 134 _LOG.info(f'{name} DataDropper initialized with seed {seed}') 135 136 async def process(self, data: bytes) -> None: 137 if self._rng.uniform(0.0, 1.0) < self._rate: 138 _LOG.info(f'{self._name} dropped {len(data)} bytes of data') 139 else: 140 await self.send_data(data) 141 142 143class KeepDropQueue(Filter): 144 """A filter which alternates between sending packets and dropping packets. 145 146 A KeepDropQueue filter will alternate between keeping packets and dropping 147 chunks of data based on a keep/drop queue provided during its creation. The 148 queue is looped over unless a negative element is found. A negative number 149 is effectively the same as a value of infinity. 150 151 This filter is typically most practical when used with a packetizer so data 152 can be dropped as distinct packets. 153 154 Examples: 155 156 keep_drop_queue = [3, 2]: 157 Keeps 3 packets, 158 Drops 2 packets, 159 Keeps 3 packets, 160 Drops 2 packets, 161 ... [loops indefinitely] 162 163 keep_drop_queue = [5, 99, 1, -1]: 164 Keeps 5 packets, 165 Drops 99 packets, 166 Keeps 1 packet, 167 Drops all further packets. 168 """ 169 170 def __init__( 171 self, 172 send_data: Callable[[bytes], Awaitable[None]], 173 name: str, 174 keep_drop_queue: Iterable[int], 175 only_consider_transfer_chunks: bool = False, 176 ): 177 super().__init__(send_data) 178 self._keep_drop_queue = list(keep_drop_queue) 179 self._loop_idx = 0 180 self._current_count = self._keep_drop_queue[0] 181 self._keep = True 182 self._name = name 183 self._only_consider_transfer_chunks = only_consider_transfer_chunks 184 185 async def process(self, data: bytes) -> None: 186 if self._only_consider_transfer_chunks: 187 try: 188 _extract_transfer_chunk(data) 189 except Exception: 190 await self.send_data(data) 191 return 192 193 # Move forward through the queue if needed. 194 while self._current_count == 0: 195 self._loop_idx += 1 196 self._current_count = self._keep_drop_queue[ 197 self._loop_idx % len(self._keep_drop_queue) 198 ] 199 self._keep = not self._keep 200 201 if self._current_count > 0: 202 self._current_count -= 1 203 204 if self._keep: 205 await self.send_data(data) 206 _LOG.info(f'{self._name} forwarded {len(data)} bytes of data') 207 else: 208 _LOG.info(f'{self._name} dropped {len(data)} bytes of data') 209 210 211class RateLimiter(Filter): 212 """A filter which limits transmission rate. 213 214 This filter delays transmission of data by len(data)/rate. 215 """ 216 217 def __init__( 218 self, send_data: Callable[[bytes], Awaitable[None]], rate: float 219 ): 220 super().__init__(send_data) 221 self._rate = rate 222 223 async def process(self, data: bytes) -> None: 224 delay = len(data) / self._rate 225 await asyncio.sleep(delay) 226 await self.send_data(data) 227 228 229class DataTransposer(Filter): 230 """A filter which occasionally transposes two chunks of data. 231 232 This filter transposes data at the specified rate. It does this by 233 holding a chunk to transpose until another chunk arrives. The filter 234 will not hold a chunk longer than ``timeout`` seconds. 235 """ 236 237 def __init__( 238 self, 239 send_data: Callable[[bytes], Awaitable[None]], 240 name: str, 241 rate: float, 242 timeout: float, 243 seed: int, 244 ): 245 super().__init__(send_data) 246 self._name = name 247 self._rate = rate 248 self._timeout = timeout 249 self._data_queue = asyncio.Queue() 250 self._rng = random.Random(seed) 251 self._transpose_task = asyncio.create_task(self._transpose_handler()) 252 253 _LOG.info(f'{name} DataTranspose initialized with seed {seed}') 254 255 def __del__(self): 256 _LOG.info(f'{self._name} cleaning up transpose task.') 257 self._transpose_task.cancel() 258 259 async def _transpose_handler(self): 260 """Async task that handles the packet transposition and timeouts""" 261 held_data: bytes | None = None 262 while True: 263 # Only use timeout if we have data held for transposition 264 timeout = None if held_data is None else self._timeout 265 try: 266 data = await asyncio.wait_for( 267 self._data_queue.get(), timeout=timeout 268 ) 269 270 if held_data is not None: 271 # If we have held data, send it out of order. 272 await self.send_data(data) 273 await self.send_data(held_data) 274 held_data = None 275 else: 276 # Otherwise decide if we should transpose the current data. 277 if self._rng.uniform(0.0, 1.0) < self._rate: 278 _LOG.info( 279 f'{self._name} transposing {len(data)} bytes of data' 280 ) 281 held_data = data 282 else: 283 await self.send_data(data) 284 285 except asyncio.TimeoutError: 286 _LOG.info(f'{self._name} sending data in order due to timeout') 287 await self.send_data(held_data) 288 held_data = None 289 290 async def process(self, data: bytes) -> None: 291 # Queue data for processing by the transpose task. 292 await self._data_queue.put(data) 293 294 295class ServerFailure(Filter): 296 """A filter to simulate the server stopping sending packets. 297 298 ServerFailure takes a list of numbers of packets to send before 299 dropping all subsequent packets until a TRANSFER_START packet 300 is seen. This process is repeated for each element in 301 packets_before_failure. After that list is exhausted, ServerFailure 302 will send all packets. 303 304 This filter should be instantiated in the same filter stack as an 305 HdlcPacketizer so that EventFilter can decode complete packets. 306 """ 307 308 def __init__( 309 self, 310 send_data: Callable[[bytes], Awaitable[None]], 311 name: str, 312 packets_before_failure_list: list[int], 313 start_immediately: bool = False, 314 only_consider_transfer_chunks: bool = False, 315 ): 316 super().__init__(send_data) 317 self._name = name 318 self._relay_packets = True 319 self._packets_before_failure_list = packets_before_failure_list 320 self._packets_before_failure = None 321 self._only_consider_transfer_chunks = only_consider_transfer_chunks 322 if start_immediately: 323 self.advance_packets_before_failure() 324 325 def advance_packets_before_failure(self): 326 if len(self._packets_before_failure_list) > 0: 327 self._packets_before_failure = ( 328 self._packets_before_failure_list.pop(0) 329 ) 330 else: 331 self._packets_before_failure = None 332 333 async def process(self, data: bytes) -> None: 334 if self._only_consider_transfer_chunks: 335 try: 336 _extract_transfer_chunk(data) 337 except Exception: 338 await self.send_data(data) 339 return 340 341 if self._packets_before_failure is None: 342 await self.send_data(data) 343 elif self._packets_before_failure > 0: 344 self._packets_before_failure -= 1 345 await self.send_data(data) 346 347 def handle_event(self, event: Event) -> None: 348 if event.type is EventType.TRANSFER_START: 349 self.advance_packets_before_failure() 350 351 352class WindowPacketDropper(Filter): 353 """A filter to allow the same packet in each window to be dropped. 354 355 WindowPacketDropper with drop the nth packet in each window as 356 specified by window_packet_to_drop. This process will happen 357 indefinitely for each window. 358 359 This filter should be instantiated in the same filter stack as an 360 HdlcPacketizer so that EventFilter can decode complete packets. 361 """ 362 363 def __init__( 364 self, 365 send_data: Callable[[bytes], Awaitable[None]], 366 name: str, 367 window_packet_to_drop: int, 368 ): 369 super().__init__(send_data) 370 self._name = name 371 self._relay_packets = True 372 self._window_packet_to_drop = window_packet_to_drop 373 self._next_window_start_offset: int | None = 0 374 self._window_packet = 0 375 376 async def process(self, data: bytes) -> None: 377 data_chunk = None 378 try: 379 chunk = _extract_transfer_chunk(data) 380 if chunk.type is Chunk.Type.DATA: 381 data_chunk = chunk 382 except Exception: 383 # Invalid / non-chunk data (e.g. text logs); ignore. 384 pass 385 386 # Only count transfer data chunks as part of a window. 387 if data_chunk is not None: 388 if data_chunk.offset == self._next_window_start_offset: 389 # If a new window has been requested, wait until the first 390 # chunk matching its requested offset to begin counting window 391 # chunks. Any in-flight chunks from the previous window are 392 # allowed through. 393 self._window_packet = 0 394 self._next_window_start_offset = None 395 396 if self._window_packet != self._window_packet_to_drop: 397 await self.send_data(data) 398 399 self._window_packet += 1 400 else: 401 await self.send_data(data) 402 403 def handle_event(self, event: Event) -> None: 404 if event.type in ( 405 EventType.PARAMETERS_RETRANSMIT, 406 EventType.PARAMETERS_CONTINUE, 407 EventType.START_ACK_CONFIRMATION, 408 ): 409 # A new transmission window has been requested, starting at the 410 # offset specified in the chunk. The receiver may already have data 411 # from the previous window in-flight, so don't immediately reset 412 # the window packet counter. 413 self._next_window_start_offset = event.chunk.offset 414 415 416class EventFilter(Filter): 417 """A filter that inspects packets and send events to other filters. 418 419 This filter should be instantiated in the same filter stack as an 420 HdlcPacketizer so that it can decode complete packets. 421 """ 422 423 def __init__( 424 self, 425 send_data: Callable[[bytes], Awaitable[None]], 426 name: str, 427 event_queue: asyncio.Queue, 428 ): 429 super().__init__(send_data) 430 self._name = name 431 self._queue = event_queue 432 433 async def process(self, data: bytes) -> None: 434 try: 435 chunk = _extract_transfer_chunk(data) 436 if chunk.type is Chunk.Type.START: 437 await self._queue.put(Event(EventType.TRANSFER_START, chunk)) 438 if chunk.type is Chunk.Type.START_ACK_CONFIRMATION: 439 await self._queue.put( 440 Event(EventType.START_ACK_CONFIRMATION, chunk) 441 ) 442 elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT: 443 await self._queue.put( 444 Event(EventType.PARAMETERS_RETRANSMIT, chunk) 445 ) 446 elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE: 447 await self._queue.put( 448 Event(EventType.PARAMETERS_CONTINUE, chunk) 449 ) 450 except: 451 # Silently ignore invalid packets 452 pass 453 454 await self.send_data(data) 455 456 457def _extract_transfer_chunk(data: bytes) -> Chunk: 458 """Gets a transfer Chunk from an HDLC frame containing an RPC packet. 459 460 Raises an exception if a valid chunk does not exist. 461 """ 462 463 decoder = decode.FrameDecoder() 464 for frame in decoder.process(data): 465 packet = packet_pb2.RpcPacket() 466 packet.ParseFromString(frame.data) 467 468 if packet.payload: 469 raw_chunk = transfer_pb2.Chunk() 470 raw_chunk.ParseFromString(packet.payload) 471 return Chunk.from_message(raw_chunk) 472 473 # The incoming data is expected to be HDLC-packetized, so only one 474 # frame should exist. 475 break 476 477 raise ValueError("Invalid transfer chunk frame") 478 479 480async def _handle_simplex_events( 481 event_queue: asyncio.Queue, handlers: list[Callable[[Event], None]] 482): 483 while True: 484 event = await event_queue.get() 485 for handler in handlers: 486 handler(event) 487 488 489async def _handle_simplex_connection( 490 name: str, 491 filter_stack_config: list[config_pb2.FilterConfig], 492 reader: asyncio.StreamReader, 493 writer: asyncio.StreamWriter, 494 inbound_event_queue: asyncio.Queue, 495 outbound_event_queue: asyncio.Queue, 496) -> None: 497 """Handle a single direction of a bidirectional connection between 498 server and client.""" 499 500 async def send(data: bytes): 501 writer.write(data) 502 await writer.drain() 503 504 filter_stack = EventFilter(send, name, outbound_event_queue) 505 506 event_handlers: list[Callable[[Event], None]] = [] 507 508 # Build the filter stack from the bottom up 509 for config in reversed(filter_stack_config): 510 filter_name = config.WhichOneof("filter") 511 if filter_name == "hdlc_packetizer": 512 filter_stack = HdlcPacketizer(filter_stack) 513 elif filter_name == "data_dropper": 514 data_dropper = config.data_dropper 515 filter_stack = DataDropper( 516 filter_stack, name, data_dropper.rate, data_dropper.seed 517 ) 518 elif filter_name == "rate_limiter": 519 filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate) 520 elif filter_name == "data_transposer": 521 transposer = config.data_transposer 522 filter_stack = DataTransposer( 523 filter_stack, 524 name, 525 transposer.rate, 526 transposer.timeout, 527 transposer.seed, 528 ) 529 elif filter_name == "server_failure": 530 server_failure = config.server_failure 531 filter_stack = ServerFailure( 532 filter_stack, 533 name, 534 server_failure.packets_before_failure, 535 server_failure.start_immediately, 536 server_failure.only_consider_transfer_chunks, 537 ) 538 event_handlers.append(filter_stack.handle_event) 539 elif filter_name == "keep_drop_queue": 540 keep_drop_queue = config.keep_drop_queue 541 filter_stack = KeepDropQueue( 542 filter_stack, 543 name, 544 keep_drop_queue.keep_drop_queue, 545 keep_drop_queue.only_consider_transfer_chunks, 546 ) 547 elif filter_name == "window_packet_dropper": 548 window_packet_dropper = config.window_packet_dropper 549 filter_stack = WindowPacketDropper( 550 filter_stack, name, window_packet_dropper.window_packet_to_drop 551 ) 552 event_handlers.append(filter_stack.handle_event) 553 else: 554 sys.exit(f'Unknown filter {filter_name}') 555 556 event_task = asyncio.create_task( 557 _handle_simplex_events(inbound_event_queue, event_handlers) 558 ) 559 560 while True: 561 # Arbitrarily chosen "page sized" read. 562 data = await reader.read(4096) 563 564 # An empty data indicates that the connection is closed. 565 if not data: 566 _LOG.info(f'{name} connection closed.') 567 return 568 569 await filter_stack.process(data) 570 571 572async def _handle_connection( 573 server_port: int, 574 config: config_pb2.ProxyConfig, 575 client_reader: asyncio.StreamReader, 576 client_writer: asyncio.StreamWriter, 577) -> None: 578 """Handle a connection between server and client.""" 579 580 client_addr = client_writer.get_extra_info('peername') 581 _LOG.info(f'New client connection from {client_addr}') 582 583 # Open a new connection to the server for each client connection. 584 # 585 # TODO(konkers): catch exception and close client writer 586 server_reader, server_writer = await asyncio.open_connection( 587 'localhost', server_port 588 ) 589 _LOG.info(f'New connection opened to server') 590 591 # Queues for the simplex connections to pass events to each other. 592 server_event_queue = asyncio.Queue() 593 client_event_queue = asyncio.Queue() 594 595 # Instantiate two simplex handler one for each direction of the connection. 596 _, pending = await asyncio.wait( 597 [ 598 asyncio.create_task( 599 _handle_simplex_connection( 600 "client", 601 config.client_filter_stack, 602 client_reader, 603 server_writer, 604 server_event_queue, 605 client_event_queue, 606 ) 607 ), 608 asyncio.create_task( 609 _handle_simplex_connection( 610 "server", 611 config.server_filter_stack, 612 server_reader, 613 client_writer, 614 client_event_queue, 615 server_event_queue, 616 ) 617 ), 618 ], 619 return_when=asyncio.FIRST_COMPLETED, 620 ) 621 622 # When one side terminates the connection, also terminate the other side 623 for task in pending: 624 task.cancel() 625 626 for stream in [client_writer, server_writer]: 627 stream.close() 628 629 630def _parse_args() -> argparse.Namespace: 631 parser = argparse.ArgumentParser( 632 description=__doc__, 633 formatter_class=argparse.RawDescriptionHelpFormatter, 634 ) 635 636 parser.add_argument( 637 '--server-port', 638 type=int, 639 required=True, 640 help='Port of the integration test server. The proxy will forward connections to this port', 641 ) 642 parser.add_argument( 643 '--client-port', 644 type=int, 645 required=True, 646 help='Port on which to listen for connections from integration test client.', 647 ) 648 649 return parser.parse_args() 650 651 652def _init_logging(level: int) -> None: 653 _LOG.setLevel(logging.DEBUG) 654 log_to_stderr = logging.StreamHandler() 655 log_to_stderr.setLevel(level) 656 log_to_stderr.setFormatter( 657 logging.Formatter( 658 fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s', 659 datefmt='%H:%M:%S', 660 ) 661 ) 662 663 _LOG.addHandler(log_to_stderr) 664 665 666async def _main(server_port: int, client_port: int) -> None: 667 _init_logging(logging.DEBUG) 668 669 # Load config from stdin using synchronous IO 670 text_config = sys.stdin.buffer.read() 671 672 config = text_format.Parse(text_config, config_pb2.ProxyConfig()) 673 674 # Instantiate the TCP server. 675 server_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) 676 server_socket.setsockopt( 677 socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE 678 ) 679 server_socket.bind(('', client_port)) 680 server = await asyncio.start_server( 681 lambda reader, writer: _handle_connection( 682 server_port, config, reader, writer 683 ), 684 limit=_RECEIVE_BUFFER_SIZE, 685 sock=server_socket, 686 ) 687 688 addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets) 689 _LOG.info(f'Listening for client connection on {addrs}') 690 691 # Run the TCP server. 692 async with server: 693 await server.serve_forever() 694 695 696if __name__ == '__main__': 697 asyncio.run(_main(**vars(_parse_args()))) 698