• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2022 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests for the transfer service client."""
16
17import enum
18import math
19import unittest
20from typing import Iterable, List
21
22from pw_status import Status
23from pw_rpc import callback_client, client, ids, packets
24from pw_rpc.internal import packet_pb2
25
26import pw_transfer
27from pw_transfer.transfer_pb2 import Chunk
28
29_TRANSFER_SERVICE_ID = ids.calculate('pw.transfer.Transfer')
30
31# If the default timeout is too short, some tests become flaky on Windows.
32DEFAULT_TIMEOUT_S = 0.3
33
34
35class _Method(enum.Enum):
36    READ = ids.calculate('Read')
37    WRITE = ids.calculate('Write')
38
39
40class TransferManagerTest(unittest.TestCase):
41    """Tests for the transfer manager."""
42    def setUp(self) -> None:
43        self._client = client.Client.from_modules(
44            callback_client.Impl(), [client.Channel(1, self._handle_request)],
45            (pw_transfer.transfer_pb2, ))
46        self._service = self._client.channel(1).rpcs.pw.transfer.Transfer
47
48        self._sent_chunks: List[Chunk] = []
49        self._packets_to_send: List[List[bytes]] = []
50
51    def _enqueue_server_responses(
52            self, method: _Method,
53            responses: Iterable[Iterable[Chunk]]) -> None:
54        for group in responses:
55            serialized_group = []
56            for response in group:
57                serialized_group.append(
58                    packet_pb2.RpcPacket(
59                        type=packet_pb2.PacketType.SERVER_STREAM,
60                        channel_id=1,
61                        service_id=_TRANSFER_SERVICE_ID,
62                        method_id=method.value,
63                        status=Status.OK.value,
64                        payload=response.SerializeToString()).
65                    SerializeToString())
66            self._packets_to_send.append(serialized_group)
67
68    def _enqueue_server_error(self, method: _Method, error: Status) -> None:
69        self._packets_to_send.append([
70            packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
71                                 channel_id=1,
72                                 service_id=_TRANSFER_SERVICE_ID,
73                                 method_id=method.value,
74                                 status=error.value).SerializeToString()
75        ])
76
77    def _handle_request(self, data: bytes) -> None:
78        packet = packets.decode(data)
79        if packet.type is not packet_pb2.PacketType.CLIENT_STREAM:
80            return
81
82        chunk = Chunk()
83        chunk.MergeFromString(packet.payload)
84        self._sent_chunks.append(chunk)
85
86        if self._packets_to_send:
87            responses = self._packets_to_send.pop(0)
88            for response in responses:
89                self._client.process_packet(response)
90
91    def _received_data(self) -> bytearray:
92        data = bytearray()
93        for chunk in self._sent_chunks:
94            data.extend(chunk.data)
95        return data
96
97    def test_read_transfer_basic(self):
98        manager = pw_transfer.Manager(
99            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
100
101        self._enqueue_server_responses(
102            _Method.READ,
103            ((Chunk(transfer_id=3, offset=0, data=b'abc',
104                    remaining_bytes=0), ), ),
105        )
106
107        data = manager.read(3)
108        self.assertEqual(data, b'abc')
109        self.assertEqual(len(self._sent_chunks), 2)
110        self.assertTrue(self._sent_chunks[-1].HasField('status'))
111        self.assertEqual(self._sent_chunks[-1].status, 0)
112
113    def test_read_transfer_multichunk(self) -> None:
114        manager = pw_transfer.Manager(
115            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
116
117        self._enqueue_server_responses(
118            _Method.READ,
119            ((
120                Chunk(transfer_id=3, offset=0, data=b'abc', remaining_bytes=3),
121                Chunk(transfer_id=3, offset=3, data=b'def', remaining_bytes=0),
122            ), ),
123        )
124
125        data = manager.read(3)
126        self.assertEqual(data, b'abcdef')
127        self.assertEqual(len(self._sent_chunks), 2)
128        self.assertTrue(self._sent_chunks[-1].HasField('status'))
129        self.assertEqual(self._sent_chunks[-1].status, 0)
130
131    def test_read_transfer_progress_callback(self) -> None:
132        manager = pw_transfer.Manager(
133            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
134
135        self._enqueue_server_responses(
136            _Method.READ,
137            ((
138                Chunk(transfer_id=3, offset=0, data=b'abc', remaining_bytes=3),
139                Chunk(transfer_id=3, offset=3, data=b'def', remaining_bytes=0),
140            ), ),
141        )
142
143        progress: List[pw_transfer.ProgressStats] = []
144
145        data = manager.read(3, progress.append)
146        self.assertEqual(data, b'abcdef')
147        self.assertEqual(len(self._sent_chunks), 2)
148        self.assertTrue(self._sent_chunks[-1].HasField('status'))
149        self.assertEqual(self._sent_chunks[-1].status, 0)
150        self.assertEqual(progress, [
151            pw_transfer.ProgressStats(3, 3, 6),
152            pw_transfer.ProgressStats(6, 6, 6),
153        ])
154
155    def test_read_transfer_retry_bad_offset(self) -> None:
156        """Server responds with an unexpected offset in a read transfer."""
157        manager = pw_transfer.Manager(
158            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
159
160        self._enqueue_server_responses(
161            _Method.READ,
162            (
163                (
164                    Chunk(transfer_id=3,
165                          offset=0,
166                          data=b'123',
167                          remaining_bytes=6),
168
169                    # Incorrect offset; expecting 3.
170                    Chunk(transfer_id=3,
171                          offset=1,
172                          data=b'456',
173                          remaining_bytes=3),
174                ),
175                (
176                    Chunk(transfer_id=3,
177                          offset=3,
178                          data=b'456',
179                          remaining_bytes=3),
180                    Chunk(transfer_id=3,
181                          offset=6,
182                          data=b'789',
183                          remaining_bytes=0),
184                ),
185            ))
186
187        data = manager.read(3)
188        self.assertEqual(data, b'123456789')
189
190        # Two transfer parameter requests should have been sent.
191        self.assertEqual(len(self._sent_chunks), 3)
192        self.assertTrue(self._sent_chunks[-1].HasField('status'))
193        self.assertEqual(self._sent_chunks[-1].status, 0)
194
195    def test_read_transfer_retry_timeout(self) -> None:
196        """Server doesn't respond to read transfer parameters."""
197        manager = pw_transfer.Manager(
198            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
199
200        self._enqueue_server_responses(
201            _Method.READ,
202            (
203                (),  # Send nothing in response to the initial parameters.
204                (Chunk(transfer_id=3, offset=0, data=b'xyz',
205                       remaining_bytes=0), ),
206            ))
207
208        data = manager.read(3)
209        self.assertEqual(data, b'xyz')
210
211        # Two transfer parameter requests should have been sent.
212        self.assertEqual(len(self._sent_chunks), 3)
213        self.assertTrue(self._sent_chunks[-1].HasField('status'))
214        self.assertEqual(self._sent_chunks[-1].status, 0)
215
216    def test_read_transfer_timeout(self) -> None:
217        manager = pw_transfer.Manager(
218            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
219
220        with self.assertRaises(pw_transfer.Error) as context:
221            manager.read(27)
222
223        exception = context.exception
224        self.assertEqual(exception.transfer_id, 27)
225        self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED)
226
227        # The client should have sent four transfer parameters requests: one
228        # initial, and three retries.
229        self.assertEqual(len(self._sent_chunks), 4)
230
231    def test_read_transfer_error(self) -> None:
232        manager = pw_transfer.Manager(
233            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
234
235        self._enqueue_server_responses(
236            _Method.READ,
237            ((Chunk(transfer_id=31, status=Status.NOT_FOUND.value), ), ),
238        )
239
240        with self.assertRaises(pw_transfer.Error) as context:
241            manager.read(31)
242
243        exception = context.exception
244        self.assertEqual(exception.transfer_id, 31)
245        self.assertEqual(exception.status, Status.NOT_FOUND)
246
247    def test_read_transfer_server_error(self) -> None:
248        manager = pw_transfer.Manager(
249            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
250
251        self._enqueue_server_error(_Method.READ, Status.NOT_FOUND)
252
253        with self.assertRaises(pw_transfer.Error) as context:
254            manager.read(31)
255
256        exception = context.exception
257        self.assertEqual(exception.transfer_id, 31)
258        self.assertEqual(exception.status, Status.INTERNAL)
259
260    def test_write_transfer_basic(self) -> None:
261        manager = pw_transfer.Manager(
262            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
263
264        self._enqueue_server_responses(
265            _Method.WRITE,
266            (
267                (Chunk(transfer_id=4,
268                       offset=0,
269                       pending_bytes=32,
270                       max_chunk_size_bytes=8), ),
271                (Chunk(transfer_id=4, status=Status.OK.value), ),
272            ),
273        )
274
275        manager.write(4, b'hello')
276        self.assertEqual(len(self._sent_chunks), 2)
277        self.assertEqual(self._received_data(), b'hello')
278
279    def test_write_transfer_max_chunk_size(self) -> None:
280        manager = pw_transfer.Manager(
281            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
282
283        self._enqueue_server_responses(
284            _Method.WRITE,
285            (
286                (Chunk(transfer_id=4,
287                       offset=0,
288                       pending_bytes=32,
289                       max_chunk_size_bytes=8), ),
290                (),
291                (Chunk(transfer_id=4, status=Status.OK.value), ),
292            ),
293        )
294
295        manager.write(4, b'hello world')
296        self.assertEqual(len(self._sent_chunks), 3)
297        self.assertEqual(self._received_data(), b'hello world')
298        self.assertEqual(self._sent_chunks[1].data, b'hello wo')
299        self.assertEqual(self._sent_chunks[2].data, b'rld')
300
301    def test_write_transfer_multiple_parameters(self) -> None:
302        manager = pw_transfer.Manager(
303            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
304
305        self._enqueue_server_responses(
306            _Method.WRITE,
307            (
308                (Chunk(transfer_id=4,
309                       offset=0,
310                       pending_bytes=8,
311                       max_chunk_size_bytes=8), ),
312                (Chunk(transfer_id=4,
313                       offset=8,
314                       pending_bytes=8,
315                       max_chunk_size_bytes=8), ),
316                (Chunk(transfer_id=4, status=Status.OK.value), ),
317            ),
318        )
319
320        manager.write(4, b'data to write')
321        self.assertEqual(len(self._sent_chunks), 3)
322        self.assertEqual(self._received_data(), b'data to write')
323        self.assertEqual(self._sent_chunks[1].data, b'data to ')
324        self.assertEqual(self._sent_chunks[2].data, b'write')
325
326    def test_write_transfer_progress_callback(self) -> None:
327        manager = pw_transfer.Manager(
328            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
329
330        self._enqueue_server_responses(
331            _Method.WRITE,
332            (
333                (Chunk(transfer_id=4,
334                       offset=0,
335                       pending_bytes=8,
336                       max_chunk_size_bytes=8), ),
337                (Chunk(transfer_id=4,
338                       offset=8,
339                       pending_bytes=8,
340                       max_chunk_size_bytes=8), ),
341                (Chunk(transfer_id=4, status=Status.OK.value), ),
342            ),
343        )
344
345        progress: List[pw_transfer.ProgressStats] = []
346
347        manager.write(4, b'data to write', progress.append)
348        self.assertEqual(len(self._sent_chunks), 3)
349        self.assertEqual(self._received_data(), b'data to write')
350        self.assertEqual(self._sent_chunks[1].data, b'data to ')
351        self.assertEqual(self._sent_chunks[2].data, b'write')
352        self.assertEqual(progress, [
353            pw_transfer.ProgressStats(8, 0, 13),
354            pw_transfer.ProgressStats(13, 8, 13),
355            pw_transfer.ProgressStats(13, 13, 13)
356        ])
357
358    def test_write_transfer_rewind(self) -> None:
359        """Write transfer in which the server re-requests an earlier offset."""
360        manager = pw_transfer.Manager(
361            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
362
363        self._enqueue_server_responses(
364            _Method.WRITE,
365            (
366                (Chunk(transfer_id=4,
367                       offset=0,
368                       pending_bytes=8,
369                       max_chunk_size_bytes=8), ),
370                (Chunk(transfer_id=4,
371                       offset=8,
372                       pending_bytes=8,
373                       max_chunk_size_bytes=8), ),
374                (
375                    Chunk(
376                        transfer_id=4,
377                        offset=4,  # rewind
378                        pending_bytes=8,
379                        max_chunk_size_bytes=8), ),
380                (
381                    Chunk(
382                        transfer_id=4,
383                        offset=12,
384                        pending_bytes=16,  # update max size
385                        max_chunk_size_bytes=16), ),
386                (Chunk(transfer_id=4, status=Status.OK.value), ),
387            ),
388        )
389
390        manager.write(4, b'pigweed data transfer')
391        self.assertEqual(len(self._sent_chunks), 5)
392        self.assertEqual(self._sent_chunks[1].data, b'pigweed ')
393        self.assertEqual(self._sent_chunks[2].data, b'data tra')
394        self.assertEqual(self._sent_chunks[3].data, b'eed data')
395        self.assertEqual(self._sent_chunks[4].data, b' transfer')
396
397    def test_write_transfer_bad_offset(self) -> None:
398        manager = pw_transfer.Manager(
399            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
400
401        self._enqueue_server_responses(
402            _Method.WRITE,
403            (
404                (Chunk(transfer_id=4,
405                       offset=0,
406                       pending_bytes=8,
407                       max_chunk_size_bytes=8), ),
408                (
409                    Chunk(
410                        transfer_id=4,
411                        offset=100,  # larger offset than data
412                        pending_bytes=8,
413                        max_chunk_size_bytes=8), ),
414                (Chunk(transfer_id=4, status=Status.OK.value), ),
415            ),
416        )
417
418        with self.assertRaises(pw_transfer.Error) as context:
419            manager.write(4, b'small data')
420
421        exception = context.exception
422        self.assertEqual(exception.transfer_id, 4)
423        self.assertEqual(exception.status, Status.OUT_OF_RANGE)
424
425    def test_write_transfer_error(self) -> None:
426        manager = pw_transfer.Manager(
427            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
428
429        self._enqueue_server_responses(
430            _Method.WRITE,
431            ((Chunk(transfer_id=21, status=Status.UNAVAILABLE.value), ), ),
432        )
433
434        with self.assertRaises(pw_transfer.Error) as context:
435            manager.write(21, b'no write')
436
437        exception = context.exception
438        self.assertEqual(exception.transfer_id, 21)
439        self.assertEqual(exception.status, Status.UNAVAILABLE)
440
441    def test_write_transfer_server_error(self) -> None:
442        manager = pw_transfer.Manager(
443            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
444
445        self._enqueue_server_error(_Method.WRITE, Status.NOT_FOUND)
446
447        with self.assertRaises(pw_transfer.Error) as context:
448            manager.write(21, b'server error')
449
450        exception = context.exception
451        self.assertEqual(exception.transfer_id, 21)
452        self.assertEqual(exception.status, Status.INTERNAL)
453
454    def test_write_transfer_timeout_after_initial_chunk(self) -> None:
455        manager = pw_transfer.Manager(self._service,
456                                      default_response_timeout_s=0.001,
457                                      max_retries=2)
458
459        with self.assertRaises(pw_transfer.Error) as context:
460            manager.write(22, b'no server response!')
461
462        self.assertEqual(
463            self._sent_chunks,
464            [
465                Chunk(transfer_id=22,
466                      type=Chunk.Type.TRANSFER_START),  # initial chunk
467                Chunk(transfer_id=22,
468                      type=Chunk.Type.TRANSFER_START),  # retry 1
469                Chunk(transfer_id=22,
470                      type=Chunk.Type.TRANSFER_START),  # retry 2
471            ])
472
473        exception = context.exception
474        self.assertEqual(exception.transfer_id, 22)
475        self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED)
476
477    def test_write_transfer_timeout_after_intermediate_chunk(self) -> None:
478        """Tests write transfers that timeout after the initial chunk."""
479        manager = pw_transfer.Manager(
480            self._service,
481            default_response_timeout_s=DEFAULT_TIMEOUT_S,
482            max_retries=2)
483
484        self._enqueue_server_responses(
485            _Method.WRITE,
486            [[Chunk(transfer_id=22, pending_bytes=10, max_chunk_size_bytes=5)]
487             ])
488
489        with self.assertRaises(pw_transfer.Error) as context:
490            manager.write(22, b'0123456789')
491
492        last_data_chunk = Chunk(transfer_id=22,
493                                data=b'56789',
494                                offset=5,
495                                remaining_bytes=0,
496                                type=Chunk.Type.TRANSFER_DATA)
497
498        self.assertEqual(
499            self._sent_chunks,
500            [
501                Chunk(transfer_id=22, type=Chunk.Type.TRANSFER_START),
502                Chunk(transfer_id=22,
503                      data=b'01234',
504                      type=Chunk.Type.TRANSFER_DATA),
505                last_data_chunk,  # last chunk
506                last_data_chunk,  # retry 1
507                last_data_chunk,  # retry 2
508            ])
509
510        exception = context.exception
511        self.assertEqual(exception.transfer_id, 22)
512        self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED)
513
514    def test_write_zero_pending_bytes_is_internal_error(self) -> None:
515        manager = pw_transfer.Manager(
516            self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S)
517
518        self._enqueue_server_responses(
519            _Method.WRITE,
520            ((Chunk(transfer_id=23, pending_bytes=0), ), ),
521        )
522
523        with self.assertRaises(pw_transfer.Error) as context:
524            manager.write(23, b'no write')
525
526        exception = context.exception
527        self.assertEqual(exception.transfer_id, 23)
528        self.assertEqual(exception.status, Status.INTERNAL)
529
530
531class ProgressStatsTest(unittest.TestCase):
532    def test_received_percent_known_total(self) -> None:
533        self.assertEqual(
534            pw_transfer.ProgressStats(75, 0, 100).percent_received(), 0.0)
535        self.assertEqual(
536            pw_transfer.ProgressStats(75, 50, 100).percent_received(), 50.0)
537        self.assertEqual(
538            pw_transfer.ProgressStats(100, 100, 100).percent_received(), 100.0)
539
540    def test_received_percent_unknown_total(self) -> None:
541        self.assertTrue(
542            math.isnan(
543                pw_transfer.ProgressStats(75, 50, None).percent_received()))
544        self.assertTrue(
545            math.isnan(
546                pw_transfer.ProgressStats(100, 100, None).percent_received()))
547
548    def test_str_known_total(self) -> None:
549        stats = str(pw_transfer.ProgressStats(75, 50, 100))
550        self.assertIn('75', stats)
551        self.assertIn('50', stats)
552        self.assertIn('100', stats)
553
554    def test_str_unknown_total(self) -> None:
555        stats = str(pw_transfer.ProgressStats(75, 50, None))
556        self.assertIn('75', stats)
557        self.assertIn('50', stats)
558        self.assertIn('unknown', stats)
559
560
561if __name__ == '__main__':
562    unittest.main()
563