• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests transfers between the Python client and C++ service."""
16
17from pathlib import Path
18import random
19import tempfile
20from typing import List, Tuple, Union
21import unittest
22
23from pw_hdlc import rpc
24from pw_rpc import testing
25from pw_status import Status
26import pw_transfer
27from pw_transfer import transfer_pb2
28from pw_transfer_test import test_server_pb2
29
30ITERATIONS = 5
31TIMEOUT_S = 0.05
32
33_DATA_4096B = b'SPAM' * (4096 // len('SPAM'))
34
35
36class TransferServiceIntegrationTest(unittest.TestCase):
37    """Tests transfers between the Python client and C++ service."""
38    test_server_command: Tuple[str, ...] = ()
39    port: int
40
41    def setUp(self) -> None:
42        self._tempdir = tempfile.TemporaryDirectory(
43            prefix=f'pw_transfer_{self.id().rsplit(".", 1)[-1]}_')
44        self.directory = Path(self._tempdir.name)
45
46        command = (*self.test_server_command, str(self.directory))
47        self._outgoing_filter = rpc.PacketFilter('outgoing RPC')
48        self._incoming_filter = rpc.PacketFilter('incoming RPC')
49        self._context = rpc.HdlcRpcLocalServerAndClient(
50            command,
51            self.port, [transfer_pb2, test_server_pb2],
52            outgoing_processor=self._outgoing_filter,
53            incoming_processor=self._incoming_filter)
54
55        service = self._context.client.channel(1).rpcs.pw.transfer.Transfer
56        self.manager = pw_transfer.Manager(
57            service, default_response_timeout_s=TIMEOUT_S)
58
59        self._test_server = self._context.client.channel(
60            1).rpcs.pw.transfer.TestServer
61
62    def tearDown(self) -> None:
63        try:
64            self._tempdir.cleanup()
65        finally:
66            if hasattr(self, '_context'):
67                self._context.close()
68
69    def transfer_file_path(self, transfer_id: int) -> Path:
70        return self.directory / str(transfer_id)
71
72    def set_content(self, transfer_id: int, data: Union[bytes, str]) -> None:
73        self.transfer_file_path(transfer_id).write_bytes(
74            data.encode() if isinstance(data, str) else data)
75        self._test_server.ReloadTransferFiles()
76
77    def get_content(self, transfer_id: int) -> bytes:
78        return self.transfer_file_path(transfer_id).read_bytes()
79
80    def test_read_unknown_id(self) -> None:
81        with self.assertRaises(pw_transfer.Error) as ctx:
82            self.manager.read(99)
83        self.assertEqual(ctx.exception.status, Status.NOT_FOUND)
84
85    def test_read_empty(self) -> None:
86        for _ in range(ITERATIONS):
87            self.set_content(24, '')
88            self.assertEqual(self.manager.read(24), b'')
89
90    def test_read_single_byte(self) -> None:
91        for _ in range(ITERATIONS):
92            self.set_content(25, '0')
93            self.assertEqual(self.manager.read(25), b'0')
94
95    def test_read_small_amount_of_data(self) -> None:
96        for _ in range(ITERATIONS):
97            self.set_content(26, 'hunter2')
98            self.assertEqual(self.manager.read(26), b'hunter2')
99
100    def test_read_large_amount_of_data(self) -> None:
101        for _ in range(ITERATIONS):
102            size = 2**13  # TODO(hepler): Increase to 2**14 when it passes.
103            self.set_content(27, '~' * size)
104            self.assertEqual(self.manager.read(27), b'~' * size)
105
106    def test_write_unknown_id(self) -> None:
107        with self.assertRaises(pw_transfer.Error) as ctx:
108            self.manager.write(99, '')
109        self.assertEqual(ctx.exception.status, Status.NOT_FOUND)
110
111    def test_write_empty(self) -> None:
112        for _ in range(ITERATIONS):
113            self.set_content(28, 'junk')
114            self.manager.write(28, b'')
115            self.assertEqual(self.get_content(28), b'')
116
117    def test_write_single_byte(self) -> None:
118        for _ in range(ITERATIONS):
119            self.set_content(29, 'junk')
120            self.manager.write(29, b'$')
121            self.assertEqual(self.get_content(29), b'$')
122
123    def test_write_small_amount_of_data(self) -> None:
124        for _ in range(ITERATIONS):
125            self.set_content(30, 'junk')
126            self.manager.write(30, b'file transfer')
127            self.assertEqual(self.get_content(30), b'file transfer')
128
129    def test_write_large_amount_of_data(self) -> None:
130        for _ in range(ITERATIONS):
131            self.set_content(31, 'junk')
132            self.manager.write(31, b'*' * 512)
133            self.assertEqual(self.get_content(31), b'*' * 512)
134
135    def test_write_very_large_amount_of_data(self) -> None:
136        for _ in range(ITERATIONS):
137            self.set_content(32, 'junk')
138
139            # Larger than the transfer service's configured pending_bytes.
140            self.manager.write(32, _DATA_4096B)
141            self.assertEqual(self.get_content(32), _DATA_4096B)
142
143    def test_write_string(self) -> None:
144        for _ in range(ITERATIONS):
145            # Write a string instead of bytes.
146            self.set_content(33, 'junk')
147            self.manager.write(33, 'hello world')
148            self.assertEqual(self.get_content(33), b'hello world')
149
150    def test_write_drop_data_chunks_and_transfer_parameters(self) -> None:
151        self.set_content(34, 'junk')
152
153        # Allow the initial packet and first chunk, then drop the second chunk.
154        self._outgoing_filter.keep(2)
155        self._outgoing_filter.drop(1)
156
157        # Allow the initial transfer parameters updates, then drop the next two.
158        self._incoming_filter.keep(1)
159        self._incoming_filter.drop(2)
160
161        with self.assertLogs('pw_transfer', 'DEBUG') as logs:
162            self.manager.write(34, _DATA_4096B)
163
164        self.assertEqual(self.get_content(34), _DATA_4096B)
165
166        # Verify that the client retried twice.
167        messages = [r.getMessage() for r in logs.records]
168        retry = f'Received no responses for {TIMEOUT_S:.3f}s; retrying {{}}/3'
169        self.assertIn(retry.format(1), messages)
170        self.assertIn(retry.format(2), messages)
171
172    def test_write_regularly_drop_packets(self) -> None:
173        self.set_content(35, 'junk')
174
175        self._outgoing_filter.drop_every(5)  # drop one per window
176        self._incoming_filter.drop_every(3)
177
178        self.manager.write(35, _DATA_4096B)
179
180        self.assertEqual(self.get_content(35), _DATA_4096B)
181
182    def test_write_randomly_drop_packets(self) -> None:
183        # Allow lots of retries since there are lots of drops.
184        self.manager.max_retries = 9
185
186        for seed in [1, 5678, 600613]:
187            self.set_content(seed, 'junk')
188
189            rand = random.Random(seed)
190            self._incoming_filter.randomly_drop(3, rand)
191            self._outgoing_filter.randomly_drop(3, rand)
192
193            data = bytes(
194                rand.randrange(256) for _ in range(rand.randrange(16384)))
195            self.manager.write(seed, data)
196            self.assertEqual(self.get_content(seed), data)
197
198            self._incoming_filter.reset()
199            self._outgoing_filter.reset()
200
201
202def _main(test_server_command: List[str], port: int,
203          unittest_args: List[str]) -> None:
204    TransferServiceIntegrationTest.test_server_command = tuple(
205        test_server_command)
206    TransferServiceIntegrationTest.port = port
207
208    unittest.main(argv=unittest_args)
209
210
211if __name__ == '__main__':
212    _main(**vars(testing.parse_test_server_args()))
213