• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2022 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Test fixture for pw_transfer integration tests."""
16
17import argparse
18import asyncio
19from dataclasses import dataclass
20import logging
21import pathlib
22from pathlib import Path
23import sys
24import tempfile
25from typing import BinaryIO, Iterable, List, NamedTuple, Optional
26import unittest
27
28from google.protobuf import text_format
29
30from pigweed.pw_protobuf.pw_protobuf_protos import status_pb2
31from pigweed.pw_transfer.integration_test import config_pb2
32from rules_python.python.runfiles import runfiles
33
34_LOG = logging.getLogger('pw_transfer_intergration_test_proxy')
35_LOG.level = logging.DEBUG
36_LOG.addHandler(logging.StreamHandler(sys.stdout))
37
38
39class LogMonitor:
40    """Monitors lines read from the reader, and logs them."""
41
42    class Error(Exception):
43        """Raised if wait_for_line reaches EOF before expected line."""
44
45        pass
46
47    def __init__(self, prefix: str, reader: asyncio.StreamReader):
48        """Initializer.
49
50        Args:
51          prefix: Prepended to read lines before they are logged.
52          reader: StreamReader to read lines from.
53        """
54        self._prefix = prefix
55        self._reader = reader
56
57        # Queue of messages waiting to be monitored.
58        self._queue = asyncio.Queue()
59        # Relog any messages read from the reader, and enqueue them for
60        # monitoring.
61        self._relog_and_enqueue_task = asyncio.create_task(
62            self._relog_and_enqueue()
63        )
64
65    async def wait_for_line(self, msg: str):
66        """Wait for a line containing msg to be read from the reader."""
67        while True:
68            line = await self._queue.get()
69            if not line:
70                raise LogMonitor.Error(
71                    f"Reached EOF before getting line matching {msg}"
72                )
73            if msg in line.decode():
74                return
75
76    async def wait_for_eof(self):
77        """Wait for the reader to reach EOF, relogging any lines read."""
78        # Drain the queue, since we're not monitoring it any more.
79        drain_queue = asyncio.create_task(self._drain_queue())
80        await asyncio.gather(drain_queue, self._relog_and_enqueue_task)
81
82    async def _relog_and_enqueue(self):
83        """Reads lines from the reader, logs them, and puts them in queue."""
84        while True:
85            line = await self._reader.readline()
86            await self._queue.put(line)
87            if line:
88                _LOG.info(f"{self._prefix} {line.decode().rstrip()}")
89            else:
90                # EOF. Note, we still put the EOF in the queue, so that the
91                # queue reader can process it appropriately.
92                return
93
94    async def _drain_queue(self):
95        while True:
96            line = await self._queue.get()
97            if not line:
98                # EOF.
99                return
100
101
102class MonitoredSubprocess:
103    """A subprocess with monitored asynchronous communication."""
104
105    @staticmethod
106    async def create(cmd: List[str], prefix: str, stdinput: bytes):
107        """Starts the subprocess and writes stdinput to stdin.
108
109        This method returns once stdinput has been written to stdin. The
110        MonitoredSubprocess continues to log the process's stderr and stdout
111        (with the prefix) until it terminates.
112
113        Args:
114          cmd: Command line to execute.
115          prefix: Prepended to process logs.
116          stdinput: Written to stdin on process startup.
117        """
118        self = MonitoredSubprocess()
119        self._process = await asyncio.create_subprocess_exec(
120            *cmd,
121            stdin=asyncio.subprocess.PIPE,
122            stdout=asyncio.subprocess.PIPE,
123            stderr=asyncio.subprocess.PIPE,
124        )
125
126        self._stderr_monitor = LogMonitor(
127            f"{prefix} ERR:", self._process.stderr
128        )
129        self._stdout_monitor = LogMonitor(
130            f"{prefix} OUT:", self._process.stdout
131        )
132
133        self._process.stdin.write(stdinput)
134        await self._process.stdin.drain()
135        self._process.stdin.close()
136        await self._process.stdin.wait_closed()
137        return self
138
139    async def wait_for_line(self, stream: str, msg: str, timeout: float):
140        """Wait for a line containing msg to be read on the stream."""
141        if stream == "stdout":
142            monitor = self._stdout_monitor
143        elif stream == "stderr":
144            monitor = self._stderr_monitor
145        else:
146            raise ValueError(
147                "Stream must be 'stdout' or 'stderr', got {stream}"
148            )
149
150        await asyncio.wait_for(monitor.wait_for_line(msg), timeout)
151
152    def returncode(self):
153        return self._process.returncode
154
155    def terminate(self):
156        """Terminate the process."""
157        self._process.terminate()
158
159    async def wait_for_termination(self, timeout: float):
160        """Wait for the process to terminate."""
161        await asyncio.wait_for(
162            asyncio.gather(
163                self._process.wait(),
164                self._stdout_monitor.wait_for_eof(),
165                self._stderr_monitor.wait_for_eof(),
166            ),
167            timeout,
168        )
169
170    async def terminate_and_wait(self, timeout: float):
171        """Terminate the process and wait for it to exit."""
172        if self.returncode() is not None:
173            # Process already terminated
174            return
175        self.terminate()
176        await self.wait_for_termination(timeout)
177
178
179class TransferConfig(NamedTuple):
180    """A simple tuple to collect configs for test binaries."""
181
182    server: config_pb2.ServerConfig
183    client: config_pb2.ClientConfig
184    proxy: config_pb2.ProxyConfig
185
186
187class TransferIntegrationTestHarness:
188    """A class to manage transfer integration tests"""
189
190    # Prefix for log messages coming from the harness (as opposed to the server,
191    # client, or proxy processes). Padded so that the length is the same as
192    # "SERVER OUT:".
193    _PREFIX = "HARNESS:   "
194
195    @dataclass
196    class Config:
197        server_port: int = 3300
198        client_port: int = 3301
199        java_client_binary: Optional[Path] = None
200        cpp_client_binary: Optional[Path] = None
201        python_client_binary: Optional[Path] = None
202        proxy_binary: Optional[Path] = None
203        server_binary: Optional[Path] = None
204
205    class TransferExitCodes(NamedTuple):
206        client: int
207        server: int
208
209    def __init__(self, harness_config: Config) -> None:
210        # TODO(tpudlik): This is Bazel-only. Support gn, too.
211        r = runfiles.Create()
212
213        # Set defaults.
214        self._JAVA_CLIENT_BINARY = r.Rlocation(
215            "pigweed/pw_transfer/integration_test/java_client"
216        )
217        self._CPP_CLIENT_BINARY = r.Rlocation(
218            "pigweed/pw_transfer/integration_test/cpp_client"
219        )
220        self._PYTHON_CLIENT_BINARY = r.Rlocation(
221            "pigweed/pw_transfer/integration_test/python_client"
222        )
223        self._PROXY_BINARY = r.Rlocation(
224            "pigweed/pw_transfer/integration_test/proxy"
225        )
226        self._SERVER_BINARY = r.Rlocation(
227            "pigweed/pw_transfer/integration_test/server"
228        )
229
230        # Server/client ports are non-optional, so use those.
231        self._CLIENT_PORT = harness_config.client_port
232        self._SERVER_PORT = harness_config.server_port
233
234        # If the harness configuration specifies overrides, use those.
235        if harness_config.java_client_binary is not None:
236            self._JAVA_CLIENT_BINARY = harness_config.java_client_binary
237        if harness_config.cpp_client_binary is not None:
238            self._CPP_CLIENT_BINARY = harness_config.cpp_client_binary
239        if harness_config.python_client_binary is not None:
240            self._PYTHON_CLIENT_BINARY = harness_config.python_client_binary
241        if harness_config.proxy_binary is not None:
242            self._PROXY_BINARY = harness_config.proxy_binary
243        if harness_config.server_binary is not None:
244            self._SERVER_BINARY = harness_config.server_binary
245
246        self._CLIENT_BINARY = {
247            "cpp": self._CPP_CLIENT_BINARY,
248            "java": self._JAVA_CLIENT_BINARY,
249            "python": self._PYTHON_CLIENT_BINARY,
250        }
251        pass
252
253    async def _start_client(
254        self, client_type: str, config: config_pb2.ClientConfig
255    ):
256        _LOG.info(f"{self._PREFIX} Starting client with config\n{config}")
257        self._client = await MonitoredSubprocess.create(
258            [self._CLIENT_BINARY[client_type], str(self._CLIENT_PORT)],
259            "CLIENT",
260            str(config).encode('ascii'),
261        )
262
263    async def _start_server(self, config: config_pb2.ServerConfig):
264        _LOG.info(f"{self._PREFIX} Starting server with config\n{config}")
265        self._server = await MonitoredSubprocess.create(
266            [self._SERVER_BINARY, str(self._SERVER_PORT)],
267            "SERVER",
268            str(config).encode('ascii'),
269        )
270
271    async def _start_proxy(self, config: config_pb2.ProxyConfig):
272        _LOG.info(f"{self._PREFIX} Starting proxy with config\n{config}")
273        self._proxy = await MonitoredSubprocess.create(
274            [
275                self._PROXY_BINARY,
276                "--server-port",
277                str(self._SERVER_PORT),
278                "--client-port",
279                str(self._CLIENT_PORT),
280            ],
281            # Extra space in "PROXY " so that it lines up with "SERVER".
282            "PROXY ",
283            str(config).encode('ascii'),
284        )
285
286    async def perform_transfers(
287        self,
288        server_config: config_pb2.ServerConfig,
289        client_type: str,
290        client_config: config_pb2.ClientConfig,
291        proxy_config: config_pb2.ProxyConfig,
292    ) -> TransferExitCodes:
293        """Performs a pw_transfer write.
294
295        Args:
296          server_config: Server configuration.
297          client_type: Either "cpp", "java", or "python".
298          client_config: Client configuration.
299          proxy_config: Proxy configuration.
300
301        Returns:
302          Exit code of the client and server as a tuple.
303        """
304        # Timeout for components (server, proxy) to come up or shut down after
305        # write is finished or a signal is sent. Approximately arbitrary. Should
306        # not be too long so that we catch bugs in the server that prevent it
307        # from shutting down.
308        TIMEOUT = 5  # seconds
309
310        try:
311            await self._start_proxy(proxy_config)
312            await self._proxy.wait_for_line(
313                "stderr", "Listening for client connection", TIMEOUT
314            )
315
316            await self._start_server(server_config)
317            await self._server.wait_for_line(
318                "stderr", "Starting pw_rpc server on port", TIMEOUT
319            )
320
321            await self._start_client(client_type, client_config)
322            # No timeout: the client will only exit once the transfer
323            # completes, and this can take a long time for large payloads.
324            await self._client.wait_for_termination(None)
325
326            # Wait for the server to exit.
327            await self._server.wait_for_termination(TIMEOUT)
328
329        finally:
330            # Stop the server, if still running. (Only expected if the
331            # wait_for above timed out.)
332            if self._server:
333                await self._server.terminate_and_wait(TIMEOUT)
334            # Stop the proxy. Unlike the server, we expect it to still be
335            # running at this stage.
336            if self._proxy:
337                await self._proxy.terminate_and_wait(TIMEOUT)
338
339            return self.TransferExitCodes(
340                self._client.returncode(), self._server.returncode()
341            )
342
343
344class BasicTransfer(NamedTuple):
345    id: int
346    type: config_pb2.TransferAction.TransferType.ValueType
347    data: bytes
348
349
350class TransferIntegrationTest(unittest.TestCase):
351    """A base class for transfer integration tests.
352
353    This significantly reduces the boiler plate required for building
354    integration test cases for pw_transfer. This class does not include any
355    tests itself, but instead bundles together much of the boiler plate required
356    for making an integration test for pw_transfer using this test fixture.
357    """
358
359    HARNESS_CONFIG = TransferIntegrationTestHarness.Config()
360
361    @classmethod
362    def setUpClass(cls):
363        cls.harness = TransferIntegrationTestHarness(cls.HARNESS_CONFIG)
364
365    @staticmethod
366    def default_server_config() -> config_pb2.ServerConfig:
367        return config_pb2.ServerConfig(
368            chunk_size_bytes=216,
369            pending_bytes=32 * 1024,
370            chunk_timeout_seconds=5,
371            transfer_service_retries=4,
372            extend_window_divisor=32,
373        )
374
375    @staticmethod
376    def default_client_config() -> config_pb2.ClientConfig:
377        return config_pb2.ClientConfig(
378            max_retries=5,
379            max_lifetime_retries=1500,
380            initial_chunk_timeout_ms=4000,
381            chunk_timeout_ms=4000,
382        )
383
384    @staticmethod
385    def default_proxy_config() -> config_pb2.ProxyConfig:
386        return text_format.Parse(
387            """
388                client_filter_stack: [
389                    { hdlc_packetizer: {} },
390                    { data_dropper: {rate: 0.01, seed: 1649963713563718435} }
391                ]
392
393                server_filter_stack: [
394                    { hdlc_packetizer: {} },
395                    { data_dropper: {rate: 0.01, seed: 1649963713563718436} }
396            ]""",
397            config_pb2.ProxyConfig(),
398        )
399
400    @staticmethod
401    def default_config() -> TransferConfig:
402        """Returns a new transfer config with default options."""
403        return TransferConfig(
404            TransferIntegrationTest.default_server_config(),
405            TransferIntegrationTest.default_client_config(),
406            TransferIntegrationTest.default_proxy_config(),
407        )
408
409    def do_single_write(
410        self,
411        client_type: str,
412        config: TransferConfig,
413        resource_id: int,
414        data: bytes,
415        protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST,
416        permanent_resource_id=False,
417        expected_status=status_pb2.StatusCode.OK,
418    ) -> None:
419        """Performs a single client-to-server write of the provided data."""
420        with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_server_output:
421            if permanent_resource_id:
422                config.server.resources[
423                    resource_id
424                ].default_destination_path = f_server_output.name
425            else:
426                config.server.resources[resource_id].destination_paths.append(
427                    f_server_output.name
428                )
429            config.client.transfer_actions.append(
430                config_pb2.TransferAction(
431                    resource_id=resource_id,
432                    file_path=f_payload.name,
433                    transfer_type=config_pb2.TransferAction.TransferType.WRITE_TO_SERVER,
434                    protocol_version=protocol_version,
435                    expected_status=int(expected_status),
436                )
437            )
438
439            f_payload.write(data)
440            f_payload.flush()  # Ensure contents are there to read!
441            exit_codes = asyncio.run(
442                self.harness.perform_transfers(
443                    config.server, client_type, config.client, config.proxy
444                )
445            )
446
447            self.assertEqual(exit_codes.client, 0)
448            self.assertEqual(exit_codes.server, 0)
449            if expected_status == status_pb2.StatusCode.OK:
450                self.assertEqual(f_server_output.read(), data)
451
452    def do_single_read(
453        self,
454        client_type: str,
455        config: TransferConfig,
456        resource_id: int,
457        data: bytes,
458        protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST,
459        permanent_resource_id=False,
460        expected_status=status_pb2.StatusCode.OK,
461    ) -> None:
462        """Performs a single server-to-client read of the provided data."""
463        with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_client_output:
464            if permanent_resource_id:
465                config.server.resources[
466                    resource_id
467                ].default_source_path = f_payload.name
468            else:
469                config.server.resources[resource_id].source_paths.append(
470                    f_payload.name
471                )
472            config.client.transfer_actions.append(
473                config_pb2.TransferAction(
474                    resource_id=resource_id,
475                    file_path=f_client_output.name,
476                    transfer_type=config_pb2.TransferAction.TransferType.READ_FROM_SERVER,
477                    protocol_version=protocol_version,
478                    expected_status=int(expected_status),
479                )
480            )
481
482            f_payload.write(data)
483            f_payload.flush()  # Ensure contents are there to read!
484            exit_codes = asyncio.run(
485                self.harness.perform_transfers(
486                    config.server, client_type, config.client, config.proxy
487                )
488            )
489            self.assertEqual(exit_codes.client, 0)
490            self.assertEqual(exit_codes.server, 0)
491            if expected_status == status_pb2.StatusCode.OK:
492                self.assertEqual(f_client_output.read(), data)
493
494    def do_basic_transfer_sequence(
495        self,
496        client_type: str,
497        config: TransferConfig,
498        transfers: Iterable[BasicTransfer],
499    ) -> None:
500        """Performs multiple reads/writes in a single client/server session."""
501
502        class ReadbackSet(NamedTuple):
503            server_file: BinaryIO
504            client_file: BinaryIO
505            expected_data: bytes
506
507        transfer_results: List[ReadbackSet] = []
508        for transfer in transfers:
509            server_file = tempfile.NamedTemporaryFile()
510            client_file = tempfile.NamedTemporaryFile()
511
512            if (
513                transfer.type
514                == config_pb2.TransferAction.TransferType.READ_FROM_SERVER
515            ):
516                server_file.write(transfer.data)
517                server_file.flush()
518                config.server.resources[transfer.id].source_paths.append(
519                    server_file.name
520                )
521            elif (
522                transfer.type
523                == config_pb2.TransferAction.TransferType.WRITE_TO_SERVER
524            ):
525                client_file.write(transfer.data)
526                client_file.flush()
527                config.server.resources[transfer.id].destination_paths.append(
528                    server_file.name
529                )
530            else:
531                raise ValueError('Unknown TransferType')
532
533            config.client.transfer_actions.append(
534                config_pb2.TransferAction(
535                    resource_id=transfer.id,
536                    file_path=client_file.name,
537                    transfer_type=transfer.type,
538                )
539            )
540
541            transfer_results.append(
542                ReadbackSet(server_file, client_file, transfer.data)
543            )
544
545        exit_codes = asyncio.run(
546            self.harness.perform_transfers(
547                config.server, client_type, config.client, config.proxy
548            )
549        )
550
551        for i, result in enumerate(transfer_results):
552            with self.subTest(i=i):
553                # Need to seek to the beginning of the file to read written
554                # data.
555                result.client_file.seek(0, 0)
556                result.server_file.seek(0, 0)
557                self.assertEqual(
558                    result.client_file.read(), result.expected_data
559                )
560                self.assertEqual(
561                    result.server_file.read(), result.expected_data
562                )
563
564        # Check exit codes at the end as they provide less useful info.
565        self.assertEqual(exit_codes.client, 0)
566        self.assertEqual(exit_codes.server, 0)
567
568
569def run_tests_for(test_class_name):
570    parser = argparse.ArgumentParser()
571    parser.add_argument(
572        '--server-port',
573        type=int,
574        help='Port of the integration test server.  The proxy will forward connections to this port',
575    )
576    parser.add_argument(
577        '--client-port',
578        type=int,
579        help='Port on which to listen for connections from integration test client.',
580    )
581    parser.add_argument(
582        '--java-client-binary',
583        type=pathlib.Path,
584        default=None,
585        help='Path to the Java transfer client to use in tests',
586    )
587    parser.add_argument(
588        '--cpp-client-binary',
589        type=pathlib.Path,
590        default=None,
591        help='Path to the C++ transfer client to use in tests',
592    )
593    parser.add_argument(
594        '--python-client-binary',
595        type=pathlib.Path,
596        default=None,
597        help='Path to the Python transfer client to use in tests',
598    )
599    parser.add_argument(
600        '--server-binary',
601        type=pathlib.Path,
602        default=None,
603        help='Path to the transfer server to use in tests',
604    )
605    parser.add_argument(
606        '--proxy-binary',
607        type=pathlib.Path,
608        default=None,
609        help=(
610            'Path to the proxy binary to use in tests to allow interception '
611            'of client/server data'
612        ),
613    )
614
615    (args, passthrough_args) = parser.parse_known_args()
616
617    # Inherrit the default configuration from the class being tested, and only
618    # override provided arguments.
619    for arg in vars(args):
620        val = getattr(args, arg)
621        if val:
622            setattr(test_class_name.HARNESS_CONFIG, arg, val)
623
624    unittest_args = [sys.argv[0]] + passthrough_args
625    unittest.main(argv=unittest_args)
626