1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Client for the pw_transfer service, which transmits data over pw_rpc.""" 15 16import asyncio 17import ctypes 18import logging 19import threading 20from typing import Any, Callable 21 22from pw_rpc.callback_client import BidirectionalStreamingCall 23from pw_status import Status 24 25from pw_transfer.transfer import ( 26 ProgressCallback, 27 ProtocolVersion, 28 ReadTransfer, 29 Transfer, 30 WriteTransfer, 31) 32from pw_transfer.chunk import Chunk 33 34try: 35 from pw_transfer import transfer_pb2 36except ImportError: 37 # For the bazel build, which puts generated protos in a different location. 38 from pigweed.pw_transfer import transfer_pb2 # type: ignore 39 40_LOG = logging.getLogger(__package__) 41 42_TransferDict = dict[int, Transfer] 43 44 45class _TransferStream: 46 def __init__( 47 self, 48 method, 49 chunk_handler: Callable[[Chunk], Any], 50 error_handler: Callable[[Status], Any], 51 max_reopen_attempts=3, 52 ): 53 self._method = method 54 self._chunk_handler = chunk_handler 55 self._error_handler = error_handler 56 self._call: BidirectionalStreamingCall | None = None 57 self._reopen_attempts = 0 58 self._max_reopen_attempts = max_reopen_attempts 59 60 def is_open(self) -> bool: 61 return self._call is not None 62 63 def open(self, force: bool = False) -> None: 64 if force or self._call is None: 65 self._call = self._method.invoke( 66 lambda _, chunk: self._on_chunk_received(chunk), 67 on_error=lambda _, status: self._on_stream_error(status), 68 ) 69 70 def close(self) -> None: 71 if self._call is not None: 72 self._call.cancel() 73 self._call = None 74 75 def send(self, chunk: Chunk) -> None: 76 assert self._call is not None 77 self._call.send(chunk.to_message()) 78 79 def _on_chunk_received(self, chunk: Chunk) -> None: 80 self._reopen_attempts = 0 81 self._chunk_handler(chunk) 82 83 def _on_stream_error(self, rpc_status: Status) -> None: 84 if rpc_status is Status.FAILED_PRECONDITION: 85 # FAILED_PRECONDITION indicates that the stream packet was not 86 # recognized as the stream is not open. Attempt to re-open the 87 # stream automatically. 88 self.open(force=True) 89 else: 90 # Other errors are unrecoverable; clear the stream. 91 _LOG.error('Transfer stream shut down with status %s', rpc_status) 92 self._call = None 93 94 self._error_handler(rpc_status) 95 96 97class Manager: # pylint: disable=too-many-instance-attributes 98 """A manager for transmitting data through an RPC TransferService. 99 100 This should be initialized with an active Manager over an RPC channel. Only 101 one instance of this class should exist for a configured RPC TransferService 102 -- the Manager supports multiple simultaneous transfers. 103 104 When created, a Manager starts a separate thread in which transfer 105 communications and events are handled. 106 """ 107 108 def __init__( 109 self, 110 rpc_transfer_service, 111 *, 112 default_response_timeout_s: float = 2.0, 113 initial_response_timeout_s: float = 4.0, 114 max_retries: int = 3, 115 max_lifetime_retries: int = 1500, 116 max_chunk_size_bytes: int = 1024, 117 default_protocol_version=ProtocolVersion.VERSION_TWO, 118 ): 119 """Initializes a Manager on top of a TransferService. 120 121 Args: 122 rpc_transfer_service: the pw_rpc transfer service client 123 default_response_timeout_s: max time to wait between receiving packets 124 initial_response_timeout_s: timeout for the first packet; may be 125 longer to account for transfer handler initialization 126 max_retries: number of times to retry a single package after a timeout 127 max_lifetime_retires: Cumulative maximum number of times to retry over 128 the course of the transfer before giving up. 129 max_chunk_size_bytes: In a read transfer, the maximum size of data the 130 server should send within a single packet. 131 default_protocol_version: Version of the pw_transfer protocol to use. 132 Defaults to the latest, but can be set to legacy for projects 133 which use legacy devices. 134 """ 135 self._service: Any = rpc_transfer_service 136 self._default_response_timeout_s = default_response_timeout_s 137 self._initial_response_timeout_s = initial_response_timeout_s 138 self.max_retries = max_retries 139 self.max_lifetime_retries = max_lifetime_retries 140 self._max_chunk_size_bytes = max_chunk_size_bytes 141 self._default_protocol_version = default_protocol_version 142 143 # Ongoing transfers in the service by resource ID. 144 self._read_transfers: _TransferDict = {} 145 self._write_transfers: _TransferDict = {} 146 self._next_session_id = ctypes.c_uint32(1) 147 148 self._loop = asyncio.new_event_loop() 149 # Set the event loop for the current thread. 150 asyncio.set_event_loop(self._loop) 151 152 # Queues are used for communication between the Manager context and the 153 # dedicated asyncio transfer thread. 154 self._new_transfer_queue: asyncio.Queue = asyncio.Queue() 155 self._read_chunk_queue: asyncio.Queue = asyncio.Queue() 156 self._write_chunk_queue: asyncio.Queue = asyncio.Queue() 157 self._quit_event = asyncio.Event() 158 159 self._thread = threading.Thread( 160 target=self._start_event_loop_thread, daemon=True 161 ) 162 163 # RPC streams for read and write transfers. These are shareable by 164 # multiple transfers of the same type. 165 self._read_stream = _TransferStream( 166 self._service.Read, 167 lambda chunk: self._loop.call_soon_threadsafe( 168 self._read_chunk_queue.put_nowait, chunk 169 ), 170 self._on_read_error, 171 ) 172 self._write_stream = _TransferStream( 173 self._service.Write, 174 lambda chunk: self._loop.call_soon_threadsafe( 175 self._write_chunk_queue.put_nowait, chunk 176 ), 177 self._on_write_error, 178 ) 179 180 self._thread.start() 181 182 def __del__(self): 183 # Notify the thread that the transfer manager is being destroyed and 184 # wait for it to exit. 185 if self._thread.is_alive(): 186 self._loop.call_soon_threadsafe(self._quit_event.set) 187 self._thread.join() 188 189 def read( 190 self, 191 resource_id: int, 192 progress_callback: ProgressCallback | None = None, 193 protocol_version: ProtocolVersion | None = None, 194 chunk_timeout_s: float | None = None, 195 initial_timeout_s: float | None = None, 196 initial_offset: int = 0, 197 ) -> bytes: 198 """Receives ("downloads") data from the server. 199 200 Args: 201 resource_id: ID of the resource from which to read. 202 progress_callback: Optional callback periodically invoked throughout 203 the transfer with the transfer state. Can be used to provide user- 204 facing status updates such as progress bars. 205 protocol_version: The desired protocol version to use for this 206 transfer. Defaults to the version the manager was initialized 207 (typically VERSION_TWO). 208 chunk_timeout_s: Timeout for any individual chunk. 209 initial_timeout_s: Timeout for the first chunk, overrides 210 chunk_timeout_s. 211 initial_offset: Initial offset to start reading from. Must be 212 supported by the transfer handler. All transfers support starting 213 from 0, the default. Returned bytes will not have any padding 214 related to this initial offset. No seeking is done in the transfer 215 operation on the client side. 216 217 Raises: 218 Error: the transfer failed to complete 219 """ 220 221 if resource_id in self._read_transfers: 222 raise ValueError( 223 f'Read transfer for resource {resource_id} already exists' 224 ) 225 226 if protocol_version is None: 227 protocol_version = self._default_protocol_version 228 229 if protocol_version == ProtocolVersion.LEGACY and initial_offset != 0: 230 raise ValueError( 231 f'Unsupported transfer with offset {initial_offset} started ' 232 + 'with legacy protocol' 233 ) 234 235 session_id = ( 236 resource_id 237 if protocol_version is ProtocolVersion.LEGACY 238 else self.assign_session_id() 239 ) 240 241 if chunk_timeout_s is None: 242 chunk_timeout_s = self._default_response_timeout_s 243 244 if initial_timeout_s is None: 245 initial_timeout_s = self._initial_response_timeout_s 246 247 transfer = ReadTransfer( 248 session_id, 249 resource_id, 250 self._read_stream.send, 251 self._end_read_transfer, 252 chunk_timeout_s, 253 initial_timeout_s, 254 self.max_retries, 255 self.max_lifetime_retries, 256 protocol_version, 257 max_chunk_size=self._max_chunk_size_bytes, 258 progress_callback=progress_callback, 259 initial_offset=initial_offset, 260 ) 261 self._start_read_transfer(transfer) 262 263 transfer.done.wait() 264 265 if not transfer.status.ok(): 266 raise Error(transfer.resource_id, transfer.status) 267 268 return transfer.data 269 270 def write( 271 self, 272 resource_id: int, 273 data: bytes | str, 274 progress_callback: ProgressCallback | None = None, 275 protocol_version: ProtocolVersion | None = None, 276 chunk_timeout_s: Any | None = None, 277 initial_timeout_s: Any | None = None, 278 initial_offset: int = 0, 279 ) -> None: 280 """Transmits ("uploads") data to the server. 281 282 Args: 283 resource_id: ID of the resource to which to write. 284 data: Data to send to the server. 285 progress_callback: Optional callback periodically invoked throughout 286 the transfer with the transfer state. Can be used to provide user- 287 facing status updates such as progress bars. 288 protocol_version: The desired protocol version to use for this 289 transfer. Defaults to the version the manager was initialized 290 (defaults to LATEST). 291 chunk_timeout_s: Timeout for any individual chunk. 292 initial_timeout_s: Timeout for the first chunk, overrides 293 chunk_timeout_s. 294 initial_offset: Initial offset to start writing to. Must be supported 295 by the transfer handler. All transfers support starting from 0, 296 the default. data arg should start with the data you want to see 297 starting at this initial offset on the server. No seeking is done 298 in the transfer operation on the client side. 299 300 Raises: 301 Error: the transfer failed to complete 302 """ 303 304 if isinstance(data, str): 305 data = data.encode() 306 307 if resource_id in self._write_transfers: 308 raise ValueError( 309 f'Write transfer for resource {resource_id} already exists' 310 ) 311 312 if protocol_version is None: 313 protocol_version = self._default_protocol_version 314 315 if ( 316 protocol_version != ProtocolVersion.VERSION_TWO 317 and initial_offset != 0 318 ): 319 raise ValueError( 320 f'Unsupported transfer with offset {initial_offset} started ' 321 + 'with legacy protocol' 322 ) 323 324 session_id = ( 325 resource_id 326 if protocol_version is ProtocolVersion.LEGACY 327 else self.assign_session_id() 328 ) 329 330 if chunk_timeout_s is None: 331 chunk_timeout_s = self._default_response_timeout_s 332 333 if initial_timeout_s is None: 334 initial_timeout_s = self._initial_response_timeout_s 335 336 transfer = WriteTransfer( 337 session_id, 338 resource_id, 339 data, 340 self._write_stream.send, 341 self._end_write_transfer, 342 chunk_timeout_s, 343 initial_timeout_s, 344 self.max_retries, 345 self.max_lifetime_retries, 346 protocol_version, 347 progress_callback=progress_callback, 348 initial_offset=initial_offset, 349 ) 350 self._start_write_transfer(transfer) 351 352 transfer.done.wait() 353 354 if not transfer.status.ok(): 355 raise Error(transfer.resource_id, transfer.status) 356 357 def assign_session_id(self) -> int: 358 new_id = self._next_session_id.value 359 360 self._next_session_id = ctypes.c_uint32(self._next_session_id.value + 1) 361 if self._next_session_id.value == 0: 362 self._next_session_id = ctypes.c_uint32(1) 363 364 return new_id 365 366 def _start_event_loop_thread(self): 367 """Entry point for event loop thread that starts an asyncio context.""" 368 asyncio.set_event_loop(self._loop) 369 370 # Recreate the async communication channels in the context of the 371 # running event loop. 372 self._new_transfer_queue = asyncio.Queue() 373 self._read_chunk_queue = asyncio.Queue() 374 self._write_chunk_queue = asyncio.Queue() 375 self._quit_event = asyncio.Event() 376 377 self._loop.create_task(self._transfer_event_loop()) 378 self._loop.run_forever() 379 380 async def _transfer_event_loop(self): 381 """Main async event loop.""" 382 exit_thread = self._loop.create_task(self._quit_event.wait()) 383 new_transfer = self._loop.create_task(self._new_transfer_queue.get()) 384 read_chunk = self._loop.create_task(self._read_chunk_queue.get()) 385 write_chunk = self._loop.create_task(self._write_chunk_queue.get()) 386 387 while not self._quit_event.is_set(): 388 # Perform a select(2)-like wait for one of several events to occur. 389 done, _ = await asyncio.wait( 390 (exit_thread, new_transfer, read_chunk, write_chunk), 391 return_when=asyncio.FIRST_COMPLETED, 392 ) 393 394 if exit_thread in done: 395 break 396 397 if new_transfer in done: 398 await new_transfer.result().begin() 399 new_transfer = self._loop.create_task( 400 self._new_transfer_queue.get() 401 ) 402 403 if read_chunk in done: 404 self._loop.create_task( 405 self._handle_chunk( 406 self._read_transfers, read_chunk.result() 407 ) 408 ) 409 read_chunk = self._loop.create_task( 410 self._read_chunk_queue.get() 411 ) 412 413 if write_chunk in done: 414 self._loop.create_task( 415 self._handle_chunk( 416 self._write_transfers, write_chunk.result() 417 ) 418 ) 419 write_chunk = self._loop.create_task( 420 self._write_chunk_queue.get() 421 ) 422 423 self._loop.stop() 424 425 @staticmethod 426 async def _handle_chunk( 427 transfers: _TransferDict, message: transfer_pb2.Chunk 428 ) -> None: 429 """Processes an incoming chunk from a stream. 430 431 The chunk is dispatched to an active transfer based on its ID. If the 432 transfer indicates that it is complete, the provided completion callback 433 is invoked. 434 """ 435 436 chunk = Chunk.from_message(message) 437 438 # Find a transfer for the chunk in the list of active transfers. 439 try: 440 if chunk.protocol_version is ProtocolVersion.LEGACY: 441 transfer = next( 442 t 443 for t in transfers.values() 444 if t.resource_id == chunk.session_id 445 ) 446 else: 447 transfer = next( 448 t for t in transfers.values() if t.id == chunk.id() 449 ) 450 except StopIteration: 451 _LOG.error( 452 'TransferManager received chunk for unknown transfer %d', 453 chunk.id(), 454 ) 455 # TODO(frolv): What should be done here, if anything? 456 return 457 458 await transfer.handle_chunk(chunk) 459 460 def _on_read_error(self, status: Status) -> None: 461 """Callback for an RPC error in the read stream.""" 462 463 for transfer in self._read_transfers.values(): 464 transfer.finish(Status.INTERNAL, skip_callback=True) 465 self._read_transfers.clear() 466 467 _LOG.error('Read stream shut down: %s', status) 468 469 def _on_write_error(self, status: Status) -> None: 470 """Callback for an RPC error in the write stream.""" 471 472 for transfer in self._write_transfers.values(): 473 transfer.finish(Status.INTERNAL, skip_callback=True) 474 self._write_transfers.clear() 475 476 _LOG.error('Write stream shut down: %s', status) 477 478 def _start_read_transfer(self, transfer: Transfer) -> None: 479 """Begins a new read transfer, opening the stream if it isn't.""" 480 481 self._read_transfers[transfer.resource_id] = transfer 482 self._read_stream.open() 483 484 _LOG.debug('Starting new read transfer %d', transfer.id) 485 self._loop.call_soon_threadsafe( 486 self._new_transfer_queue.put_nowait, transfer 487 ) 488 489 def _end_read_transfer(self, transfer: Transfer) -> None: 490 """Completes a read transfer.""" 491 del self._read_transfers[transfer.resource_id] 492 493 if not transfer.status.ok(): 494 _LOG.error( 495 'Read transfer %d terminated with status %s', 496 transfer.id, 497 transfer.status, 498 ) 499 500 def _start_write_transfer(self, transfer: Transfer) -> None: 501 """Begins a new write transfer, opening the stream if it isn't.""" 502 503 self._write_transfers[transfer.resource_id] = transfer 504 self._write_stream.open() 505 506 _LOG.debug('Starting new write transfer %d', transfer.id) 507 self._loop.call_soon_threadsafe( 508 self._new_transfer_queue.put_nowait, transfer 509 ) 510 511 def _end_write_transfer(self, transfer: Transfer) -> None: 512 """Completes a write transfer.""" 513 del self._write_transfers[transfer.resource_id] 514 515 if not transfer.status.ok(): 516 _LOG.error( 517 'Write transfer %d terminated with status %s', 518 transfer.id, 519 transfer.status, 520 ) 521 522 523class Error(Exception): 524 """Exception raised when a transfer fails. 525 526 Stores the ID of the failed transfer resource and the error that occurred. 527 """ 528 529 def __init__(self, resource_id: int, status: Status): 530 super().__init__(f'Transfer {resource_id} failed with status {status}') 531 self.resource_id = resource_id 532 self.status = status 533