• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Client for the pw_transfer service, which transmits data over pw_rpc."""
15
16import asyncio
17import ctypes
18import logging
19import threading
20from typing import Any, Callable
21
22from pw_rpc.callback_client import BidirectionalStreamingCall
23from pw_status import Status
24
25from pw_transfer.transfer import (
26    ProgressCallback,
27    ProtocolVersion,
28    ReadTransfer,
29    Transfer,
30    WriteTransfer,
31)
32from pw_transfer.chunk import Chunk
33
34try:
35    from pw_transfer import transfer_pb2
36except ImportError:
37    # For the bazel build, which puts generated protos in a different location.
38    from pigweed.pw_transfer import transfer_pb2  # type: ignore
39
40_LOG = logging.getLogger(__package__)
41
42_TransferDict = dict[int, Transfer]
43
44
45class _TransferStream:
46    def __init__(
47        self,
48        method,
49        chunk_handler: Callable[[Chunk], Any],
50        error_handler: Callable[[Status], Any],
51        max_reopen_attempts=3,
52    ):
53        self._method = method
54        self._chunk_handler = chunk_handler
55        self._error_handler = error_handler
56        self._call: BidirectionalStreamingCall | None = None
57        self._reopen_attempts = 0
58        self._max_reopen_attempts = max_reopen_attempts
59
60    def is_open(self) -> bool:
61        return self._call is not None
62
63    def open(self, force: bool = False) -> None:
64        if force or self._call is None:
65            self._call = self._method.invoke(
66                lambda _, chunk: self._on_chunk_received(chunk),
67                on_error=lambda _, status: self._on_stream_error(status),
68            )
69
70    def close(self) -> None:
71        if self._call is not None:
72            self._call.cancel()
73            self._call = None
74
75    def send(self, chunk: Chunk) -> None:
76        assert self._call is not None
77        self._call.send(chunk.to_message())
78
79    def _on_chunk_received(self, chunk: Chunk) -> None:
80        self._reopen_attempts = 0
81        self._chunk_handler(chunk)
82
83    def _on_stream_error(self, rpc_status: Status) -> None:
84        if rpc_status is Status.FAILED_PRECONDITION:
85            # FAILED_PRECONDITION indicates that the stream packet was not
86            # recognized as the stream is not open. Attempt to re-open the
87            # stream automatically.
88            self.open(force=True)
89        else:
90            # Other errors are unrecoverable; clear the stream.
91            _LOG.error('Transfer stream shut down with status %s', rpc_status)
92            self._call = None
93
94        self._error_handler(rpc_status)
95
96
97class Manager:  # pylint: disable=too-many-instance-attributes
98    """A manager for transmitting data through an RPC TransferService.
99
100    This should be initialized with an active Manager over an RPC channel. Only
101    one instance of this class should exist for a configured RPC TransferService
102    -- the Manager supports multiple simultaneous transfers.
103
104    When created, a Manager starts a separate thread in which transfer
105    communications and events are handled.
106    """
107
108    def __init__(
109        self,
110        rpc_transfer_service,
111        *,
112        default_response_timeout_s: float = 2.0,
113        initial_response_timeout_s: float = 4.0,
114        max_retries: int = 3,
115        max_lifetime_retries: int = 1500,
116        max_chunk_size_bytes: int = 1024,
117        default_protocol_version=ProtocolVersion.VERSION_TWO,
118    ):
119        """Initializes a Manager on top of a TransferService.
120
121        Args:
122          rpc_transfer_service: the pw_rpc transfer service client
123          default_response_timeout_s: max time to wait between receiving packets
124          initial_response_timeout_s: timeout for the first packet; may be
125              longer to account for transfer handler initialization
126          max_retries: number of times to retry a single package after a timeout
127          max_lifetime_retires: Cumulative maximum number of times to retry over
128              the course of the transfer before giving up.
129          max_chunk_size_bytes: In a read transfer, the maximum size of data the
130              server should send within a single packet.
131          default_protocol_version: Version of the pw_transfer protocol to use.
132              Defaults to the latest, but can be set to legacy for projects
133              which use legacy devices.
134        """
135        self._service: Any = rpc_transfer_service
136        self._default_response_timeout_s = default_response_timeout_s
137        self._initial_response_timeout_s = initial_response_timeout_s
138        self.max_retries = max_retries
139        self.max_lifetime_retries = max_lifetime_retries
140        self._max_chunk_size_bytes = max_chunk_size_bytes
141        self._default_protocol_version = default_protocol_version
142
143        # Ongoing transfers in the service by resource ID.
144        self._read_transfers: _TransferDict = {}
145        self._write_transfers: _TransferDict = {}
146        self._next_session_id = ctypes.c_uint32(1)
147
148        self._loop = asyncio.new_event_loop()
149        # Set the event loop for the current thread.
150        asyncio.set_event_loop(self._loop)
151
152        # Queues are used for communication between the Manager context and the
153        # dedicated asyncio transfer thread.
154        self._new_transfer_queue: asyncio.Queue = asyncio.Queue()
155        self._read_chunk_queue: asyncio.Queue = asyncio.Queue()
156        self._write_chunk_queue: asyncio.Queue = asyncio.Queue()
157        self._quit_event = asyncio.Event()
158
159        self._thread = threading.Thread(
160            target=self._start_event_loop_thread, daemon=True
161        )
162
163        # RPC streams for read and write transfers. These are shareable by
164        # multiple transfers of the same type.
165        self._read_stream = _TransferStream(
166            self._service.Read,
167            lambda chunk: self._loop.call_soon_threadsafe(
168                self._read_chunk_queue.put_nowait, chunk
169            ),
170            self._on_read_error,
171        )
172        self._write_stream = _TransferStream(
173            self._service.Write,
174            lambda chunk: self._loop.call_soon_threadsafe(
175                self._write_chunk_queue.put_nowait, chunk
176            ),
177            self._on_write_error,
178        )
179
180        self._thread.start()
181
182    def __del__(self):
183        # Notify the thread that the transfer manager is being destroyed and
184        # wait for it to exit.
185        if self._thread.is_alive():
186            self._loop.call_soon_threadsafe(self._quit_event.set)
187            self._thread.join()
188
189    def read(
190        self,
191        resource_id: int,
192        progress_callback: ProgressCallback | None = None,
193        protocol_version: ProtocolVersion | None = None,
194        chunk_timeout_s: float | None = None,
195        initial_timeout_s: float | None = None,
196        initial_offset: int = 0,
197    ) -> bytes:
198        """Receives ("downloads") data from the server.
199
200        Args:
201          resource_id: ID of the resource from which to read.
202          progress_callback: Optional callback periodically invoked throughout
203              the transfer with the transfer state. Can be used to provide user-
204              facing status updates such as progress bars.
205          protocol_version: The desired protocol version to use for this
206              transfer. Defaults to the version the manager was initialized
207              (typically VERSION_TWO).
208          chunk_timeout_s: Timeout for any individual chunk.
209          initial_timeout_s: Timeout for the first chunk, overrides
210              chunk_timeout_s.
211          initial_offset: Initial offset to start reading from. Must be
212              supported by the transfer handler. All transfers support starting
213              from 0, the default. Returned bytes will not have any padding
214              related to this initial offset. No seeking is done in the transfer
215              operation on the client side.
216
217        Raises:
218          Error: the transfer failed to complete
219        """
220
221        if resource_id in self._read_transfers:
222            raise ValueError(
223                f'Read transfer for resource {resource_id} already exists'
224            )
225
226        if protocol_version is None:
227            protocol_version = self._default_protocol_version
228
229        if protocol_version == ProtocolVersion.LEGACY and initial_offset != 0:
230            raise ValueError(
231                f'Unsupported transfer with offset {initial_offset} started '
232                + 'with legacy protocol'
233            )
234
235        session_id = (
236            resource_id
237            if protocol_version is ProtocolVersion.LEGACY
238            else self.assign_session_id()
239        )
240
241        if chunk_timeout_s is None:
242            chunk_timeout_s = self._default_response_timeout_s
243
244        if initial_timeout_s is None:
245            initial_timeout_s = self._initial_response_timeout_s
246
247        transfer = ReadTransfer(
248            session_id,
249            resource_id,
250            self._read_stream.send,
251            self._end_read_transfer,
252            chunk_timeout_s,
253            initial_timeout_s,
254            self.max_retries,
255            self.max_lifetime_retries,
256            protocol_version,
257            max_chunk_size=self._max_chunk_size_bytes,
258            progress_callback=progress_callback,
259            initial_offset=initial_offset,
260        )
261        self._start_read_transfer(transfer)
262
263        transfer.done.wait()
264
265        if not transfer.status.ok():
266            raise Error(transfer.resource_id, transfer.status)
267
268        return transfer.data
269
270    def write(
271        self,
272        resource_id: int,
273        data: bytes | str,
274        progress_callback: ProgressCallback | None = None,
275        protocol_version: ProtocolVersion | None = None,
276        chunk_timeout_s: Any | None = None,
277        initial_timeout_s: Any | None = None,
278        initial_offset: int = 0,
279    ) -> None:
280        """Transmits ("uploads") data to the server.
281
282        Args:
283          resource_id: ID of the resource to which to write.
284          data: Data to send to the server.
285          progress_callback: Optional callback periodically invoked throughout
286              the transfer with the transfer state. Can be used to provide user-
287              facing status updates such as progress bars.
288          protocol_version: The desired protocol version to use for this
289              transfer. Defaults to the version the manager was initialized
290              (defaults to LATEST).
291          chunk_timeout_s: Timeout for any individual chunk.
292          initial_timeout_s: Timeout for the first chunk, overrides
293              chunk_timeout_s.
294          initial_offset: Initial offset to start writing to. Must be supported
295              by the transfer handler. All transfers support starting from 0,
296              the default. data arg should start with the data you want to see
297              starting at this initial offset on the server. No seeking is done
298              in the transfer operation on the client side.
299
300        Raises:
301          Error: the transfer failed to complete
302        """
303
304        if isinstance(data, str):
305            data = data.encode()
306
307        if resource_id in self._write_transfers:
308            raise ValueError(
309                f'Write transfer for resource {resource_id} already exists'
310            )
311
312        if protocol_version is None:
313            protocol_version = self._default_protocol_version
314
315        if (
316            protocol_version != ProtocolVersion.VERSION_TWO
317            and initial_offset != 0
318        ):
319            raise ValueError(
320                f'Unsupported transfer with offset {initial_offset} started '
321                + 'with legacy protocol'
322            )
323
324        session_id = (
325            resource_id
326            if protocol_version is ProtocolVersion.LEGACY
327            else self.assign_session_id()
328        )
329
330        if chunk_timeout_s is None:
331            chunk_timeout_s = self._default_response_timeout_s
332
333        if initial_timeout_s is None:
334            initial_timeout_s = self._initial_response_timeout_s
335
336        transfer = WriteTransfer(
337            session_id,
338            resource_id,
339            data,
340            self._write_stream.send,
341            self._end_write_transfer,
342            chunk_timeout_s,
343            initial_timeout_s,
344            self.max_retries,
345            self.max_lifetime_retries,
346            protocol_version,
347            progress_callback=progress_callback,
348            initial_offset=initial_offset,
349        )
350        self._start_write_transfer(transfer)
351
352        transfer.done.wait()
353
354        if not transfer.status.ok():
355            raise Error(transfer.resource_id, transfer.status)
356
357    def assign_session_id(self) -> int:
358        new_id = self._next_session_id.value
359
360        self._next_session_id = ctypes.c_uint32(self._next_session_id.value + 1)
361        if self._next_session_id.value == 0:
362            self._next_session_id = ctypes.c_uint32(1)
363
364        return new_id
365
366    def _start_event_loop_thread(self):
367        """Entry point for event loop thread that starts an asyncio context."""
368        asyncio.set_event_loop(self._loop)
369
370        # Recreate the async communication channels in the context of the
371        # running event loop.
372        self._new_transfer_queue = asyncio.Queue()
373        self._read_chunk_queue = asyncio.Queue()
374        self._write_chunk_queue = asyncio.Queue()
375        self._quit_event = asyncio.Event()
376
377        self._loop.create_task(self._transfer_event_loop())
378        self._loop.run_forever()
379
380    async def _transfer_event_loop(self):
381        """Main async event loop."""
382        exit_thread = self._loop.create_task(self._quit_event.wait())
383        new_transfer = self._loop.create_task(self._new_transfer_queue.get())
384        read_chunk = self._loop.create_task(self._read_chunk_queue.get())
385        write_chunk = self._loop.create_task(self._write_chunk_queue.get())
386
387        while not self._quit_event.is_set():
388            # Perform a select(2)-like wait for one of several events to occur.
389            done, _ = await asyncio.wait(
390                (exit_thread, new_transfer, read_chunk, write_chunk),
391                return_when=asyncio.FIRST_COMPLETED,
392            )
393
394            if exit_thread in done:
395                break
396
397            if new_transfer in done:
398                await new_transfer.result().begin()
399                new_transfer = self._loop.create_task(
400                    self._new_transfer_queue.get()
401                )
402
403            if read_chunk in done:
404                self._loop.create_task(
405                    self._handle_chunk(
406                        self._read_transfers, read_chunk.result()
407                    )
408                )
409                read_chunk = self._loop.create_task(
410                    self._read_chunk_queue.get()
411                )
412
413            if write_chunk in done:
414                self._loop.create_task(
415                    self._handle_chunk(
416                        self._write_transfers, write_chunk.result()
417                    )
418                )
419                write_chunk = self._loop.create_task(
420                    self._write_chunk_queue.get()
421                )
422
423        self._loop.stop()
424
425    @staticmethod
426    async def _handle_chunk(
427        transfers: _TransferDict, message: transfer_pb2.Chunk
428    ) -> None:
429        """Processes an incoming chunk from a stream.
430
431        The chunk is dispatched to an active transfer based on its ID. If the
432        transfer indicates that it is complete, the provided completion callback
433        is invoked.
434        """
435
436        chunk = Chunk.from_message(message)
437
438        # Find a transfer for the chunk in the list of active transfers.
439        try:
440            if chunk.protocol_version is ProtocolVersion.LEGACY:
441                transfer = next(
442                    t
443                    for t in transfers.values()
444                    if t.resource_id == chunk.session_id
445                )
446            else:
447                transfer = next(
448                    t for t in transfers.values() if t.id == chunk.id()
449                )
450        except StopIteration:
451            _LOG.error(
452                'TransferManager received chunk for unknown transfer %d',
453                chunk.id(),
454            )
455            # TODO(frolv): What should be done here, if anything?
456            return
457
458        await transfer.handle_chunk(chunk)
459
460    def _on_read_error(self, status: Status) -> None:
461        """Callback for an RPC error in the read stream."""
462
463        for transfer in self._read_transfers.values():
464            transfer.finish(Status.INTERNAL, skip_callback=True)
465        self._read_transfers.clear()
466
467        _LOG.error('Read stream shut down: %s', status)
468
469    def _on_write_error(self, status: Status) -> None:
470        """Callback for an RPC error in the write stream."""
471
472        for transfer in self._write_transfers.values():
473            transfer.finish(Status.INTERNAL, skip_callback=True)
474        self._write_transfers.clear()
475
476        _LOG.error('Write stream shut down: %s', status)
477
478    def _start_read_transfer(self, transfer: Transfer) -> None:
479        """Begins a new read transfer, opening the stream if it isn't."""
480
481        self._read_transfers[transfer.resource_id] = transfer
482        self._read_stream.open()
483
484        _LOG.debug('Starting new read transfer %d', transfer.id)
485        self._loop.call_soon_threadsafe(
486            self._new_transfer_queue.put_nowait, transfer
487        )
488
489    def _end_read_transfer(self, transfer: Transfer) -> None:
490        """Completes a read transfer."""
491        del self._read_transfers[transfer.resource_id]
492
493        if not transfer.status.ok():
494            _LOG.error(
495                'Read transfer %d terminated with status %s',
496                transfer.id,
497                transfer.status,
498            )
499
500    def _start_write_transfer(self, transfer: Transfer) -> None:
501        """Begins a new write transfer, opening the stream if it isn't."""
502
503        self._write_transfers[transfer.resource_id] = transfer
504        self._write_stream.open()
505
506        _LOG.debug('Starting new write transfer %d', transfer.id)
507        self._loop.call_soon_threadsafe(
508            self._new_transfer_queue.put_nowait, transfer
509        )
510
511    def _end_write_transfer(self, transfer: Transfer) -> None:
512        """Completes a write transfer."""
513        del self._write_transfers[transfer.resource_id]
514
515        if not transfer.status.ok():
516            _LOG.error(
517                'Write transfer %d terminated with status %s',
518                transfer.id,
519                transfer.status,
520            )
521
522
523class Error(Exception):
524    """Exception raised when a transfer fails.
525
526    Stores the ID of the failed transfer resource and the error that occurred.
527    """
528
529    def __init__(self, resource_id: int, status: Status):
530        super().__init__(f'Transfer {resource_id} failed with status {status}')
531        self.resource_id = resource_id
532        self.status = status
533