1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Client for the pw_transfer service, which transmits data over pw_rpc.""" 15 16import asyncio 17import logging 18import threading 19from typing import Any, Dict, Optional, Union 20 21from pw_rpc.callback_client import BidirectionalStreamingCall 22from pw_status import Status 23 24from pw_transfer.transfer import (ProgressCallback, ReadTransfer, Transfer, 25 WriteTransfer) 26from pw_transfer.transfer_pb2 import Chunk 27 28_LOG = logging.getLogger(__package__) 29 30_TransferDict = Dict[int, Transfer] 31 32 33class Manager: # pylint: disable=too-many-instance-attributes 34 """A manager for transmitting data through an RPC TransferService. 35 36 This should be initialized with an active Manager over an RPC channel. Only 37 one instance of this class should exist for a configured RPC TransferService 38 -- the Manager supports multiple simultaneous transfers. 39 40 When created, a Manager starts a separate thread in which transfer 41 communications and events are handled. 42 """ 43 def __init__(self, 44 rpc_transfer_service, 45 *, 46 default_response_timeout_s: float = 2.0, 47 initial_response_timeout_s: float = 4.0, 48 max_retries: int = 3): 49 """Initializes a Manager on top of a TransferService. 50 51 Args: 52 rpc_transfer_service: the pw_rpc transfer service client 53 default_response_timeout_s: max time to wait between receiving packets 54 initial_response_timeout_s: timeout for the first packet; may be 55 longer to account for transfer handler initialization 56 max_retires: number of times to retry after a timeout 57 """ 58 self._service: Any = rpc_transfer_service 59 self._default_response_timeout_s = default_response_timeout_s 60 self._initial_response_timeout_s = initial_response_timeout_s 61 self.max_retries = max_retries 62 63 # Ongoing transfers in the service by ID. 64 self._read_transfers: _TransferDict = {} 65 self._write_transfers: _TransferDict = {} 66 67 # RPC streams for read and write transfers. These are shareable by 68 # multiple transfers of the same type. 69 self._read_stream: Optional[BidirectionalStreamingCall] = None 70 self._write_stream: Optional[BidirectionalStreamingCall] = None 71 72 self._loop = asyncio.new_event_loop() 73 74 # Queues are used for communication between the Manager context and the 75 # dedicated asyncio transfer thread. 76 self._new_transfer_queue: asyncio.Queue = asyncio.Queue() 77 self._read_chunk_queue: asyncio.Queue = asyncio.Queue() 78 self._write_chunk_queue: asyncio.Queue = asyncio.Queue() 79 self._quit_event = asyncio.Event() 80 81 self._thread = threading.Thread(target=self._start_event_loop_thread, 82 daemon=True) 83 84 self._thread.start() 85 86 def __del__(self): 87 # Notify the thread that the transfer manager is being destroyed and 88 # wait for it to exit. 89 if self._thread.is_alive(): 90 self._loop.call_soon_threadsafe(self._quit_event.set) 91 self._thread.join() 92 93 def read(self, 94 transfer_id: int, 95 progress_callback: ProgressCallback = None) -> bytes: 96 """Receives ("downloads") data from the server. 97 98 Raises: 99 Error: the transfer failed to complete 100 """ 101 102 if transfer_id in self._read_transfers: 103 raise ValueError(f'Read transfer {transfer_id} already exists') 104 105 transfer = ReadTransfer(transfer_id, 106 self._send_read_chunk, 107 self._end_read_transfer, 108 self._default_response_timeout_s, 109 self._initial_response_timeout_s, 110 self.max_retries, 111 progress_callback=progress_callback) 112 self._start_read_transfer(transfer) 113 114 transfer.done.wait() 115 116 if not transfer.status.ok(): 117 raise Error(transfer.id, transfer.status) 118 119 return transfer.data 120 121 def write(self, 122 transfer_id: int, 123 data: Union[bytes, str], 124 progress_callback: ProgressCallback = None) -> None: 125 """Transmits ("uploads") data to the server. 126 127 Args: 128 transfer_id: ID of the write transfer 129 data: Data to send to the server. 130 progress_callback: Optional callback periodically invoked throughout 131 the transfer with the transfer state. Can be used to provide user- 132 facing status updates such as progress bars. 133 134 Raises: 135 Error: the transfer failed to complete 136 """ 137 138 if isinstance(data, str): 139 data = data.encode() 140 141 if transfer_id in self._write_transfers: 142 raise ValueError(f'Write transfer {transfer_id} already exists') 143 144 transfer = WriteTransfer(transfer_id, 145 data, 146 self._send_write_chunk, 147 self._end_write_transfer, 148 self._default_response_timeout_s, 149 self._initial_response_timeout_s, 150 self.max_retries, 151 progress_callback=progress_callback) 152 self._start_write_transfer(transfer) 153 154 transfer.done.wait() 155 156 if not transfer.status.ok(): 157 raise Error(transfer.id, transfer.status) 158 159 def _send_read_chunk(self, chunk: Chunk) -> None: 160 assert self._read_stream is not None 161 self._read_stream.send(chunk) 162 163 def _send_write_chunk(self, chunk: Chunk) -> None: 164 assert self._write_stream is not None 165 self._write_stream.send(chunk) 166 167 def _start_event_loop_thread(self): 168 """Entry point for event loop thread that starts an asyncio context.""" 169 asyncio.set_event_loop(self._loop) 170 171 # Recreate the async communication channels in the context of the 172 # running event loop. 173 self._new_transfer_queue = asyncio.Queue() 174 self._read_chunk_queue = asyncio.Queue() 175 self._write_chunk_queue = asyncio.Queue() 176 self._quit_event = asyncio.Event() 177 178 self._loop.create_task(self._transfer_event_loop()) 179 self._loop.run_forever() 180 181 async def _transfer_event_loop(self): 182 """Main async event loop.""" 183 exit_thread = self._loop.create_task(self._quit_event.wait()) 184 new_transfer = self._loop.create_task(self._new_transfer_queue.get()) 185 read_chunk = self._loop.create_task(self._read_chunk_queue.get()) 186 write_chunk = self._loop.create_task(self._write_chunk_queue.get()) 187 188 while not self._quit_event.is_set(): 189 # Perform a select(2)-like wait for one of several events to occur. 190 done, _ = await asyncio.wait( 191 (exit_thread, new_transfer, read_chunk, write_chunk), 192 return_when=asyncio.FIRST_COMPLETED) 193 194 if exit_thread in done: 195 break 196 197 if new_transfer in done: 198 await new_transfer.result().begin() 199 new_transfer = self._loop.create_task( 200 self._new_transfer_queue.get()) 201 202 if read_chunk in done: 203 self._loop.create_task( 204 self._handle_chunk(self._read_transfers, 205 read_chunk.result())) 206 read_chunk = self._loop.create_task( 207 self._read_chunk_queue.get()) 208 209 if write_chunk in done: 210 self._loop.create_task( 211 self._handle_chunk(self._write_transfers, 212 write_chunk.result())) 213 write_chunk = self._loop.create_task( 214 self._write_chunk_queue.get()) 215 216 self._loop.stop() 217 218 @staticmethod 219 async def _handle_chunk(transfers: _TransferDict, chunk: Chunk) -> None: 220 """Processes an incoming chunk from a stream. 221 222 The chunk is dispatched to an active transfer based on its ID. If the 223 transfer indicates that it is complete, the provided completion callback 224 is invoked. 225 """ 226 227 try: 228 transfer = transfers[chunk.transfer_id] 229 except KeyError: 230 _LOG.error( 231 'TransferManager received chunk for unknown transfer %d', 232 chunk.transfer_id) 233 # TODO(frolv): What should be done here, if anything? 234 return 235 236 await transfer.handle_chunk(chunk) 237 238 def _open_read_stream(self) -> None: 239 self._read_stream = self._service.Read.invoke( 240 lambda _, chunk: self._loop.call_soon_threadsafe( 241 self._read_chunk_queue.put_nowait, chunk), 242 on_error=lambda _, status: self._on_read_error(status)) 243 244 def _on_read_error(self, status: Status) -> None: 245 """Callback for an RPC error in the read stream.""" 246 247 if status is Status.FAILED_PRECONDITION: 248 # FAILED_PRECONDITION indicates that the stream packet was not 249 # recognized as the stream is not open. This could occur if the 250 # server resets during an active transfer. Re-open the stream to 251 # allow pending transfers to continue. 252 self._open_read_stream() 253 else: 254 # Other errors are unrecoverable. Clear the stream and cancel any 255 # pending transfers with an INTERNAL status as this is a system 256 # error. 257 self._read_stream = None 258 259 for transfer in self._read_transfers.values(): 260 transfer.finish(Status.INTERNAL, skip_callback=True) 261 self._read_transfers.clear() 262 263 _LOG.error('Read stream shut down: %s', status) 264 265 def _open_write_stream(self) -> None: 266 self._write_stream = self._service.Write.invoke( 267 lambda _, chunk: self._loop.call_soon_threadsafe( 268 self._write_chunk_queue.put_nowait, chunk), 269 on_error=lambda _, status: self._on_write_error(status)) 270 271 def _on_write_error(self, status: Status) -> None: 272 """Callback for an RPC error in the write stream.""" 273 274 if status is Status.FAILED_PRECONDITION: 275 # FAILED_PRECONDITION indicates that the stream packet was not 276 # recognized as the stream is not open. This could occur if the 277 # server resets during an active transfer. Re-open the stream to 278 # allow pending transfers to continue. 279 self._open_write_stream() 280 else: 281 # Other errors are unrecoverable. Clear the stream and cancel any 282 # pending transfers with an INTERNAL status as this is a system 283 # error. 284 self._write_stream = None 285 286 for transfer in self._write_transfers.values(): 287 transfer.finish(Status.INTERNAL, skip_callback=True) 288 self._write_transfers.clear() 289 290 _LOG.error('Write stream shut down: %s', status) 291 292 def _start_read_transfer(self, transfer: Transfer) -> None: 293 """Begins a new read transfer, opening the stream if it isn't.""" 294 295 self._read_transfers[transfer.id] = transfer 296 297 if not self._read_stream: 298 self._open_read_stream() 299 300 _LOG.debug('Starting new read transfer %d', transfer.id) 301 self._loop.call_soon_threadsafe(self._new_transfer_queue.put_nowait, 302 transfer) 303 304 def _end_read_transfer(self, transfer: Transfer) -> None: 305 """Completes a read transfer.""" 306 del self._read_transfers[transfer.id] 307 308 if not transfer.status.ok(): 309 _LOG.error('Read transfer %d terminated with status %s', 310 transfer.id, transfer.status) 311 312 # TODO(frolv): This doesn't seem to work. Investigate why. 313 # If no more transfers are using the read stream, close it. 314 # if not self._read_transfers and self._read_stream: 315 # self._read_stream.cancel() 316 # self._read_stream = None 317 318 def _start_write_transfer(self, transfer: Transfer) -> None: 319 """Begins a new write transfer, opening the stream if it isn't.""" 320 321 self._write_transfers[transfer.id] = transfer 322 323 if not self._write_stream: 324 self._open_write_stream() 325 326 _LOG.debug('Starting new write transfer %d', transfer.id) 327 self._loop.call_soon_threadsafe(self._new_transfer_queue.put_nowait, 328 transfer) 329 330 def _end_write_transfer(self, transfer: Transfer) -> None: 331 """Completes a write transfer.""" 332 del self._write_transfers[transfer.id] 333 334 if not transfer.status.ok(): 335 _LOG.error('Write transfer %d terminated with status %s', 336 transfer.id, transfer.status) 337 338 # TODO(frolv): This doesn't seem to work. Investigate why. 339 # If no more transfers are using the write stream, close it. 340 # if not self._write_transfers and self._write_stream: 341 # self._write_stream.cancel() 342 # self._write_stream = None 343 344 345class Error(Exception): 346 """Exception raised when a transfer fails. 347 348 Stores the ID of the failed transfer and the error that occurred. 349 """ 350 def __init__(self, transfer_id: int, status: Status): 351 super().__init__(f'Transfer {transfer_id} failed with status {status}') 352 self.transfer_id = transfer_id 353 self.status = status 354