• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import contextlib
20import struct
21import asyncio
22import logging
23from typing import ContextManager
24
25from .. import hci
26from ..colors import color
27from ..snoop import Snooper
28
29
30# -----------------------------------------------------------------------------
31# Logging
32# -----------------------------------------------------------------------------
33logger = logging.getLogger(__name__)
34
35# -----------------------------------------------------------------------------
36# Information needed to parse HCI packets with a generic parser:
37# For each packet type, the info represents:
38# (length-size, length-offset, unpack-type)
39HCI_PACKET_INFO = {
40    hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
41    hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
42    hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
43    hci.HCI_EVENT_PACKET: (1, 1, 'B'),
44}
45
46
47# -----------------------------------------------------------------------------
48class PacketPump:
49    '''
50    Pump HCI packets from a reader to a sink
51    '''
52
53    def __init__(self, reader, sink):
54        self.reader = reader
55        self.sink = sink
56
57    async def run(self):
58        while True:
59            try:
60                # Get a packet from the source
61                packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
62
63                # Deliver the packet to the sink
64                self.sink.on_packet(packet)
65            except Exception as error:
66                logger.warning(f'!!! {error}')
67
68
69# -----------------------------------------------------------------------------
70class PacketParser:
71    '''
72    In-line parser that accepts data and emits 'on_packet' when a full packet has been
73    parsed
74    '''
75
76    # pylint: disable=attribute-defined-outside-init
77
78    NEED_TYPE = 0
79    NEED_LENGTH = 1
80    NEED_BODY = 2
81
82    def __init__(self, sink=None):
83        self.sink = sink
84        self.extended_packet_info = {}
85        self.reset()
86
87    def reset(self):
88        self.state = PacketParser.NEED_TYPE
89        self.bytes_needed = 1
90        self.packet = bytearray()
91        self.packet_info = None
92
93    def feed_data(self, data):
94        data_offset = 0
95        data_left = len(data)
96        while data_left and self.bytes_needed:
97            consumed = min(self.bytes_needed, data_left)
98            self.packet.extend(data[data_offset : data_offset + consumed])
99            data_offset += consumed
100            data_left -= consumed
101            self.bytes_needed -= consumed
102
103            if self.bytes_needed == 0:
104                if self.state == PacketParser.NEED_TYPE:
105                    packet_type = self.packet[0]
106                    self.packet_info = HCI_PACKET_INFO.get(
107                        packet_type
108                    ) or self.extended_packet_info.get(packet_type)
109                    if self.packet_info is None:
110                        raise ValueError(f'invalid packet type {packet_type}')
111                    self.state = PacketParser.NEED_LENGTH
112                    self.bytes_needed = self.packet_info[0] + self.packet_info[1]
113                elif self.state == PacketParser.NEED_LENGTH:
114                    body_length = struct.unpack_from(
115                        self.packet_info[2], self.packet, 1 + self.packet_info[1]
116                    )[0]
117                    self.bytes_needed = body_length
118                    self.state = PacketParser.NEED_BODY
119
120                # Emit a packet if one is complete
121                if self.state == PacketParser.NEED_BODY and not self.bytes_needed:
122                    if self.sink:
123                        try:
124                            self.sink.on_packet(bytes(self.packet))
125                        except Exception as error:
126                            logger.warning(
127                                color(f'!!! Exception in on_packet: {error}', 'red')
128                            )
129                    self.reset()
130
131    def set_packet_sink(self, sink):
132        self.sink = sink
133
134
135# -----------------------------------------------------------------------------
136class PacketReader:
137    '''
138    Reader that reads HCI packets from a sync source
139    '''
140
141    def __init__(self, source):
142        self.source = source
143
144    def next_packet(self):
145        # Get the packet type
146        packet_type = self.source.read(1)
147        if len(packet_type) != 1:
148            return None
149
150        # Get the packet info based on its type
151        packet_info = HCI_PACKET_INFO.get(packet_type[0])
152        if packet_info is None:
153            raise ValueError(f'invalid packet type {packet_type} found')
154
155        # Read the header (that includes the length)
156        header_size = packet_info[0] + packet_info[1]
157        header = self.source.read(header_size)
158        if len(header) != header_size:
159            raise ValueError('packet too short')
160
161        # Read the body
162        body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
163        body = self.source.read(body_length)
164        if len(body) != body_length:
165            raise ValueError('packet too short')
166
167        return packet_type + header + body
168
169
170# -----------------------------------------------------------------------------
171class AsyncPacketReader:
172    '''
173    Reader that reads HCI packets from an async source
174    '''
175
176    def __init__(self, source):
177        self.source = source
178
179    async def next_packet(self):
180        # Get the packet type
181        packet_type = await self.source.readexactly(1)
182
183        # Get the packet info based on its type
184        packet_info = HCI_PACKET_INFO.get(packet_type[0])
185        if packet_info is None:
186            raise ValueError(f'invalid packet type {packet_type} found')
187
188        # Read the header (that includes the length)
189        header_size = packet_info[0] + packet_info[1]
190        header = await self.source.readexactly(header_size)
191
192        # Read the body
193        body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
194        body = await self.source.readexactly(body_length)
195
196        return packet_type + header + body
197
198
199# -----------------------------------------------------------------------------
200class AsyncPipeSink:
201    '''
202    Sink that forwards packets asynchronously to another sink
203    '''
204
205    def __init__(self, sink):
206        self.sink = sink
207        self.loop = asyncio.get_running_loop()
208
209    def on_packet(self, packet):
210        self.loop.call_soon(self.sink.on_packet, packet)
211
212
213# -----------------------------------------------------------------------------
214class ParserSource:
215    """
216    Base class designed to be subclassed by transport-specific source classes
217    """
218
219    def __init__(self):
220        self.parser = PacketParser()
221        self.terminated = asyncio.get_running_loop().create_future()
222
223    def set_packet_sink(self, sink):
224        self.parser.set_packet_sink(sink)
225
226    async def wait_for_termination(self):
227        return await self.terminated
228
229    def close(self):
230        pass
231
232
233# -----------------------------------------------------------------------------
234class StreamPacketSource(asyncio.Protocol, ParserSource):
235    def data_received(self, data):
236        self.parser.feed_data(data)
237
238
239# -----------------------------------------------------------------------------
240class StreamPacketSink:
241    def __init__(self, transport):
242        self.transport = transport
243
244    def on_packet(self, packet):
245        self.transport.write(packet)
246
247    def close(self):
248        self.transport.close()
249
250
251# -----------------------------------------------------------------------------
252class Transport:
253    """
254    Base class for all transports.
255
256    A Transport represents a source and a sink together.
257    An instance must be closed by calling close() when no longer used. Instances
258    implement the ContextManager protocol so that they may be used in a `async with`
259    statement.
260    An instance is iterable. The iterator yields, in order, its source and sink, so
261    that it may be used with a convenient call syntax like:
262
263    async with create_transport() as (source, sink):
264        ...
265    """
266
267    def __init__(self, source, sink):
268        self.source = source
269        self.sink = sink
270
271    async def __aenter__(self):
272        return self
273
274    async def __aexit__(self, *args):
275        await self.close()
276
277    def __iter__(self):
278        return iter((self.source, self.sink))
279
280    async def close(self) -> None:
281        self.source.close()
282        self.sink.close()
283
284
285# -----------------------------------------------------------------------------
286class PumpedPacketSource(ParserSource):
287    def __init__(self, receive):
288        super().__init__()
289        self.receive_function = receive
290        self.pump_task = None
291
292    def start(self):
293        async def pump_packets():
294            while True:
295                try:
296                    packet = await self.receive_function()
297                    self.parser.feed_data(packet)
298                except asyncio.exceptions.CancelledError:
299                    logger.debug('source pump task done')
300                    break
301                except Exception as error:
302                    logger.warning(f'exception while waiting for packet: {error}')
303                    self.terminated.set_result(error)
304                    break
305
306        self.pump_task = asyncio.create_task(pump_packets())
307
308    def close(self):
309        if self.pump_task:
310            self.pump_task.cancel()
311
312
313# -----------------------------------------------------------------------------
314class PumpedPacketSink:
315    def __init__(self, send):
316        self.send_function = send
317        self.packet_queue = asyncio.Queue()
318        self.pump_task = None
319
320    def on_packet(self, packet):
321        self.packet_queue.put_nowait(packet)
322
323    def start(self):
324        async def pump_packets():
325            while True:
326                try:
327                    packet = await self.packet_queue.get()
328                    await self.send_function(packet)
329                except asyncio.exceptions.CancelledError:
330                    logger.debug('sink pump task done')
331                    break
332                except Exception as error:
333                    logger.warning(f'exception while sending packet: {error}')
334                    break
335
336        self.pump_task = asyncio.create_task(pump_packets())
337
338    def close(self):
339        if self.pump_task:
340            self.pump_task.cancel()
341
342
343# -----------------------------------------------------------------------------
344class PumpedTransport(Transport):
345    def __init__(self, source, sink, close_function):
346        super().__init__(source, sink)
347        self.close_function = close_function
348
349    def start(self):
350        self.source.start()
351        self.sink.start()
352
353    async def close(self):
354        await super().close()
355        await self.close_function()
356
357
358# -----------------------------------------------------------------------------
359class SnoopingTransport(Transport):
360    """Transport wrapper that snoops on packets to/from a wrapped transport."""
361
362    @staticmethod
363    def create_with(
364        transport: Transport, snooper: ContextManager[Snooper]
365    ) -> SnoopingTransport:
366        """
367        Create an instance given a snooper that works as as context manager.
368
369        The returned instance will exit the snooper context when it is closed.
370        """
371        with contextlib.ExitStack() as exit_stack:
372            return SnoopingTransport(
373                transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
374            )
375        raise RuntimeError('unexpected code path')  # Satisfy the type checker
376
377    class Source:
378        def __init__(self, source, snooper):
379            self.source = source
380            self.snooper = snooper
381            self.sink = None
382
383        def set_packet_sink(self, sink):
384            self.sink = sink
385            self.source.set_packet_sink(self)
386
387        def on_packet(self, packet):
388            self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
389            if self.sink:
390                self.sink.on_packet(packet)
391
392    class Sink:
393        def __init__(self, sink, snooper):
394            self.sink = sink
395            self.snooper = snooper
396
397        def on_packet(self, packet):
398            self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
399            if self.sink:
400                self.sink.on_packet(packet)
401
402    def __init__(self, transport, snooper, close_snooper=None):
403        super().__init__(
404            self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
405        )
406        self.transport = transport
407        self.close_snooper = close_snooper
408
409    async def close(self):
410        await self.transport.close()
411        if self.close_snooper:
412            self.close_snooper()
413