1#!/usr/bin/env python3 2# Copyright 2022 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests for the transfer service client.""" 16 17import enum 18import math 19import unittest 20from typing import Iterable, List 21 22from pw_status import Status 23from pw_rpc import callback_client, client, ids, packets 24from pw_rpc.internal import packet_pb2 25 26import pw_transfer 27from pw_transfer.transfer_pb2 import Chunk 28 29_TRANSFER_SERVICE_ID = ids.calculate('pw.transfer.Transfer') 30 31# If the default timeout is too short, some tests become flaky on Windows. 32DEFAULT_TIMEOUT_S = 0.3 33 34 35class _Method(enum.Enum): 36 READ = ids.calculate('Read') 37 WRITE = ids.calculate('Write') 38 39 40class TransferManagerTest(unittest.TestCase): 41 """Tests for the transfer manager.""" 42 def setUp(self) -> None: 43 self._client = client.Client.from_modules( 44 callback_client.Impl(), [client.Channel(1, self._handle_request)], 45 (pw_transfer.transfer_pb2, )) 46 self._service = self._client.channel(1).rpcs.pw.transfer.Transfer 47 48 self._sent_chunks: List[Chunk] = [] 49 self._packets_to_send: List[List[bytes]] = [] 50 51 def _enqueue_server_responses( 52 self, method: _Method, 53 responses: Iterable[Iterable[Chunk]]) -> None: 54 for group in responses: 55 serialized_group = [] 56 for response in group: 57 serialized_group.append( 58 packet_pb2.RpcPacket( 59 type=packet_pb2.PacketType.SERVER_STREAM, 60 channel_id=1, 61 service_id=_TRANSFER_SERVICE_ID, 62 method_id=method.value, 63 status=Status.OK.value, 64 payload=response.SerializeToString()). 65 SerializeToString()) 66 self._packets_to_send.append(serialized_group) 67 68 def _enqueue_server_error(self, method: _Method, error: Status) -> None: 69 self._packets_to_send.append([ 70 packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR, 71 channel_id=1, 72 service_id=_TRANSFER_SERVICE_ID, 73 method_id=method.value, 74 status=error.value).SerializeToString() 75 ]) 76 77 def _handle_request(self, data: bytes) -> None: 78 packet = packets.decode(data) 79 if packet.type is not packet_pb2.PacketType.CLIENT_STREAM: 80 return 81 82 chunk = Chunk() 83 chunk.MergeFromString(packet.payload) 84 self._sent_chunks.append(chunk) 85 86 if self._packets_to_send: 87 responses = self._packets_to_send.pop(0) 88 for response in responses: 89 self._client.process_packet(response) 90 91 def _received_data(self) -> bytearray: 92 data = bytearray() 93 for chunk in self._sent_chunks: 94 data.extend(chunk.data) 95 return data 96 97 def test_read_transfer_basic(self): 98 manager = pw_transfer.Manager( 99 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 100 101 self._enqueue_server_responses( 102 _Method.READ, 103 ((Chunk(transfer_id=3, offset=0, data=b'abc', 104 remaining_bytes=0), ), ), 105 ) 106 107 data = manager.read(3) 108 self.assertEqual(data, b'abc') 109 self.assertEqual(len(self._sent_chunks), 2) 110 self.assertTrue(self._sent_chunks[-1].HasField('status')) 111 self.assertEqual(self._sent_chunks[-1].status, 0) 112 113 def test_read_transfer_multichunk(self) -> None: 114 manager = pw_transfer.Manager( 115 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 116 117 self._enqueue_server_responses( 118 _Method.READ, 119 (( 120 Chunk(transfer_id=3, offset=0, data=b'abc', remaining_bytes=3), 121 Chunk(transfer_id=3, offset=3, data=b'def', remaining_bytes=0), 122 ), ), 123 ) 124 125 data = manager.read(3) 126 self.assertEqual(data, b'abcdef') 127 self.assertEqual(len(self._sent_chunks), 2) 128 self.assertTrue(self._sent_chunks[-1].HasField('status')) 129 self.assertEqual(self._sent_chunks[-1].status, 0) 130 131 def test_read_transfer_progress_callback(self) -> None: 132 manager = pw_transfer.Manager( 133 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 134 135 self._enqueue_server_responses( 136 _Method.READ, 137 (( 138 Chunk(transfer_id=3, offset=0, data=b'abc', remaining_bytes=3), 139 Chunk(transfer_id=3, offset=3, data=b'def', remaining_bytes=0), 140 ), ), 141 ) 142 143 progress: List[pw_transfer.ProgressStats] = [] 144 145 data = manager.read(3, progress.append) 146 self.assertEqual(data, b'abcdef') 147 self.assertEqual(len(self._sent_chunks), 2) 148 self.assertTrue(self._sent_chunks[-1].HasField('status')) 149 self.assertEqual(self._sent_chunks[-1].status, 0) 150 self.assertEqual(progress, [ 151 pw_transfer.ProgressStats(3, 3, 6), 152 pw_transfer.ProgressStats(6, 6, 6), 153 ]) 154 155 def test_read_transfer_retry_bad_offset(self) -> None: 156 """Server responds with an unexpected offset in a read transfer.""" 157 manager = pw_transfer.Manager( 158 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 159 160 self._enqueue_server_responses( 161 _Method.READ, 162 ( 163 ( 164 Chunk(transfer_id=3, 165 offset=0, 166 data=b'123', 167 remaining_bytes=6), 168 169 # Incorrect offset; expecting 3. 170 Chunk(transfer_id=3, 171 offset=1, 172 data=b'456', 173 remaining_bytes=3), 174 ), 175 ( 176 Chunk(transfer_id=3, 177 offset=3, 178 data=b'456', 179 remaining_bytes=3), 180 Chunk(transfer_id=3, 181 offset=6, 182 data=b'789', 183 remaining_bytes=0), 184 ), 185 )) 186 187 data = manager.read(3) 188 self.assertEqual(data, b'123456789') 189 190 # Two transfer parameter requests should have been sent. 191 self.assertEqual(len(self._sent_chunks), 3) 192 self.assertTrue(self._sent_chunks[-1].HasField('status')) 193 self.assertEqual(self._sent_chunks[-1].status, 0) 194 195 def test_read_transfer_retry_timeout(self) -> None: 196 """Server doesn't respond to read transfer parameters.""" 197 manager = pw_transfer.Manager( 198 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 199 200 self._enqueue_server_responses( 201 _Method.READ, 202 ( 203 (), # Send nothing in response to the initial parameters. 204 (Chunk(transfer_id=3, offset=0, data=b'xyz', 205 remaining_bytes=0), ), 206 )) 207 208 data = manager.read(3) 209 self.assertEqual(data, b'xyz') 210 211 # Two transfer parameter requests should have been sent. 212 self.assertEqual(len(self._sent_chunks), 3) 213 self.assertTrue(self._sent_chunks[-1].HasField('status')) 214 self.assertEqual(self._sent_chunks[-1].status, 0) 215 216 def test_read_transfer_timeout(self) -> None: 217 manager = pw_transfer.Manager( 218 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 219 220 with self.assertRaises(pw_transfer.Error) as context: 221 manager.read(27) 222 223 exception = context.exception 224 self.assertEqual(exception.transfer_id, 27) 225 self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED) 226 227 # The client should have sent four transfer parameters requests: one 228 # initial, and three retries. 229 self.assertEqual(len(self._sent_chunks), 4) 230 231 def test_read_transfer_error(self) -> None: 232 manager = pw_transfer.Manager( 233 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 234 235 self._enqueue_server_responses( 236 _Method.READ, 237 ((Chunk(transfer_id=31, status=Status.NOT_FOUND.value), ), ), 238 ) 239 240 with self.assertRaises(pw_transfer.Error) as context: 241 manager.read(31) 242 243 exception = context.exception 244 self.assertEqual(exception.transfer_id, 31) 245 self.assertEqual(exception.status, Status.NOT_FOUND) 246 247 def test_read_transfer_server_error(self) -> None: 248 manager = pw_transfer.Manager( 249 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 250 251 self._enqueue_server_error(_Method.READ, Status.NOT_FOUND) 252 253 with self.assertRaises(pw_transfer.Error) as context: 254 manager.read(31) 255 256 exception = context.exception 257 self.assertEqual(exception.transfer_id, 31) 258 self.assertEqual(exception.status, Status.INTERNAL) 259 260 def test_write_transfer_basic(self) -> None: 261 manager = pw_transfer.Manager( 262 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 263 264 self._enqueue_server_responses( 265 _Method.WRITE, 266 ( 267 (Chunk(transfer_id=4, 268 offset=0, 269 pending_bytes=32, 270 max_chunk_size_bytes=8), ), 271 (Chunk(transfer_id=4, status=Status.OK.value), ), 272 ), 273 ) 274 275 manager.write(4, b'hello') 276 self.assertEqual(len(self._sent_chunks), 2) 277 self.assertEqual(self._received_data(), b'hello') 278 279 def test_write_transfer_max_chunk_size(self) -> None: 280 manager = pw_transfer.Manager( 281 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 282 283 self._enqueue_server_responses( 284 _Method.WRITE, 285 ( 286 (Chunk(transfer_id=4, 287 offset=0, 288 pending_bytes=32, 289 max_chunk_size_bytes=8), ), 290 (), 291 (Chunk(transfer_id=4, status=Status.OK.value), ), 292 ), 293 ) 294 295 manager.write(4, b'hello world') 296 self.assertEqual(len(self._sent_chunks), 3) 297 self.assertEqual(self._received_data(), b'hello world') 298 self.assertEqual(self._sent_chunks[1].data, b'hello wo') 299 self.assertEqual(self._sent_chunks[2].data, b'rld') 300 301 def test_write_transfer_multiple_parameters(self) -> None: 302 manager = pw_transfer.Manager( 303 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 304 305 self._enqueue_server_responses( 306 _Method.WRITE, 307 ( 308 (Chunk(transfer_id=4, 309 offset=0, 310 pending_bytes=8, 311 max_chunk_size_bytes=8), ), 312 (Chunk(transfer_id=4, 313 offset=8, 314 pending_bytes=8, 315 max_chunk_size_bytes=8), ), 316 (Chunk(transfer_id=4, status=Status.OK.value), ), 317 ), 318 ) 319 320 manager.write(4, b'data to write') 321 self.assertEqual(len(self._sent_chunks), 3) 322 self.assertEqual(self._received_data(), b'data to write') 323 self.assertEqual(self._sent_chunks[1].data, b'data to ') 324 self.assertEqual(self._sent_chunks[2].data, b'write') 325 326 def test_write_transfer_progress_callback(self) -> None: 327 manager = pw_transfer.Manager( 328 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 329 330 self._enqueue_server_responses( 331 _Method.WRITE, 332 ( 333 (Chunk(transfer_id=4, 334 offset=0, 335 pending_bytes=8, 336 max_chunk_size_bytes=8), ), 337 (Chunk(transfer_id=4, 338 offset=8, 339 pending_bytes=8, 340 max_chunk_size_bytes=8), ), 341 (Chunk(transfer_id=4, status=Status.OK.value), ), 342 ), 343 ) 344 345 progress: List[pw_transfer.ProgressStats] = [] 346 347 manager.write(4, b'data to write', progress.append) 348 self.assertEqual(len(self._sent_chunks), 3) 349 self.assertEqual(self._received_data(), b'data to write') 350 self.assertEqual(self._sent_chunks[1].data, b'data to ') 351 self.assertEqual(self._sent_chunks[2].data, b'write') 352 self.assertEqual(progress, [ 353 pw_transfer.ProgressStats(8, 0, 13), 354 pw_transfer.ProgressStats(13, 8, 13), 355 pw_transfer.ProgressStats(13, 13, 13) 356 ]) 357 358 def test_write_transfer_rewind(self) -> None: 359 """Write transfer in which the server re-requests an earlier offset.""" 360 manager = pw_transfer.Manager( 361 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 362 363 self._enqueue_server_responses( 364 _Method.WRITE, 365 ( 366 (Chunk(transfer_id=4, 367 offset=0, 368 pending_bytes=8, 369 max_chunk_size_bytes=8), ), 370 (Chunk(transfer_id=4, 371 offset=8, 372 pending_bytes=8, 373 max_chunk_size_bytes=8), ), 374 ( 375 Chunk( 376 transfer_id=4, 377 offset=4, # rewind 378 pending_bytes=8, 379 max_chunk_size_bytes=8), ), 380 ( 381 Chunk( 382 transfer_id=4, 383 offset=12, 384 pending_bytes=16, # update max size 385 max_chunk_size_bytes=16), ), 386 (Chunk(transfer_id=4, status=Status.OK.value), ), 387 ), 388 ) 389 390 manager.write(4, b'pigweed data transfer') 391 self.assertEqual(len(self._sent_chunks), 5) 392 self.assertEqual(self._sent_chunks[1].data, b'pigweed ') 393 self.assertEqual(self._sent_chunks[2].data, b'data tra') 394 self.assertEqual(self._sent_chunks[3].data, b'eed data') 395 self.assertEqual(self._sent_chunks[4].data, b' transfer') 396 397 def test_write_transfer_bad_offset(self) -> None: 398 manager = pw_transfer.Manager( 399 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 400 401 self._enqueue_server_responses( 402 _Method.WRITE, 403 ( 404 (Chunk(transfer_id=4, 405 offset=0, 406 pending_bytes=8, 407 max_chunk_size_bytes=8), ), 408 ( 409 Chunk( 410 transfer_id=4, 411 offset=100, # larger offset than data 412 pending_bytes=8, 413 max_chunk_size_bytes=8), ), 414 (Chunk(transfer_id=4, status=Status.OK.value), ), 415 ), 416 ) 417 418 with self.assertRaises(pw_transfer.Error) as context: 419 manager.write(4, b'small data') 420 421 exception = context.exception 422 self.assertEqual(exception.transfer_id, 4) 423 self.assertEqual(exception.status, Status.OUT_OF_RANGE) 424 425 def test_write_transfer_error(self) -> None: 426 manager = pw_transfer.Manager( 427 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 428 429 self._enqueue_server_responses( 430 _Method.WRITE, 431 ((Chunk(transfer_id=21, status=Status.UNAVAILABLE.value), ), ), 432 ) 433 434 with self.assertRaises(pw_transfer.Error) as context: 435 manager.write(21, b'no write') 436 437 exception = context.exception 438 self.assertEqual(exception.transfer_id, 21) 439 self.assertEqual(exception.status, Status.UNAVAILABLE) 440 441 def test_write_transfer_server_error(self) -> None: 442 manager = pw_transfer.Manager( 443 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 444 445 self._enqueue_server_error(_Method.WRITE, Status.NOT_FOUND) 446 447 with self.assertRaises(pw_transfer.Error) as context: 448 manager.write(21, b'server error') 449 450 exception = context.exception 451 self.assertEqual(exception.transfer_id, 21) 452 self.assertEqual(exception.status, Status.INTERNAL) 453 454 def test_write_transfer_timeout_after_initial_chunk(self) -> None: 455 manager = pw_transfer.Manager(self._service, 456 default_response_timeout_s=0.001, 457 max_retries=2) 458 459 with self.assertRaises(pw_transfer.Error) as context: 460 manager.write(22, b'no server response!') 461 462 self.assertEqual( 463 self._sent_chunks, 464 [ 465 Chunk(transfer_id=22, 466 type=Chunk.Type.TRANSFER_START), # initial chunk 467 Chunk(transfer_id=22, 468 type=Chunk.Type.TRANSFER_START), # retry 1 469 Chunk(transfer_id=22, 470 type=Chunk.Type.TRANSFER_START), # retry 2 471 ]) 472 473 exception = context.exception 474 self.assertEqual(exception.transfer_id, 22) 475 self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED) 476 477 def test_write_transfer_timeout_after_intermediate_chunk(self) -> None: 478 """Tests write transfers that timeout after the initial chunk.""" 479 manager = pw_transfer.Manager( 480 self._service, 481 default_response_timeout_s=DEFAULT_TIMEOUT_S, 482 max_retries=2) 483 484 self._enqueue_server_responses( 485 _Method.WRITE, 486 [[Chunk(transfer_id=22, pending_bytes=10, max_chunk_size_bytes=5)] 487 ]) 488 489 with self.assertRaises(pw_transfer.Error) as context: 490 manager.write(22, b'0123456789') 491 492 last_data_chunk = Chunk(transfer_id=22, 493 data=b'56789', 494 offset=5, 495 remaining_bytes=0, 496 type=Chunk.Type.TRANSFER_DATA) 497 498 self.assertEqual( 499 self._sent_chunks, 500 [ 501 Chunk(transfer_id=22, type=Chunk.Type.TRANSFER_START), 502 Chunk(transfer_id=22, 503 data=b'01234', 504 type=Chunk.Type.TRANSFER_DATA), 505 last_data_chunk, # last chunk 506 last_data_chunk, # retry 1 507 last_data_chunk, # retry 2 508 ]) 509 510 exception = context.exception 511 self.assertEqual(exception.transfer_id, 22) 512 self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED) 513 514 def test_write_zero_pending_bytes_is_internal_error(self) -> None: 515 manager = pw_transfer.Manager( 516 self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) 517 518 self._enqueue_server_responses( 519 _Method.WRITE, 520 ((Chunk(transfer_id=23, pending_bytes=0), ), ), 521 ) 522 523 with self.assertRaises(pw_transfer.Error) as context: 524 manager.write(23, b'no write') 525 526 exception = context.exception 527 self.assertEqual(exception.transfer_id, 23) 528 self.assertEqual(exception.status, Status.INTERNAL) 529 530 531class ProgressStatsTest(unittest.TestCase): 532 def test_received_percent_known_total(self) -> None: 533 self.assertEqual( 534 pw_transfer.ProgressStats(75, 0, 100).percent_received(), 0.0) 535 self.assertEqual( 536 pw_transfer.ProgressStats(75, 50, 100).percent_received(), 50.0) 537 self.assertEqual( 538 pw_transfer.ProgressStats(100, 100, 100).percent_received(), 100.0) 539 540 def test_received_percent_unknown_total(self) -> None: 541 self.assertTrue( 542 math.isnan( 543 pw_transfer.ProgressStats(75, 50, None).percent_received())) 544 self.assertTrue( 545 math.isnan( 546 pw_transfer.ProgressStats(100, 100, None).percent_received())) 547 548 def test_str_known_total(self) -> None: 549 stats = str(pw_transfer.ProgressStats(75, 50, 100)) 550 self.assertIn('75', stats) 551 self.assertIn('50', stats) 552 self.assertIn('100', stats) 553 554 def test_str_unknown_total(self) -> None: 555 stats = str(pw_transfer.ProgressStats(75, 50, None)) 556 self.assertIn('75', stats) 557 self.assertIn('50', stats) 558 self.assertIn('unknown', stats) 559 560 561if __name__ == '__main__': 562 unittest.main() 563