1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Client for the pw_transfer service, which transmits data over pw_rpc.""" 15 16import asyncio 17import ctypes 18import logging 19import threading 20from typing import Any, Callable 21 22from pw_rpc.callback_client import BidirectionalStreamingCall 23from pw_status import Status 24 25from pw_transfer.transfer import ( 26 ProgressCallback, 27 ProtocolVersion, 28 ReadTransfer, 29 Transfer, 30 WriteTransfer, 31) 32from pw_transfer.chunk import Chunk 33from pw_transfer import transfer_pb2 34 35_LOG = logging.getLogger(__package__) 36 37_TransferDict = dict[int, Transfer] 38 39 40class _TransferStream: 41 def __init__( 42 self, 43 method, 44 chunk_handler: Callable[[Chunk], Any], 45 error_handler: Callable[[Status], Any], 46 max_reopen_attempts=3, 47 ): 48 self._method = method 49 self._chunk_handler = chunk_handler 50 self._error_handler = error_handler 51 self._call: BidirectionalStreamingCall | None = None 52 self._reopen_attempts = 0 53 self._max_reopen_attempts = max_reopen_attempts 54 55 def is_open(self) -> bool: 56 return self._call is not None 57 58 def open(self, force: bool = False) -> None: 59 if force or self._call is None: 60 self._call = self._method.invoke( 61 lambda _, chunk: self._on_chunk_received(chunk), 62 on_error=lambda _, status: self._on_stream_error(status), 63 ) 64 65 def close(self) -> None: 66 if self._call is not None: 67 self._call.cancel() 68 self._call = None 69 70 def send(self, chunk: Chunk) -> None: 71 assert self._call is not None 72 self._call.send(chunk.to_message()) 73 74 def _on_chunk_received(self, chunk: Chunk) -> None: 75 self._reopen_attempts = 0 76 self._chunk_handler(chunk) 77 78 def _on_stream_error(self, rpc_status: Status) -> None: 79 if rpc_status is Status.FAILED_PRECONDITION: 80 # FAILED_PRECONDITION indicates that the stream packet was not 81 # recognized as the stream is not open. Attempt to re-open the 82 # stream automatically. 83 self.open(force=True) 84 else: 85 # Other errors are unrecoverable; clear the stream. 86 _LOG.error('Transfer stream shut down with status %s', rpc_status) 87 self._call = None 88 89 self._error_handler(rpc_status) 90 91 92class Manager: # pylint: disable=too-many-instance-attributes 93 """A manager for transmitting data through an RPC TransferService. 94 95 This should be initialized with an active Manager over an RPC channel. Only 96 one instance of this class should exist for a configured RPC TransferService 97 -- the Manager supports multiple simultaneous transfers. 98 99 When created, a Manager starts a separate thread in which transfer 100 communications and events are handled. 101 """ 102 103 def __init__( 104 self, 105 rpc_transfer_service, 106 *, 107 default_response_timeout_s: float = 2.0, 108 initial_response_timeout_s: float = 4.0, 109 max_retries: int = 3, 110 max_lifetime_retries: int = 1500, 111 max_chunk_size_bytes: int = 1024, 112 default_protocol_version=ProtocolVersion.VERSION_TWO, 113 ): 114 """Initializes a Manager on top of a TransferService. 115 116 Args: 117 rpc_transfer_service: the pw_rpc transfer service client 118 default_response_timeout_s: max time to wait between receiving packets 119 initial_response_timeout_s: timeout for the first packet; may be 120 longer to account for transfer handler initialization 121 max_retries: number of times to retry a single package after a timeout 122 max_lifetime_retires: Cumulative maximum number of times to retry over 123 the course of the transfer before giving up. 124 max_chunk_size_bytes: In a read transfer, the maximum size of data the 125 server should send within a single packet. 126 default_protocol_version: Version of the pw_transfer protocol to use. 127 Defaults to the latest, but can be set to legacy for projects 128 which use legacy devices. 129 """ 130 self._service: Any = rpc_transfer_service 131 self._default_response_timeout_s = default_response_timeout_s 132 self._initial_response_timeout_s = initial_response_timeout_s 133 self.max_retries = max_retries 134 self.max_lifetime_retries = max_lifetime_retries 135 self._max_chunk_size_bytes = max_chunk_size_bytes 136 self._default_protocol_version = default_protocol_version 137 138 # Ongoing transfers in the service by resource ID. 139 self._read_transfers: _TransferDict = {} 140 self._write_transfers: _TransferDict = {} 141 self._next_session_id = ctypes.c_uint32(1) 142 143 self._loop = asyncio.new_event_loop() 144 # Set the event loop for the current thread. 145 asyncio.set_event_loop(self._loop) 146 147 # Queues are used for communication between the Manager context and the 148 # dedicated asyncio transfer thread. 149 self._new_transfer_queue: asyncio.Queue = asyncio.Queue() 150 self._read_chunk_queue: asyncio.Queue = asyncio.Queue() 151 self._write_chunk_queue: asyncio.Queue = asyncio.Queue() 152 self._quit_event = asyncio.Event() 153 154 self._thread = threading.Thread( 155 target=self._start_event_loop_thread, daemon=True 156 ) 157 158 # RPC streams for read and write transfers. These are shareable by 159 # multiple transfers of the same type. 160 self._read_stream = _TransferStream( 161 self._service.Read, 162 lambda chunk: self._loop.call_soon_threadsafe( 163 self._read_chunk_queue.put_nowait, chunk 164 ), 165 self._on_read_error, 166 ) 167 self._write_stream = _TransferStream( 168 self._service.Write, 169 lambda chunk: self._loop.call_soon_threadsafe( 170 self._write_chunk_queue.put_nowait, chunk 171 ), 172 self._on_write_error, 173 ) 174 175 self._thread.start() 176 177 def __del__(self): 178 # Notify the thread that the transfer manager is being destroyed and 179 # wait for it to exit. 180 if self._thread.is_alive(): 181 self._loop.call_soon_threadsafe(self._quit_event.set) 182 self._thread.join() 183 184 def read( 185 self, 186 resource_id: int, 187 progress_callback: ProgressCallback | None = None, 188 protocol_version: ProtocolVersion | None = None, 189 chunk_timeout_s: float | None = None, 190 initial_timeout_s: float | None = None, 191 initial_offset: int = 0, 192 max_window_size: int = 32768, 193 ) -> bytes: 194 """Receives ("downloads") data from the server. 195 196 Args: 197 resource_id: ID of the resource from which to read. 198 progress_callback: Optional callback periodically invoked throughout 199 the transfer with the transfer state. Can be used to provide user- 200 facing status updates such as progress bars. 201 protocol_version: The desired protocol version to use for this 202 transfer. Defaults to the version the manager was initialized 203 (typically VERSION_TWO). 204 chunk_timeout_s: Timeout for any individual chunk. 205 initial_timeout_s: Timeout for the first chunk, overrides 206 chunk_timeout_s. 207 initial_offset: Initial offset to start reading from. Must be 208 supported by the transfer handler. All transfers support starting 209 from 0, the default. Returned bytes will not have any padding 210 related to this initial offset. No seeking is done in the transfer 211 operation on the client side. 212 213 Raises: 214 Error: the transfer failed to complete 215 """ 216 217 if resource_id in self._read_transfers: 218 raise ValueError( 219 f'Read transfer for resource {resource_id} already exists' 220 ) 221 222 if protocol_version is None: 223 protocol_version = self._default_protocol_version 224 225 if protocol_version == ProtocolVersion.LEGACY and initial_offset != 0: 226 raise ValueError( 227 f'Unsupported transfer with offset {initial_offset} started ' 228 + 'with legacy protocol' 229 ) 230 231 session_id = ( 232 resource_id 233 if protocol_version is ProtocolVersion.LEGACY 234 else self.assign_session_id() 235 ) 236 237 if chunk_timeout_s is None: 238 chunk_timeout_s = self._default_response_timeout_s 239 240 if initial_timeout_s is None: 241 initial_timeout_s = self._initial_response_timeout_s 242 243 transfer = ReadTransfer( 244 session_id, 245 resource_id, 246 self._read_stream.send, 247 self._end_read_transfer, 248 chunk_timeout_s, 249 initial_timeout_s, 250 self.max_retries, 251 self.max_lifetime_retries, 252 protocol_version, 253 max_window_size_bytes=max_window_size, 254 max_chunk_size=self._max_chunk_size_bytes, 255 progress_callback=progress_callback, 256 initial_offset=initial_offset, 257 ) 258 self._start_read_transfer(transfer) 259 260 transfer.done.wait() 261 262 if not transfer.status.ok(): 263 raise Error(transfer.resource_id, transfer.status) 264 265 return transfer.data 266 267 def write( 268 self, 269 resource_id: int, 270 data: bytes | str, 271 progress_callback: ProgressCallback | None = None, 272 protocol_version: ProtocolVersion | None = None, 273 chunk_timeout_s: Any | None = None, 274 initial_timeout_s: Any | None = None, 275 initial_offset: int = 0, 276 ) -> None: 277 """Transmits ("uploads") data to the server. 278 279 Args: 280 resource_id: ID of the resource to which to write. 281 data: Data to send to the server. 282 progress_callback: Optional callback periodically invoked throughout 283 the transfer with the transfer state. Can be used to provide user- 284 facing status updates such as progress bars. 285 protocol_version: The desired protocol version to use for this 286 transfer. Defaults to the version the manager was initialized 287 (defaults to LATEST). 288 chunk_timeout_s: Timeout for any individual chunk. 289 initial_timeout_s: Timeout for the first chunk, overrides 290 chunk_timeout_s. 291 initial_offset: Initial offset to start writing to. Must be supported 292 by the transfer handler. All transfers support starting from 0, 293 the default. data arg should start with the data you want to see 294 starting at this initial offset on the server. No seeking is done 295 in the transfer operation on the client side. 296 297 Raises: 298 Error: the transfer failed to complete 299 """ 300 301 if isinstance(data, str): 302 data = data.encode() 303 304 if resource_id in self._write_transfers: 305 raise ValueError( 306 f'Write transfer for resource {resource_id} already exists' 307 ) 308 309 if protocol_version is None: 310 protocol_version = self._default_protocol_version 311 312 if ( 313 protocol_version != ProtocolVersion.VERSION_TWO 314 and initial_offset != 0 315 ): 316 raise ValueError( 317 f'Unsupported transfer with offset {initial_offset} started ' 318 + 'with legacy protocol' 319 ) 320 321 session_id = ( 322 resource_id 323 if protocol_version is ProtocolVersion.LEGACY 324 else self.assign_session_id() 325 ) 326 327 if chunk_timeout_s is None: 328 chunk_timeout_s = self._default_response_timeout_s 329 330 if initial_timeout_s is None: 331 initial_timeout_s = self._initial_response_timeout_s 332 333 transfer = WriteTransfer( 334 session_id, 335 resource_id, 336 data, 337 self._write_stream.send, 338 self._end_write_transfer, 339 chunk_timeout_s, 340 initial_timeout_s, 341 self.max_retries, 342 self.max_lifetime_retries, 343 protocol_version, 344 progress_callback=progress_callback, 345 initial_offset=initial_offset, 346 ) 347 self._start_write_transfer(transfer) 348 349 transfer.done.wait() 350 351 if not transfer.status.ok(): 352 raise Error(transfer.resource_id, transfer.status) 353 354 def assign_session_id(self) -> int: 355 new_id = self._next_session_id.value 356 357 self._next_session_id = ctypes.c_uint32(self._next_session_id.value + 1) 358 if self._next_session_id.value == 0: 359 self._next_session_id = ctypes.c_uint32(1) 360 361 return new_id 362 363 def _start_event_loop_thread(self): 364 """Entry point for event loop thread that starts an asyncio context.""" 365 asyncio.set_event_loop(self._loop) 366 367 # Recreate the async communication channels in the context of the 368 # running event loop. 369 self._new_transfer_queue = asyncio.Queue() 370 self._read_chunk_queue = asyncio.Queue() 371 self._write_chunk_queue = asyncio.Queue() 372 self._quit_event = asyncio.Event() 373 374 self._loop.create_task(self._transfer_event_loop()) 375 self._loop.run_forever() 376 377 async def _transfer_event_loop(self): 378 """Main async event loop.""" 379 exit_thread = self._loop.create_task(self._quit_event.wait()) 380 new_transfer = self._loop.create_task(self._new_transfer_queue.get()) 381 read_chunk = self._loop.create_task(self._read_chunk_queue.get()) 382 write_chunk = self._loop.create_task(self._write_chunk_queue.get()) 383 384 while not self._quit_event.is_set(): 385 # Perform a select(2)-like wait for one of several events to occur. 386 done, _ = await asyncio.wait( 387 (exit_thread, new_transfer, read_chunk, write_chunk), 388 return_when=asyncio.FIRST_COMPLETED, 389 ) 390 391 if exit_thread in done: 392 break 393 394 if new_transfer in done: 395 await new_transfer.result().begin() 396 new_transfer = self._loop.create_task( 397 self._new_transfer_queue.get() 398 ) 399 400 if read_chunk in done: 401 self._loop.create_task( 402 self._handle_chunk( 403 self._read_transfers, read_chunk.result() 404 ) 405 ) 406 read_chunk = self._loop.create_task( 407 self._read_chunk_queue.get() 408 ) 409 410 if write_chunk in done: 411 self._loop.create_task( 412 self._handle_chunk( 413 self._write_transfers, write_chunk.result() 414 ) 415 ) 416 write_chunk = self._loop.create_task( 417 self._write_chunk_queue.get() 418 ) 419 420 self._loop.stop() 421 422 @staticmethod 423 async def _handle_chunk( 424 transfers: _TransferDict, message: transfer_pb2.Chunk 425 ) -> None: 426 """Processes an incoming chunk from a stream. 427 428 The chunk is dispatched to an active transfer based on its ID. If the 429 transfer indicates that it is complete, the provided completion callback 430 is invoked. 431 """ 432 433 chunk = Chunk.from_message(message) 434 435 # Find a transfer for the chunk in the list of active transfers. 436 try: 437 if chunk.protocol_version is ProtocolVersion.LEGACY: 438 transfer = next( 439 t 440 for t in transfers.values() 441 if t.resource_id == chunk.session_id 442 ) 443 else: 444 transfer = next( 445 t for t in transfers.values() if t.id == chunk.id() 446 ) 447 except StopIteration: 448 _LOG.error( 449 'TransferManager received chunk for unknown transfer %d', 450 chunk.id(), 451 ) 452 # TODO(frolv): What should be done here, if anything? 453 return 454 455 await transfer.handle_chunk(chunk) 456 457 def _on_read_error(self, status: Status) -> None: 458 """Callback for an RPC error in the read stream.""" 459 460 for transfer in self._read_transfers.values(): 461 transfer.finish(Status.INTERNAL, skip_callback=True) 462 self._read_transfers.clear() 463 464 _LOG.error('Read stream shut down: %s', status) 465 466 def _on_write_error(self, status: Status) -> None: 467 """Callback for an RPC error in the write stream.""" 468 469 for transfer in self._write_transfers.values(): 470 transfer.finish(Status.INTERNAL, skip_callback=True) 471 self._write_transfers.clear() 472 473 _LOG.error('Write stream shut down: %s', status) 474 475 def _start_read_transfer(self, transfer: Transfer) -> None: 476 """Begins a new read transfer, opening the stream if it isn't.""" 477 478 self._read_transfers[transfer.resource_id] = transfer 479 self._read_stream.open() 480 481 _LOG.debug('Starting new read transfer %d', transfer.id) 482 delay = 1 483 self._loop.call_soon_threadsafe( 484 self._loop.call_later, 485 delay, 486 self._new_transfer_queue.put_nowait, 487 transfer, 488 ) 489 490 def _end_read_transfer(self, transfer: Transfer) -> None: 491 """Completes a read transfer.""" 492 del self._read_transfers[transfer.resource_id] 493 494 if not transfer.status.ok(): 495 _LOG.error( 496 'Read transfer %d terminated with status %s', 497 transfer.id, 498 transfer.status, 499 ) 500 501 def _start_write_transfer(self, transfer: Transfer) -> None: 502 """Begins a new write transfer, opening the stream if it isn't.""" 503 504 self._write_transfers[transfer.resource_id] = transfer 505 self._write_stream.open() 506 507 _LOG.debug('Starting new write transfer %d', transfer.id) 508 delay = 1 509 self._loop.call_soon_threadsafe( 510 self._loop.call_later, 511 delay, 512 self._new_transfer_queue.put_nowait, 513 transfer, 514 ) 515 516 def _end_write_transfer(self, transfer: Transfer) -> None: 517 """Completes a write transfer.""" 518 del self._write_transfers[transfer.resource_id] 519 520 if not transfer.status.ok(): 521 _LOG.error( 522 'Write transfer %d terminated with status %s', 523 transfer.id, 524 transfer.status, 525 ) 526 527 528class Error(Exception): 529 """Exception raised when a transfer fails. 530 531 Stores the ID of the failed transfer resource and the error that occurred. 532 """ 533 534 def __init__(self, resource_id: int, status: Status): 535 super().__init__(f'Transfer {resource_id} failed with status {status}') 536 self.resource_id = resource_id 537 self.status = status 538