• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Client for the pw_transfer service, which transmits data over pw_rpc."""
15
16import asyncio
17import ctypes
18import logging
19import threading
20from typing import Any, Callable
21
22from pw_rpc.callback_client import BidirectionalStreamingCall
23from pw_status import Status
24
25from pw_transfer.transfer import (
26    ProgressCallback,
27    ProtocolVersion,
28    ReadTransfer,
29    Transfer,
30    WriteTransfer,
31)
32from pw_transfer.chunk import Chunk
33from pw_transfer import transfer_pb2
34
35_LOG = logging.getLogger(__package__)
36
37_TransferDict = dict[int, Transfer]
38
39
40class _TransferStream:
41    def __init__(
42        self,
43        method,
44        chunk_handler: Callable[[Chunk], Any],
45        error_handler: Callable[[Status], Any],
46        max_reopen_attempts=3,
47    ):
48        self._method = method
49        self._chunk_handler = chunk_handler
50        self._error_handler = error_handler
51        self._call: BidirectionalStreamingCall | None = None
52        self._reopen_attempts = 0
53        self._max_reopen_attempts = max_reopen_attempts
54
55    def is_open(self) -> bool:
56        return self._call is not None
57
58    def open(self, force: bool = False) -> None:
59        if force or self._call is None:
60            self._call = self._method.invoke(
61                lambda _, chunk: self._on_chunk_received(chunk),
62                on_error=lambda _, status: self._on_stream_error(status),
63            )
64
65    def close(self) -> None:
66        if self._call is not None:
67            self._call.cancel()
68            self._call = None
69
70    def send(self, chunk: Chunk) -> None:
71        assert self._call is not None
72        self._call.send(chunk.to_message())
73
74    def _on_chunk_received(self, chunk: Chunk) -> None:
75        self._reopen_attempts = 0
76        self._chunk_handler(chunk)
77
78    def _on_stream_error(self, rpc_status: Status) -> None:
79        if rpc_status is Status.FAILED_PRECONDITION:
80            # FAILED_PRECONDITION indicates that the stream packet was not
81            # recognized as the stream is not open. Attempt to re-open the
82            # stream automatically.
83            self.open(force=True)
84        else:
85            # Other errors are unrecoverable; clear the stream.
86            _LOG.error('Transfer stream shut down with status %s', rpc_status)
87            self._call = None
88
89        self._error_handler(rpc_status)
90
91
92class Manager:  # pylint: disable=too-many-instance-attributes
93    """A manager for transmitting data through an RPC TransferService.
94
95    This should be initialized with an active Manager over an RPC channel. Only
96    one instance of this class should exist for a configured RPC TransferService
97    -- the Manager supports multiple simultaneous transfers.
98
99    When created, a Manager starts a separate thread in which transfer
100    communications and events are handled.
101    """
102
103    def __init__(
104        self,
105        rpc_transfer_service,
106        *,
107        default_response_timeout_s: float = 2.0,
108        initial_response_timeout_s: float = 4.0,
109        max_retries: int = 3,
110        max_lifetime_retries: int = 1500,
111        max_chunk_size_bytes: int = 1024,
112        default_protocol_version=ProtocolVersion.VERSION_TWO,
113    ):
114        """Initializes a Manager on top of a TransferService.
115
116        Args:
117          rpc_transfer_service: the pw_rpc transfer service client
118          default_response_timeout_s: max time to wait between receiving packets
119          initial_response_timeout_s: timeout for the first packet; may be
120              longer to account for transfer handler initialization
121          max_retries: number of times to retry a single package after a timeout
122          max_lifetime_retires: Cumulative maximum number of times to retry over
123              the course of the transfer before giving up.
124          max_chunk_size_bytes: In a read transfer, the maximum size of data the
125              server should send within a single packet.
126          default_protocol_version: Version of the pw_transfer protocol to use.
127              Defaults to the latest, but can be set to legacy for projects
128              which use legacy devices.
129        """
130        self._service: Any = rpc_transfer_service
131        self._default_response_timeout_s = default_response_timeout_s
132        self._initial_response_timeout_s = initial_response_timeout_s
133        self.max_retries = max_retries
134        self.max_lifetime_retries = max_lifetime_retries
135        self._max_chunk_size_bytes = max_chunk_size_bytes
136        self._default_protocol_version = default_protocol_version
137
138        # Ongoing transfers in the service by resource ID.
139        self._read_transfers: _TransferDict = {}
140        self._write_transfers: _TransferDict = {}
141        self._next_session_id = ctypes.c_uint32(1)
142
143        self._loop = asyncio.new_event_loop()
144        # Set the event loop for the current thread.
145        asyncio.set_event_loop(self._loop)
146
147        # Queues are used for communication between the Manager context and the
148        # dedicated asyncio transfer thread.
149        self._new_transfer_queue: asyncio.Queue = asyncio.Queue()
150        self._read_chunk_queue: asyncio.Queue = asyncio.Queue()
151        self._write_chunk_queue: asyncio.Queue = asyncio.Queue()
152        self._quit_event = asyncio.Event()
153
154        self._thread = threading.Thread(
155            target=self._start_event_loop_thread, daemon=True
156        )
157
158        # RPC streams for read and write transfers. These are shareable by
159        # multiple transfers of the same type.
160        self._read_stream = _TransferStream(
161            self._service.Read,
162            lambda chunk: self._loop.call_soon_threadsafe(
163                self._read_chunk_queue.put_nowait, chunk
164            ),
165            self._on_read_error,
166        )
167        self._write_stream = _TransferStream(
168            self._service.Write,
169            lambda chunk: self._loop.call_soon_threadsafe(
170                self._write_chunk_queue.put_nowait, chunk
171            ),
172            self._on_write_error,
173        )
174
175        self._thread.start()
176
177    def __del__(self):
178        # Notify the thread that the transfer manager is being destroyed and
179        # wait for it to exit.
180        if self._thread.is_alive():
181            self._loop.call_soon_threadsafe(self._quit_event.set)
182            self._thread.join()
183
184    def read(
185        self,
186        resource_id: int,
187        progress_callback: ProgressCallback | None = None,
188        protocol_version: ProtocolVersion | None = None,
189        chunk_timeout_s: float | None = None,
190        initial_timeout_s: float | None = None,
191        initial_offset: int = 0,
192        max_window_size: int = 32768,
193    ) -> bytes:
194        """Receives ("downloads") data from the server.
195
196        Args:
197          resource_id: ID of the resource from which to read.
198          progress_callback: Optional callback periodically invoked throughout
199              the transfer with the transfer state. Can be used to provide user-
200              facing status updates such as progress bars.
201          protocol_version: The desired protocol version to use for this
202              transfer. Defaults to the version the manager was initialized
203              (typically VERSION_TWO).
204          chunk_timeout_s: Timeout for any individual chunk.
205          initial_timeout_s: Timeout for the first chunk, overrides
206              chunk_timeout_s.
207          initial_offset: Initial offset to start reading from. Must be
208              supported by the transfer handler. All transfers support starting
209              from 0, the default. Returned bytes will not have any padding
210              related to this initial offset. No seeking is done in the transfer
211              operation on the client side.
212
213        Raises:
214          Error: the transfer failed to complete
215        """
216
217        if resource_id in self._read_transfers:
218            raise ValueError(
219                f'Read transfer for resource {resource_id} already exists'
220            )
221
222        if protocol_version is None:
223            protocol_version = self._default_protocol_version
224
225        if protocol_version == ProtocolVersion.LEGACY and initial_offset != 0:
226            raise ValueError(
227                f'Unsupported transfer with offset {initial_offset} started '
228                + 'with legacy protocol'
229            )
230
231        session_id = (
232            resource_id
233            if protocol_version is ProtocolVersion.LEGACY
234            else self.assign_session_id()
235        )
236
237        if chunk_timeout_s is None:
238            chunk_timeout_s = self._default_response_timeout_s
239
240        if initial_timeout_s is None:
241            initial_timeout_s = self._initial_response_timeout_s
242
243        transfer = ReadTransfer(
244            session_id,
245            resource_id,
246            self._read_stream.send,
247            self._end_read_transfer,
248            chunk_timeout_s,
249            initial_timeout_s,
250            self.max_retries,
251            self.max_lifetime_retries,
252            protocol_version,
253            max_window_size_bytes=max_window_size,
254            max_chunk_size=self._max_chunk_size_bytes,
255            progress_callback=progress_callback,
256            initial_offset=initial_offset,
257        )
258        self._start_read_transfer(transfer)
259
260        transfer.done.wait()
261
262        if not transfer.status.ok():
263            raise Error(transfer.resource_id, transfer.status)
264
265        return transfer.data
266
267    def write(
268        self,
269        resource_id: int,
270        data: bytes | str,
271        progress_callback: ProgressCallback | None = None,
272        protocol_version: ProtocolVersion | None = None,
273        chunk_timeout_s: Any | None = None,
274        initial_timeout_s: Any | None = None,
275        initial_offset: int = 0,
276    ) -> None:
277        """Transmits ("uploads") data to the server.
278
279        Args:
280          resource_id: ID of the resource to which to write.
281          data: Data to send to the server.
282          progress_callback: Optional callback periodically invoked throughout
283              the transfer with the transfer state. Can be used to provide user-
284              facing status updates such as progress bars.
285          protocol_version: The desired protocol version to use for this
286              transfer. Defaults to the version the manager was initialized
287              (defaults to LATEST).
288          chunk_timeout_s: Timeout for any individual chunk.
289          initial_timeout_s: Timeout for the first chunk, overrides
290              chunk_timeout_s.
291          initial_offset: Initial offset to start writing to. Must be supported
292              by the transfer handler. All transfers support starting from 0,
293              the default. data arg should start with the data you want to see
294              starting at this initial offset on the server. No seeking is done
295              in the transfer operation on the client side.
296
297        Raises:
298          Error: the transfer failed to complete
299        """
300
301        if isinstance(data, str):
302            data = data.encode()
303
304        if resource_id in self._write_transfers:
305            raise ValueError(
306                f'Write transfer for resource {resource_id} already exists'
307            )
308
309        if protocol_version is None:
310            protocol_version = self._default_protocol_version
311
312        if (
313            protocol_version != ProtocolVersion.VERSION_TWO
314            and initial_offset != 0
315        ):
316            raise ValueError(
317                f'Unsupported transfer with offset {initial_offset} started '
318                + 'with legacy protocol'
319            )
320
321        session_id = (
322            resource_id
323            if protocol_version is ProtocolVersion.LEGACY
324            else self.assign_session_id()
325        )
326
327        if chunk_timeout_s is None:
328            chunk_timeout_s = self._default_response_timeout_s
329
330        if initial_timeout_s is None:
331            initial_timeout_s = self._initial_response_timeout_s
332
333        transfer = WriteTransfer(
334            session_id,
335            resource_id,
336            data,
337            self._write_stream.send,
338            self._end_write_transfer,
339            chunk_timeout_s,
340            initial_timeout_s,
341            self.max_retries,
342            self.max_lifetime_retries,
343            protocol_version,
344            progress_callback=progress_callback,
345            initial_offset=initial_offset,
346        )
347        self._start_write_transfer(transfer)
348
349        transfer.done.wait()
350
351        if not transfer.status.ok():
352            raise Error(transfer.resource_id, transfer.status)
353
354    def assign_session_id(self) -> int:
355        new_id = self._next_session_id.value
356
357        self._next_session_id = ctypes.c_uint32(self._next_session_id.value + 1)
358        if self._next_session_id.value == 0:
359            self._next_session_id = ctypes.c_uint32(1)
360
361        return new_id
362
363    def _start_event_loop_thread(self):
364        """Entry point for event loop thread that starts an asyncio context."""
365        asyncio.set_event_loop(self._loop)
366
367        # Recreate the async communication channels in the context of the
368        # running event loop.
369        self._new_transfer_queue = asyncio.Queue()
370        self._read_chunk_queue = asyncio.Queue()
371        self._write_chunk_queue = asyncio.Queue()
372        self._quit_event = asyncio.Event()
373
374        self._loop.create_task(self._transfer_event_loop())
375        self._loop.run_forever()
376
377    async def _transfer_event_loop(self):
378        """Main async event loop."""
379        exit_thread = self._loop.create_task(self._quit_event.wait())
380        new_transfer = self._loop.create_task(self._new_transfer_queue.get())
381        read_chunk = self._loop.create_task(self._read_chunk_queue.get())
382        write_chunk = self._loop.create_task(self._write_chunk_queue.get())
383
384        while not self._quit_event.is_set():
385            # Perform a select(2)-like wait for one of several events to occur.
386            done, _ = await asyncio.wait(
387                (exit_thread, new_transfer, read_chunk, write_chunk),
388                return_when=asyncio.FIRST_COMPLETED,
389            )
390
391            if exit_thread in done:
392                break
393
394            if new_transfer in done:
395                await new_transfer.result().begin()
396                new_transfer = self._loop.create_task(
397                    self._new_transfer_queue.get()
398                )
399
400            if read_chunk in done:
401                self._loop.create_task(
402                    self._handle_chunk(
403                        self._read_transfers, read_chunk.result()
404                    )
405                )
406                read_chunk = self._loop.create_task(
407                    self._read_chunk_queue.get()
408                )
409
410            if write_chunk in done:
411                self._loop.create_task(
412                    self._handle_chunk(
413                        self._write_transfers, write_chunk.result()
414                    )
415                )
416                write_chunk = self._loop.create_task(
417                    self._write_chunk_queue.get()
418                )
419
420        self._loop.stop()
421
422    @staticmethod
423    async def _handle_chunk(
424        transfers: _TransferDict, message: transfer_pb2.Chunk
425    ) -> None:
426        """Processes an incoming chunk from a stream.
427
428        The chunk is dispatched to an active transfer based on its ID. If the
429        transfer indicates that it is complete, the provided completion callback
430        is invoked.
431        """
432
433        chunk = Chunk.from_message(message)
434
435        # Find a transfer for the chunk in the list of active transfers.
436        try:
437            if chunk.protocol_version is ProtocolVersion.LEGACY:
438                transfer = next(
439                    t
440                    for t in transfers.values()
441                    if t.resource_id == chunk.session_id
442                )
443            else:
444                transfer = next(
445                    t for t in transfers.values() if t.id == chunk.id()
446                )
447        except StopIteration:
448            _LOG.error(
449                'TransferManager received chunk for unknown transfer %d',
450                chunk.id(),
451            )
452            # TODO(frolv): What should be done here, if anything?
453            return
454
455        await transfer.handle_chunk(chunk)
456
457    def _on_read_error(self, status: Status) -> None:
458        """Callback for an RPC error in the read stream."""
459
460        for transfer in self._read_transfers.values():
461            transfer.finish(Status.INTERNAL, skip_callback=True)
462        self._read_transfers.clear()
463
464        _LOG.error('Read stream shut down: %s', status)
465
466    def _on_write_error(self, status: Status) -> None:
467        """Callback for an RPC error in the write stream."""
468
469        for transfer in self._write_transfers.values():
470            transfer.finish(Status.INTERNAL, skip_callback=True)
471        self._write_transfers.clear()
472
473        _LOG.error('Write stream shut down: %s', status)
474
475    def _start_read_transfer(self, transfer: Transfer) -> None:
476        """Begins a new read transfer, opening the stream if it isn't."""
477
478        self._read_transfers[transfer.resource_id] = transfer
479        self._read_stream.open()
480
481        _LOG.debug('Starting new read transfer %d', transfer.id)
482        delay = 1
483        self._loop.call_soon_threadsafe(
484            self._loop.call_later,
485            delay,
486            self._new_transfer_queue.put_nowait,
487            transfer,
488        )
489
490    def _end_read_transfer(self, transfer: Transfer) -> None:
491        """Completes a read transfer."""
492        del self._read_transfers[transfer.resource_id]
493
494        if not transfer.status.ok():
495            _LOG.error(
496                'Read transfer %d terminated with status %s',
497                transfer.id,
498                transfer.status,
499            )
500
501    def _start_write_transfer(self, transfer: Transfer) -> None:
502        """Begins a new write transfer, opening the stream if it isn't."""
503
504        self._write_transfers[transfer.resource_id] = transfer
505        self._write_stream.open()
506
507        _LOG.debug('Starting new write transfer %d', transfer.id)
508        delay = 1
509        self._loop.call_soon_threadsafe(
510            self._loop.call_later,
511            delay,
512            self._new_transfer_queue.put_nowait,
513            transfer,
514        )
515
516    def _end_write_transfer(self, transfer: Transfer) -> None:
517        """Completes a write transfer."""
518        del self._write_transfers[transfer.resource_id]
519
520        if not transfer.status.ok():
521            _LOG.error(
522                'Write transfer %d terminated with status %s',
523                transfer.id,
524                transfer.status,
525            )
526
527
528class Error(Exception):
529    """Exception raised when a transfer fails.
530
531    Stores the ID of the failed transfer resource and the error that occurred.
532    """
533
534    def __init__(self, resource_id: int, status: Status):
535        super().__init__(f'Transfer {resource_id} failed with status {status}')
536        self.resource_id = resource_id
537        self.status = status
538