#!/usr/bin/env python3 # Copyright 2022 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """Test fixture for pw_transfer integration tests.""" import argparse import asyncio from dataclasses import dataclass import logging import pathlib from pathlib import Path import sys import tempfile from typing import BinaryIO, Iterable, List, NamedTuple, Optional import unittest from google.protobuf import text_format from pigweed.pw_protobuf.pw_protobuf_protos import status_pb2 from pigweed.pw_transfer.integration_test import config_pb2 from rules_python.python.runfiles import runfiles _LOG = logging.getLogger('pw_transfer_intergration_test_proxy') _LOG.level = logging.DEBUG _LOG.addHandler(logging.StreamHandler(sys.stdout)) class LogMonitor: """Monitors lines read from the reader, and logs them.""" class Error(Exception): """Raised if wait_for_line reaches EOF before expected line.""" pass def __init__(self, prefix: str, reader: asyncio.StreamReader): """Initializer. Args: prefix: Prepended to read lines before they are logged. reader: StreamReader to read lines from. """ self._prefix = prefix self._reader = reader # Queue of messages waiting to be monitored. self._queue = asyncio.Queue() # Relog any messages read from the reader, and enqueue them for # monitoring. self._relog_and_enqueue_task = asyncio.create_task( self._relog_and_enqueue() ) async def wait_for_line(self, msg: str): """Wait for a line containing msg to be read from the reader.""" while True: line = await self._queue.get() if not line: raise LogMonitor.Error( f"Reached EOF before getting line matching {msg}" ) if msg in line.decode(): return async def wait_for_eof(self): """Wait for the reader to reach EOF, relogging any lines read.""" # Drain the queue, since we're not monitoring it any more. drain_queue = asyncio.create_task(self._drain_queue()) await asyncio.gather(drain_queue, self._relog_and_enqueue_task) async def _relog_and_enqueue(self): """Reads lines from the reader, logs them, and puts them in queue.""" while True: line = await self._reader.readline() await self._queue.put(line) if line: _LOG.info(f"{self._prefix} {line.decode().rstrip()}") else: # EOF. Note, we still put the EOF in the queue, so that the # queue reader can process it appropriately. return async def _drain_queue(self): while True: line = await self._queue.get() if not line: # EOF. return class MonitoredSubprocess: """A subprocess with monitored asynchronous communication.""" @staticmethod async def create(cmd: List[str], prefix: str, stdinput: bytes): """Starts the subprocess and writes stdinput to stdin. This method returns once stdinput has been written to stdin. The MonitoredSubprocess continues to log the process's stderr and stdout (with the prefix) until it terminates. Args: cmd: Command line to execute. prefix: Prepended to process logs. stdinput: Written to stdin on process startup. """ self = MonitoredSubprocess() self._process = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) self._stderr_monitor = LogMonitor( f"{prefix} ERR:", self._process.stderr ) self._stdout_monitor = LogMonitor( f"{prefix} OUT:", self._process.stdout ) self._process.stdin.write(stdinput) await self._process.stdin.drain() self._process.stdin.close() await self._process.stdin.wait_closed() return self async def wait_for_line(self, stream: str, msg: str, timeout: float): """Wait for a line containing msg to be read on the stream.""" if stream == "stdout": monitor = self._stdout_monitor elif stream == "stderr": monitor = self._stderr_monitor else: raise ValueError( "Stream must be 'stdout' or 'stderr', got {stream}" ) await asyncio.wait_for(monitor.wait_for_line(msg), timeout) def returncode(self): return self._process.returncode def terminate(self): """Terminate the process.""" self._process.terminate() async def wait_for_termination(self, timeout: float): """Wait for the process to terminate.""" await asyncio.wait_for( asyncio.gather( self._process.wait(), self._stdout_monitor.wait_for_eof(), self._stderr_monitor.wait_for_eof(), ), timeout, ) async def terminate_and_wait(self, timeout: float): """Terminate the process and wait for it to exit.""" if self.returncode() is not None: # Process already terminated return self.terminate() await self.wait_for_termination(timeout) class TransferConfig(NamedTuple): """A simple tuple to collect configs for test binaries.""" server: config_pb2.ServerConfig client: config_pb2.ClientConfig proxy: config_pb2.ProxyConfig class TransferIntegrationTestHarness: """A class to manage transfer integration tests""" # Prefix for log messages coming from the harness (as opposed to the server, # client, or proxy processes). Padded so that the length is the same as # "SERVER OUT:". _PREFIX = "HARNESS: " @dataclass class Config: server_port: int = 3300 client_port: int = 3301 java_client_binary: Optional[Path] = None cpp_client_binary: Optional[Path] = None python_client_binary: Optional[Path] = None proxy_binary: Optional[Path] = None server_binary: Optional[Path] = None class TransferExitCodes(NamedTuple): client: int server: int def __init__(self, harness_config: Config) -> None: # TODO(tpudlik): This is Bazel-only. Support gn, too. r = runfiles.Create() # Set defaults. self._JAVA_CLIENT_BINARY = r.Rlocation( "pigweed/pw_transfer/integration_test/java_client" ) self._CPP_CLIENT_BINARY = r.Rlocation( "pigweed/pw_transfer/integration_test/cpp_client" ) self._PYTHON_CLIENT_BINARY = r.Rlocation( "pigweed/pw_transfer/integration_test/python_client" ) self._PROXY_BINARY = r.Rlocation( "pigweed/pw_transfer/integration_test/proxy" ) self._SERVER_BINARY = r.Rlocation( "pigweed/pw_transfer/integration_test/server" ) # Server/client ports are non-optional, so use those. self._CLIENT_PORT = harness_config.client_port self._SERVER_PORT = harness_config.server_port # If the harness configuration specifies overrides, use those. if harness_config.java_client_binary is not None: self._JAVA_CLIENT_BINARY = harness_config.java_client_binary if harness_config.cpp_client_binary is not None: self._CPP_CLIENT_BINARY = harness_config.cpp_client_binary if harness_config.python_client_binary is not None: self._PYTHON_CLIENT_BINARY = harness_config.python_client_binary if harness_config.proxy_binary is not None: self._PROXY_BINARY = harness_config.proxy_binary if harness_config.server_binary is not None: self._SERVER_BINARY = harness_config.server_binary self._CLIENT_BINARY = { "cpp": self._CPP_CLIENT_BINARY, "java": self._JAVA_CLIENT_BINARY, "python": self._PYTHON_CLIENT_BINARY, } pass async def _start_client( self, client_type: str, config: config_pb2.ClientConfig ): _LOG.info(f"{self._PREFIX} Starting client with config\n{config}") self._client = await MonitoredSubprocess.create( [self._CLIENT_BINARY[client_type], str(self._CLIENT_PORT)], "CLIENT", str(config).encode('ascii'), ) async def _start_server(self, config: config_pb2.ServerConfig): _LOG.info(f"{self._PREFIX} Starting server with config\n{config}") self._server = await MonitoredSubprocess.create( [self._SERVER_BINARY, str(self._SERVER_PORT)], "SERVER", str(config).encode('ascii'), ) async def _start_proxy(self, config: config_pb2.ProxyConfig): _LOG.info(f"{self._PREFIX} Starting proxy with config\n{config}") self._proxy = await MonitoredSubprocess.create( [ self._PROXY_BINARY, "--server-port", str(self._SERVER_PORT), "--client-port", str(self._CLIENT_PORT), ], # Extra space in "PROXY " so that it lines up with "SERVER". "PROXY ", str(config).encode('ascii'), ) async def perform_transfers( self, server_config: config_pb2.ServerConfig, client_type: str, client_config: config_pb2.ClientConfig, proxy_config: config_pb2.ProxyConfig, ) -> TransferExitCodes: """Performs a pw_transfer write. Args: server_config: Server configuration. client_type: Either "cpp", "java", or "python". client_config: Client configuration. proxy_config: Proxy configuration. Returns: Exit code of the client and server as a tuple. """ # Timeout for components (server, proxy) to come up or shut down after # write is finished or a signal is sent. Approximately arbitrary. Should # not be too long so that we catch bugs in the server that prevent it # from shutting down. TIMEOUT = 5 # seconds try: await self._start_proxy(proxy_config) await self._proxy.wait_for_line( "stderr", "Listening for client connection", TIMEOUT ) await self._start_server(server_config) await self._server.wait_for_line( "stderr", "Starting pw_rpc server on port", TIMEOUT ) await self._start_client(client_type, client_config) # No timeout: the client will only exit once the transfer # completes, and this can take a long time for large payloads. await self._client.wait_for_termination(None) # Wait for the server to exit. await self._server.wait_for_termination(TIMEOUT) finally: # Stop the server, if still running. (Only expected if the # wait_for above timed out.) if self._server: await self._server.terminate_and_wait(TIMEOUT) # Stop the proxy. Unlike the server, we expect it to still be # running at this stage. if self._proxy: await self._proxy.terminate_and_wait(TIMEOUT) return self.TransferExitCodes( self._client.returncode(), self._server.returncode() ) class BasicTransfer(NamedTuple): id: int type: config_pb2.TransferAction.TransferType.ValueType data: bytes class TransferIntegrationTest(unittest.TestCase): """A base class for transfer integration tests. This significantly reduces the boiler plate required for building integration test cases for pw_transfer. This class does not include any tests itself, but instead bundles together much of the boiler plate required for making an integration test for pw_transfer using this test fixture. """ HARNESS_CONFIG = TransferIntegrationTestHarness.Config() @classmethod def setUpClass(cls): cls.harness = TransferIntegrationTestHarness(cls.HARNESS_CONFIG) @staticmethod def default_server_config() -> config_pb2.ServerConfig: return config_pb2.ServerConfig( chunk_size_bytes=216, pending_bytes=32 * 1024, chunk_timeout_seconds=5, transfer_service_retries=4, extend_window_divisor=32, ) @staticmethod def default_client_config() -> config_pb2.ClientConfig: return config_pb2.ClientConfig( max_retries=5, max_lifetime_retries=1500, initial_chunk_timeout_ms=4000, chunk_timeout_ms=4000, ) @staticmethod def default_proxy_config() -> config_pb2.ProxyConfig: return text_format.Parse( """ client_filter_stack: [ { hdlc_packetizer: {} }, { data_dropper: {rate: 0.01, seed: 1649963713563718435} } ] server_filter_stack: [ { hdlc_packetizer: {} }, { data_dropper: {rate: 0.01, seed: 1649963713563718436} } ]""", config_pb2.ProxyConfig(), ) @staticmethod def default_config() -> TransferConfig: """Returns a new transfer config with default options.""" return TransferConfig( TransferIntegrationTest.default_server_config(), TransferIntegrationTest.default_client_config(), TransferIntegrationTest.default_proxy_config(), ) def do_single_write( self, client_type: str, config: TransferConfig, resource_id: int, data: bytes, protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST, permanent_resource_id=False, expected_status=status_pb2.StatusCode.OK, ) -> None: """Performs a single client-to-server write of the provided data.""" with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_server_output: if permanent_resource_id: config.server.resources[ resource_id ].default_destination_path = f_server_output.name else: config.server.resources[resource_id].destination_paths.append( f_server_output.name ) config.client.transfer_actions.append( config_pb2.TransferAction( resource_id=resource_id, file_path=f_payload.name, transfer_type=config_pb2.TransferAction.TransferType.WRITE_TO_SERVER, protocol_version=protocol_version, expected_status=int(expected_status), ) ) f_payload.write(data) f_payload.flush() # Ensure contents are there to read! exit_codes = asyncio.run( self.harness.perform_transfers( config.server, client_type, config.client, config.proxy ) ) self.assertEqual(exit_codes.client, 0) self.assertEqual(exit_codes.server, 0) if expected_status == status_pb2.StatusCode.OK: self.assertEqual(f_server_output.read(), data) def do_single_read( self, client_type: str, config: TransferConfig, resource_id: int, data: bytes, protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST, permanent_resource_id=False, expected_status=status_pb2.StatusCode.OK, ) -> None: """Performs a single server-to-client read of the provided data.""" with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_client_output: if permanent_resource_id: config.server.resources[ resource_id ].default_source_path = f_payload.name else: config.server.resources[resource_id].source_paths.append( f_payload.name ) config.client.transfer_actions.append( config_pb2.TransferAction( resource_id=resource_id, file_path=f_client_output.name, transfer_type=config_pb2.TransferAction.TransferType.READ_FROM_SERVER, protocol_version=protocol_version, expected_status=int(expected_status), ) ) f_payload.write(data) f_payload.flush() # Ensure contents are there to read! exit_codes = asyncio.run( self.harness.perform_transfers( config.server, client_type, config.client, config.proxy ) ) self.assertEqual(exit_codes.client, 0) self.assertEqual(exit_codes.server, 0) if expected_status == status_pb2.StatusCode.OK: self.assertEqual(f_client_output.read(), data) def do_basic_transfer_sequence( self, client_type: str, config: TransferConfig, transfers: Iterable[BasicTransfer], ) -> None: """Performs multiple reads/writes in a single client/server session.""" class ReadbackSet(NamedTuple): server_file: BinaryIO client_file: BinaryIO expected_data: bytes transfer_results: List[ReadbackSet] = [] for transfer in transfers: server_file = tempfile.NamedTemporaryFile() client_file = tempfile.NamedTemporaryFile() if ( transfer.type == config_pb2.TransferAction.TransferType.READ_FROM_SERVER ): server_file.write(transfer.data) server_file.flush() config.server.resources[transfer.id].source_paths.append( server_file.name ) elif ( transfer.type == config_pb2.TransferAction.TransferType.WRITE_TO_SERVER ): client_file.write(transfer.data) client_file.flush() config.server.resources[transfer.id].destination_paths.append( server_file.name ) else: raise ValueError('Unknown TransferType') config.client.transfer_actions.append( config_pb2.TransferAction( resource_id=transfer.id, file_path=client_file.name, transfer_type=transfer.type, ) ) transfer_results.append( ReadbackSet(server_file, client_file, transfer.data) ) exit_codes = asyncio.run( self.harness.perform_transfers( config.server, client_type, config.client, config.proxy ) ) for i, result in enumerate(transfer_results): with self.subTest(i=i): # Need to seek to the beginning of the file to read written # data. result.client_file.seek(0, 0) result.server_file.seek(0, 0) self.assertEqual( result.client_file.read(), result.expected_data ) self.assertEqual( result.server_file.read(), result.expected_data ) # Check exit codes at the end as they provide less useful info. self.assertEqual(exit_codes.client, 0) self.assertEqual(exit_codes.server, 0) def run_tests_for(test_class_name): parser = argparse.ArgumentParser() parser.add_argument( '--server-port', type=int, help='Port of the integration test server. The proxy will forward connections to this port', ) parser.add_argument( '--client-port', type=int, help='Port on which to listen for connections from integration test client.', ) parser.add_argument( '--java-client-binary', type=pathlib.Path, default=None, help='Path to the Java transfer client to use in tests', ) parser.add_argument( '--cpp-client-binary', type=pathlib.Path, default=None, help='Path to the C++ transfer client to use in tests', ) parser.add_argument( '--python-client-binary', type=pathlib.Path, default=None, help='Path to the Python transfer client to use in tests', ) parser.add_argument( '--server-binary', type=pathlib.Path, default=None, help='Path to the transfer server to use in tests', ) parser.add_argument( '--proxy-binary', type=pathlib.Path, default=None, help=( 'Path to the proxy binary to use in tests to allow interception ' 'of client/server data' ), ) (args, passthrough_args) = parser.parse_known_args() # Inherrit the default configuration from the class being tested, and only # override provided arguments. for arg in vars(args): val = getattr(args, arg) if val: setattr(test_class_name.HARNESS_CONFIG, arg, val) unittest_args = [sys.argv[0]] + passthrough_args unittest.main(argv=unittest_args)