1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import struct 19import asyncio 20import logging 21from colors import color 22 23from .. import hci 24 25 26# ----------------------------------------------------------------------------- 27# Logging 28# ----------------------------------------------------------------------------- 29logger = logging.getLogger(__name__) 30 31# ----------------------------------------------------------------------------- 32# Information needed to parse HCI packets with a generic parser: 33# For each packet type, the info represents: 34# (length-size, length-offset, unpack-type) 35HCI_PACKET_INFO = { 36 hci.HCI_COMMAND_PACKET: (1, 2, 'B'), 37 hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), 38 hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), 39 hci.HCI_EVENT_PACKET: (1, 1, 'B') 40} 41 42 43# ----------------------------------------------------------------------------- 44class PacketPump: 45 ''' 46 Pump HCI packets from a reader to a sink 47 ''' 48 49 def __init__(self, reader, sink): 50 self.reader = reader 51 self.sink = sink 52 53 async def run(self): 54 while True: 55 try: 56 # Get a packet from the source 57 packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet()) 58 59 # Deliver the packet to the sink 60 self.sink.on_packet(packet) 61 except Exception as error: 62 logger.warning(f'!!! {error}') 63 64 65# ----------------------------------------------------------------------------- 66class PacketParser: 67 ''' 68 In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed 69 ''' 70 NEED_TYPE = 0 71 NEED_LENGTH = 1 72 NEED_BODY = 2 73 74 def __init__(self, sink = None): 75 self.sink = sink 76 self.extended_packet_info = {} 77 self.reset() 78 79 def reset(self): 80 self.state = PacketParser.NEED_TYPE 81 self.bytes_needed = 1 82 self.packet = bytearray() 83 self.packet_info = None 84 85 def feed_data(self, data): 86 data_offset = 0 87 data_left = len(data) 88 while data_left and self.bytes_needed: 89 consumed = min(self.bytes_needed, data_left) 90 self.packet.extend(data[data_offset:data_offset + consumed]) 91 data_offset += consumed 92 data_left -= consumed 93 self.bytes_needed -= consumed 94 95 if self.bytes_needed == 0: 96 if self.state == PacketParser.NEED_TYPE: 97 packet_type = self.packet[0] 98 self.packet_info = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type) 99 if self.packet_info is None: 100 raise ValueError(f'invalid packet type {packet_type}') 101 self.state = PacketParser.NEED_LENGTH 102 self.bytes_needed = self.packet_info[0] + self.packet_info[1] 103 elif self.state == PacketParser.NEED_LENGTH: 104 body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0] 105 self.bytes_needed = body_length 106 self.state = PacketParser.NEED_BODY 107 108 # Emit a packet if one is complete 109 if self.state == PacketParser.NEED_BODY and not self.bytes_needed: 110 if self.sink: 111 try: 112 self.sink.on_packet(bytes(self.packet)) 113 except Exception as error: 114 logger.warning(color(f'!!! Exception in on_packet: {error}', 'red')) 115 self.reset() 116 117 def set_packet_sink(self, sink): 118 self.sink = sink 119 120 121# ----------------------------------------------------------------------------- 122class PacketReader: 123 ''' 124 Reader that reads HCI packets from a sync source 125 ''' 126 127 def __init__(self, source): 128 self.source = source 129 130 def next_packet(self): 131 # Get the packet type 132 packet_type = self.source.read(1) 133 if len(packet_type) != 1: 134 return None 135 136 # Get the packet info based on its type 137 packet_info = HCI_PACKET_INFO.get(packet_type[0]) 138 if packet_info is None: 139 raise ValueError(f'invalid packet type {packet_type} found') 140 141 # Read the header (that includes the length) 142 header_size = packet_info[0] + packet_info[1] 143 header = self.source.read(header_size) 144 if len(header) != header_size: 145 raise ValueError('packet too short') 146 147 # Read the body 148 body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] 149 body = self.source.read(body_length) 150 if len(body) != body_length: 151 raise ValueError('packet too short') 152 153 return packet_type + header + body 154 155 156# ----------------------------------------------------------------------------- 157class AsyncPacketReader: 158 ''' 159 Reader that reads HCI packets from an async source 160 ''' 161 162 def __init__(self, source): 163 self.source = source 164 165 async def next_packet(self): 166 # Get the packet type 167 packet_type = await self.source.readexactly(1) 168 169 # Get the packet info based on its type 170 packet_info = HCI_PACKET_INFO.get(packet_type[0]) 171 if packet_info is None: 172 raise ValueError(f'invalid packet type {packet_type} found') 173 174 # Read the header (that includes the length) 175 header_size = packet_info[0] + packet_info[1] 176 header = await self.source.readexactly(header_size) 177 178 # Read the body 179 body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] 180 body = await self.source.readexactly(body_length) 181 182 return packet_type + header + body 183 184 185# ----------------------------------------------------------------------------- 186class AsyncPipeSink: 187 ''' 188 Sink that forwards packets asynchronously to another sink 189 ''' 190 def __init__(self, sink): 191 self.sink = sink 192 self.loop = asyncio.get_running_loop() 193 194 def on_packet(self, packet): 195 self.loop.call_soon(self.sink.on_packet, packet) 196 197 198# ----------------------------------------------------------------------------- 199class ParserSource: 200 """ 201 Base class designed to be subclassed by transport-specific source classes 202 """ 203 204 def __init__(self): 205 self.parser = PacketParser() 206 self.terminated = asyncio.get_running_loop().create_future() 207 208 def set_packet_sink(self, sink): 209 self.parser.set_packet_sink(sink) 210 211 async def wait_for_termination(self): 212 return await self.terminated 213 214 def close(self): 215 pass 216 217 218# ----------------------------------------------------------------------------- 219class StreamPacketSource(asyncio.Protocol, ParserSource): 220 def data_received(self, data): 221 self.parser.feed_data(data) 222 223 224# ----------------------------------------------------------------------------- 225class StreamPacketSink: 226 def __init__(self, transport): 227 self.transport = transport 228 229 def on_packet(self, packet): 230 self.transport.write(packet) 231 232 def close(self): 233 self.transport.close() 234 235 236# ----------------------------------------------------------------------------- 237class Transport: 238 def __init__(self, source, sink): 239 self.source = source 240 self.sink = sink 241 242 async def __aenter__(self): 243 return self 244 245 async def __aexit__(self, *args): 246 await self.close() 247 248 def __iter__(self): 249 return iter((self.source, self.sink)) 250 251 async def close(self): 252 self.source.close() 253 self.sink.close() 254 255 256# ----------------------------------------------------------------------------- 257class PumpedPacketSource(ParserSource): 258 def __init__(self, receive): 259 super().__init__() 260 self.receive_function = receive 261 self.pump_task = None 262 263 def start(self): 264 async def pump_packets(): 265 while True: 266 try: 267 packet = await self.receive_function() 268 self.parser.feed_data(packet) 269 except asyncio.exceptions.CancelledError: 270 logger.debug('source pump task done') 271 break 272 except Exception as error: 273 logger.warn(f'exception while waiting for packet: {error}') 274 self.terminated.set_result(error) 275 break 276 277 self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) 278 279 def close(self): 280 if self.pump_task: 281 self.pump_task.cancel() 282 283 284# ----------------------------------------------------------------------------- 285class PumpedPacketSink: 286 def __init__(self, send): 287 self.send_function = send 288 self.packet_queue = asyncio.Queue() 289 self.pump_task = None 290 291 def on_packet(self, packet): 292 self.packet_queue.put_nowait(packet) 293 294 def start(self): 295 async def pump_packets(): 296 while True: 297 try: 298 packet = await self.packet_queue.get() 299 await self.send_function(packet) 300 except asyncio.exceptions.CancelledError: 301 logger.debug('sink pump task done') 302 break 303 except Exception as error: 304 logger.warn(f'exception while sending packet: {error}') 305 break 306 307 self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) 308 309 def close(self): 310 if self.pump_task: 311 self.pump_task.cancel() 312 313 314# ----------------------------------------------------------------------------- 315class PumpedTransport(Transport): 316 def __init__(self, source, sink, close_function): 317 super().__init__(source, sink) 318 self.close_function = close_function 319 320 def start(self): 321 self.source.start() 322 self.sink.start() 323 324 async def close(self): 325 await super().close() 326 await self.close_function() 327