• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import struct
19import asyncio
20import logging
21from colors import color
22
23from .. import hci
24
25
26# -----------------------------------------------------------------------------
27# Logging
28# -----------------------------------------------------------------------------
29logger = logging.getLogger(__name__)
30
31# -----------------------------------------------------------------------------
32# Information needed to parse HCI packets with a generic parser:
33# For each packet type, the info represents:
34# (length-size, length-offset, unpack-type)
35HCI_PACKET_INFO = {
36    hci.HCI_COMMAND_PACKET:          (1, 2, 'B'),
37    hci.HCI_ACL_DATA_PACKET:         (2, 2, 'H'),
38    hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
39    hci.HCI_EVENT_PACKET:            (1, 1, 'B')
40}
41
42
43# -----------------------------------------------------------------------------
44class PacketPump:
45    '''
46    Pump HCI packets from a reader to a sink
47    '''
48
49    def __init__(self, reader, sink):
50        self.reader = reader
51        self.sink   = sink
52
53    async def run(self):
54        while True:
55            try:
56                # Get a packet from the source
57                packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
58
59                # Deliver the packet to the sink
60                self.sink.on_packet(packet)
61            except Exception as error:
62                logger.warning(f'!!! {error}')
63
64
65# -----------------------------------------------------------------------------
66class PacketParser:
67    '''
68    In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed
69    '''
70    NEED_TYPE   = 0
71    NEED_LENGTH = 1
72    NEED_BODY   = 2
73
74    def __init__(self, sink = None):
75        self.sink = sink
76        self.extended_packet_info = {}
77        self.reset()
78
79    def reset(self):
80        self.state        = PacketParser.NEED_TYPE
81        self.bytes_needed = 1
82        self.packet       = bytearray()
83        self.packet_info  = None
84
85    def feed_data(self, data):
86        data_offset = 0
87        data_left = len(data)
88        while data_left and self.bytes_needed:
89            consumed = min(self.bytes_needed, data_left)
90            self.packet.extend(data[data_offset:data_offset + consumed])
91            data_offset       += consumed
92            data_left         -= consumed
93            self.bytes_needed -= consumed
94
95            if self.bytes_needed == 0:
96                if self.state == PacketParser.NEED_TYPE:
97                    packet_type = self.packet[0]
98                    self.packet_info  = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type)
99                    if self.packet_info is None:
100                        raise ValueError(f'invalid packet type {packet_type}')
101                    self.state = PacketParser.NEED_LENGTH
102                    self.bytes_needed = self.packet_info[0] + self.packet_info[1]
103                elif self.state == PacketParser.NEED_LENGTH:
104                    body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0]
105                    self.bytes_needed = body_length
106                    self.state = PacketParser.NEED_BODY
107
108                # Emit a packet if one is complete
109                if self.state == PacketParser.NEED_BODY and not self.bytes_needed:
110                    if self.sink:
111                        try:
112                            self.sink.on_packet(bytes(self.packet))
113                        except Exception as error:
114                            logger.warning(color(f'!!! Exception in on_packet: {error}', 'red'))
115                    self.reset()
116
117    def set_packet_sink(self, sink):
118        self.sink = sink
119
120
121# -----------------------------------------------------------------------------
122class PacketReader:
123    '''
124    Reader that reads HCI packets from a sync source
125    '''
126
127    def __init__(self, source):
128        self.source = source
129
130    def next_packet(self):
131        # Get the packet type
132        packet_type = self.source.read(1)
133        if len(packet_type) != 1:
134            return None
135
136        # Get the packet info based on its type
137        packet_info = HCI_PACKET_INFO.get(packet_type[0])
138        if packet_info is None:
139            raise ValueError(f'invalid packet type {packet_type} found')
140
141        # Read the header (that includes the length)
142        header_size = packet_info[0] + packet_info[1]
143        header = self.source.read(header_size)
144        if len(header) != header_size:
145            raise ValueError('packet too short')
146
147        # Read the body
148        body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
149        body = self.source.read(body_length)
150        if len(body) != body_length:
151            raise ValueError('packet too short')
152
153        return packet_type + header + body
154
155
156# -----------------------------------------------------------------------------
157class AsyncPacketReader:
158    '''
159    Reader that reads HCI packets from an async source
160    '''
161
162    def __init__(self, source):
163        self.source = source
164
165    async def next_packet(self):
166        # Get the packet type
167        packet_type = await self.source.readexactly(1)
168
169        # Get the packet info based on its type
170        packet_info = HCI_PACKET_INFO.get(packet_type[0])
171        if packet_info is None:
172            raise ValueError(f'invalid packet type {packet_type} found')
173
174        # Read the header (that includes the length)
175        header_size = packet_info[0] + packet_info[1]
176        header = await self.source.readexactly(header_size)
177
178        # Read the body
179        body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
180        body = await self.source.readexactly(body_length)
181
182        return packet_type + header + body
183
184
185# -----------------------------------------------------------------------------
186class AsyncPipeSink:
187    '''
188    Sink that forwards packets asynchronously to another sink
189    '''
190    def __init__(self, sink):
191        self.sink = sink
192        self.loop = asyncio.get_running_loop()
193
194    def on_packet(self, packet):
195        self.loop.call_soon(self.sink.on_packet, packet)
196
197
198# -----------------------------------------------------------------------------
199class ParserSource:
200    """
201    Base class designed to be subclassed by transport-specific source classes
202    """
203
204    def __init__(self):
205        self.parser     = PacketParser()
206        self.terminated = asyncio.get_running_loop().create_future()
207
208    def set_packet_sink(self, sink):
209        self.parser.set_packet_sink(sink)
210
211    async def wait_for_termination(self):
212        return await self.terminated
213
214    def close(self):
215        pass
216
217
218# -----------------------------------------------------------------------------
219class StreamPacketSource(asyncio.Protocol, ParserSource):
220    def data_received(self, data):
221        self.parser.feed_data(data)
222
223
224# -----------------------------------------------------------------------------
225class StreamPacketSink:
226    def __init__(self, transport):
227        self.transport = transport
228
229    def on_packet(self, packet):
230        self.transport.write(packet)
231
232    def close(self):
233        self.transport.close()
234
235
236# -----------------------------------------------------------------------------
237class Transport:
238    def __init__(self, source, sink):
239        self.source = source
240        self.sink   = sink
241
242    async def __aenter__(self):
243        return self
244
245    async def __aexit__(self, *args):
246        await self.close()
247
248    def __iter__(self):
249        return iter((self.source, self.sink))
250
251    async def close(self):
252        self.source.close()
253        self.sink.close()
254
255
256# -----------------------------------------------------------------------------
257class PumpedPacketSource(ParserSource):
258    def __init__(self, receive):
259        super().__init__()
260        self.receive_function = receive
261        self.pump_task        = None
262
263    def start(self):
264        async def pump_packets():
265            while True:
266                try:
267                    packet = await self.receive_function()
268                    self.parser.feed_data(packet)
269                except asyncio.exceptions.CancelledError:
270                    logger.debug('source pump task done')
271                    break
272                except Exception as error:
273                    logger.warn(f'exception while waiting for packet: {error}')
274                    self.terminated.set_result(error)
275                    break
276
277        self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
278
279    def close(self):
280        if self.pump_task:
281            self.pump_task.cancel()
282
283
284# -----------------------------------------------------------------------------
285class PumpedPacketSink:
286    def __init__(self, send):
287        self.send_function = send
288        self.packet_queue  = asyncio.Queue()
289        self.pump_task     = None
290
291    def on_packet(self, packet):
292        self.packet_queue.put_nowait(packet)
293
294    def start(self):
295        async def pump_packets():
296            while True:
297                try:
298                    packet = await self.packet_queue.get()
299                    await self.send_function(packet)
300                except asyncio.exceptions.CancelledError:
301                    logger.debug('sink pump task done')
302                    break
303                except Exception as error:
304                    logger.warn(f'exception while sending packet: {error}')
305                    break
306
307        self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
308
309    def close(self):
310        if self.pump_task:
311            self.pump_task.cancel()
312
313
314# -----------------------------------------------------------------------------
315class PumpedTransport(Transport):
316    def __init__(self, source, sink, close_function):
317        super().__init__(source, sink)
318        self.close_function = close_function
319
320    def start(self):
321        self.source.start()
322        self.sink.start()
323
324    async def close(self):
325        await super().close()
326        await self.close_function()
327