• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import contextlib
20import struct
21import asyncio
22import logging
23import io
24from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
25
26from bumble import hci
27from bumble.colors import color
28from bumble.snoop import Snooper
29
30
31# -----------------------------------------------------------------------------
32# Logging
33# -----------------------------------------------------------------------------
34logger = logging.getLogger(__name__)
35
36# -----------------------------------------------------------------------------
37# Information needed to parse HCI packets with a generic parser:
38# For each packet type, the info represents:
39# (length-size, length-offset, unpack-type)
40HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
41    hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
42    hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
43    hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
44    hci.HCI_EVENT_PACKET: (1, 1, 'B'),
45    hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
46}
47
48
49# -----------------------------------------------------------------------------
50# Errors
51# -----------------------------------------------------------------------------
52class TransportLostError(Exception):
53    """
54    The Transport has been lost/disconnected.
55    """
56
57
58# -----------------------------------------------------------------------------
59# Typing Protocols
60# -----------------------------------------------------------------------------
61class TransportSink(Protocol):
62    def on_packet(self, packet: bytes) -> None: ...
63
64
65class TransportSource(Protocol):
66    terminated: asyncio.Future[None]
67
68    def set_packet_sink(self, sink: TransportSink) -> None: ...
69
70
71# -----------------------------------------------------------------------------
72class PacketPump:
73    """
74    Pump HCI packets from a reader to a sink.
75    """
76
77    def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
78        self.reader = reader
79        self.sink = sink
80
81    async def run(self) -> None:
82        while True:
83            try:
84                # Deliver the packet to the sink
85                self.sink.on_packet(await self.reader.next_packet())
86            except Exception as error:
87                logger.warning(f'!!! {error}')
88
89
90# -----------------------------------------------------------------------------
91class PacketParser:
92    """
93    In-line parser that accepts data and emits 'on_packet' when a full packet has been
94    parsed.
95    """
96
97    # pylint: disable=attribute-defined-outside-init
98
99    NEED_TYPE = 0
100    NEED_LENGTH = 1
101    NEED_BODY = 2
102
103    sink: Optional[TransportSink]
104    extended_packet_info: Dict[int, Tuple[int, int, str]]
105    packet_info: Optional[Tuple[int, int, str]] = None
106
107    def __init__(self, sink: Optional[TransportSink] = None) -> None:
108        self.sink = sink
109        self.extended_packet_info = {}
110        self.reset()
111
112    def reset(self) -> None:
113        self.state = PacketParser.NEED_TYPE
114        self.bytes_needed = 1
115        self.packet = bytearray()
116        self.packet_info = None
117
118    def feed_data(self, data: bytes) -> None:
119        data_offset = 0
120        data_left = len(data)
121        while data_left and self.bytes_needed:
122            consumed = min(self.bytes_needed, data_left)
123            self.packet.extend(data[data_offset : data_offset + consumed])
124            data_offset += consumed
125            data_left -= consumed
126            self.bytes_needed -= consumed
127
128            if self.bytes_needed == 0:
129                if self.state == PacketParser.NEED_TYPE:
130                    packet_type = self.packet[0]
131                    self.packet_info = HCI_PACKET_INFO.get(
132                        packet_type
133                    ) or self.extended_packet_info.get(packet_type)
134                    if self.packet_info is None:
135                        raise ValueError(f'invalid packet type {packet_type}')
136                    self.state = PacketParser.NEED_LENGTH
137                    self.bytes_needed = self.packet_info[0] + self.packet_info[1]
138                elif self.state == PacketParser.NEED_LENGTH:
139                    assert self.packet_info is not None
140                    body_length = struct.unpack_from(
141                        self.packet_info[2], self.packet, 1 + self.packet_info[1]
142                    )[0]
143                    self.bytes_needed = body_length
144                    self.state = PacketParser.NEED_BODY
145
146                # Emit a packet if one is complete
147                if self.state == PacketParser.NEED_BODY and not self.bytes_needed:
148                    if self.sink:
149                        try:
150                            self.sink.on_packet(bytes(self.packet))
151                        except Exception as error:
152                            logger.exception(
153                                color(f'!!! Exception in on_packet: {error}', 'red')
154                            )
155                    self.reset()
156
157    def set_packet_sink(self, sink: TransportSink) -> None:
158        self.sink = sink
159
160
161# -----------------------------------------------------------------------------
162class PacketReader:
163    """
164    Reader that reads HCI packets from a sync source.
165    """
166
167    def __init__(self, source: io.BufferedReader) -> None:
168        self.source = source
169        self.at_end = False
170
171    def next_packet(self) -> Optional[bytes]:
172        # Get the packet type
173        packet_type = self.source.read(1)
174        if len(packet_type) != 1:
175            self.at_end = True
176            return None
177
178        # Get the packet info based on its type
179        packet_info = HCI_PACKET_INFO.get(packet_type[0])
180        if packet_info is None:
181            raise ValueError(f'invalid packet type {packet_type[0]} found')
182
183        # Read the header (that includes the length)
184        header_size = packet_info[0] + packet_info[1]
185        header = self.source.read(header_size)
186        if len(header) != header_size:
187            raise ValueError('packet too short')
188
189        # Read the body
190        body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
191        body = self.source.read(body_length)
192        if len(body) != body_length:
193            raise ValueError('packet too short')
194
195        return packet_type + header + body
196
197
198# -----------------------------------------------------------------------------
199class AsyncPacketReader:
200    """
201    Reader that reads HCI packets from an async source.
202    """
203
204    def __init__(self, source: asyncio.StreamReader) -> None:
205        self.source = source
206
207    async def next_packet(self) -> bytes:
208        # Get the packet type
209        packet_type = await self.source.readexactly(1)
210
211        # Get the packet info based on its type
212        packet_info = HCI_PACKET_INFO.get(packet_type[0])
213        if packet_info is None:
214            raise ValueError(f'invalid packet type {packet_type[0]} found')
215
216        # Read the header (that includes the length)
217        header_size = packet_info[0] + packet_info[1]
218        header = await self.source.readexactly(header_size)
219
220        # Read the body
221        body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
222        body = await self.source.readexactly(body_length)
223
224        return packet_type + header + body
225
226
227# -----------------------------------------------------------------------------
228class AsyncPipeSink:
229    """
230    Sink that forwards packets asynchronously to another sink.
231    """
232
233    def __init__(self, sink: TransportSink) -> None:
234        self.sink = sink
235        self.loop = asyncio.get_running_loop()
236
237    def on_packet(self, packet: bytes) -> None:
238        self.loop.call_soon(self.sink.on_packet, packet)
239
240
241# -----------------------------------------------------------------------------
242class ParserSource:
243    """
244    Base class designed to be subclassed by transport-specific source classes
245    """
246
247    terminated: asyncio.Future[None]
248    parser: PacketParser
249
250    def __init__(self) -> None:
251        self.parser = PacketParser()
252        self.terminated = asyncio.get_running_loop().create_future()
253
254    def set_packet_sink(self, sink: TransportSink) -> None:
255        self.parser.set_packet_sink(sink)
256
257    def on_transport_lost(self) -> None:
258        self.terminated.set_result(None)
259        if self.parser.sink:
260            if hasattr(self.parser.sink, 'on_transport_lost'):
261                self.parser.sink.on_transport_lost()
262
263    async def wait_for_termination(self) -> None:
264        """
265        Convenience method for backward compatibility. Prefer using the `terminated`
266        attribute instead.
267        """
268        return await self.terminated
269
270    def close(self) -> None:
271        pass
272
273
274# -----------------------------------------------------------------------------
275class StreamPacketSource(asyncio.Protocol, ParserSource):
276    def data_received(self, data: bytes) -> None:
277        self.parser.feed_data(data)
278
279
280# -----------------------------------------------------------------------------
281class StreamPacketSink:
282    def __init__(self, transport: asyncio.WriteTransport) -> None:
283        self.transport = transport
284
285    def on_packet(self, packet: bytes) -> None:
286        self.transport.write(packet)
287
288    def close(self) -> None:
289        self.transport.close()
290
291
292# -----------------------------------------------------------------------------
293class Transport:
294    """
295    Base class for all transports.
296
297    A Transport represents a source and a sink together.
298    An instance must be closed by calling close() when no longer used. Instances
299    implement the ContextManager protocol so that they may be used in a `async with`
300    statement.
301    An instance is iterable. The iterator yields, in order, its source and sink, so
302    that it may be used with a convenient call syntax like:
303
304    async with create_transport() as (source, sink):
305        ...
306    """
307
308    def __init__(self, source: TransportSource, sink: TransportSink) -> None:
309        self.source = source
310        self.sink = sink
311
312    async def __aenter__(self):
313        return self
314
315    async def __aexit__(self, *args):
316        await self.close()
317
318    def __iter__(self):
319        return iter((self.source, self.sink))
320
321    async def close(self) -> None:
322        if hasattr(self.source, 'close'):
323            self.source.close()
324        if hasattr(self.sink, 'close'):
325            self.sink.close()
326
327
328# -----------------------------------------------------------------------------
329class PumpedPacketSource(ParserSource):
330    pump_task: Optional[asyncio.Task[None]]
331
332    def __init__(self, receive) -> None:
333        super().__init__()
334        self.receive_function = receive
335        self.pump_task = None
336
337    def start(self) -> None:
338        async def pump_packets() -> None:
339            while True:
340                try:
341                    packet = await self.receive_function()
342                    self.parser.feed_data(packet)
343                except asyncio.CancelledError:
344                    logger.debug('source pump task done')
345                    self.terminated.set_result(None)
346                    break
347                except Exception as error:
348                    logger.warning(f'exception while waiting for packet: {error}')
349                    self.terminated.set_exception(error)
350                    break
351
352        self.pump_task = asyncio.create_task(pump_packets())
353
354    def close(self) -> None:
355        if self.pump_task:
356            self.pump_task.cancel()
357
358
359# -----------------------------------------------------------------------------
360class PumpedPacketSink:
361    def __init__(self, send):
362        self.send_function = send
363        self.packet_queue = asyncio.Queue()
364        self.pump_task = None
365
366    def on_packet(self, packet: bytes) -> None:
367        self.packet_queue.put_nowait(packet)
368
369    def start(self):
370        async def pump_packets():
371            while True:
372                try:
373                    packet = await self.packet_queue.get()
374                    await self.send_function(packet)
375                except asyncio.CancelledError:
376                    logger.debug('sink pump task done')
377                    break
378                except Exception as error:
379                    logger.warning(f'exception while sending packet: {error}')
380                    break
381
382        self.pump_task = asyncio.create_task(pump_packets())
383
384    def close(self):
385        if self.pump_task:
386            self.pump_task.cancel()
387
388
389# -----------------------------------------------------------------------------
390class PumpedTransport(Transport):
391    source: PumpedPacketSource
392    sink: PumpedPacketSink
393
394    def __init__(
395        self,
396        source: PumpedPacketSource,
397        sink: PumpedPacketSink,
398    ) -> None:
399        super().__init__(source, sink)
400
401    def start(self) -> None:
402        self.source.start()
403        self.sink.start()
404
405
406# -----------------------------------------------------------------------------
407class SnoopingTransport(Transport):
408    """Transport wrapper that snoops on packets to/from a wrapped transport."""
409
410    @staticmethod
411    def create_with(
412        transport: Transport, snooper: ContextManager[Snooper]
413    ) -> SnoopingTransport:
414        """
415        Create an instance given a snooper that works as as context manager.
416
417        The returned instance will exit the snooper context when it is closed.
418        """
419        with contextlib.ExitStack() as exit_stack:
420            return SnoopingTransport(
421                transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
422            )
423        raise RuntimeError('unexpected code path')  # Satisfy the type checker
424
425    class Source:
426        sink: TransportSink
427
428        @property
429        def metadata(self) -> dict[str, Any]:
430            return getattr(self.source, 'metadata', {})
431
432        def __init__(self, source: TransportSource, snooper: Snooper):
433            self.source = source
434            self.snooper = snooper
435            self.terminated = source.terminated
436
437        def set_packet_sink(self, sink: TransportSink) -> None:
438            self.sink = sink
439            self.source.set_packet_sink(self)
440
441        def on_packet(self, packet: bytes) -> None:
442            self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
443            if self.sink:
444                self.sink.on_packet(packet)
445
446    class Sink:
447        def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
448            self.sink = sink
449            self.snooper = snooper
450
451        def on_packet(self, packet: bytes) -> None:
452            self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
453            if self.sink:
454                self.sink.on_packet(packet)
455
456    def __init__(
457        self,
458        transport: Transport,
459        snooper: Snooper,
460        close_snooper=None,
461    ) -> None:
462        super().__init__(
463            self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
464        )
465        self.transport = transport
466        self.close_snooper = close_snooper
467
468    async def close(self):
469        await self.transport.close()
470        if self.close_snooper:
471            self.close_snooper()
472