1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19import contextlib 20import struct 21import asyncio 22import logging 23from typing import ContextManager 24 25from .. import hci 26from ..colors import color 27from ..snoop import Snooper 28 29 30# ----------------------------------------------------------------------------- 31# Logging 32# ----------------------------------------------------------------------------- 33logger = logging.getLogger(__name__) 34 35# ----------------------------------------------------------------------------- 36# Information needed to parse HCI packets with a generic parser: 37# For each packet type, the info represents: 38# (length-size, length-offset, unpack-type) 39HCI_PACKET_INFO = { 40 hci.HCI_COMMAND_PACKET: (1, 2, 'B'), 41 hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), 42 hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), 43 hci.HCI_EVENT_PACKET: (1, 1, 'B'), 44} 45 46 47# ----------------------------------------------------------------------------- 48class PacketPump: 49 ''' 50 Pump HCI packets from a reader to a sink 51 ''' 52 53 def __init__(self, reader, sink): 54 self.reader = reader 55 self.sink = sink 56 57 async def run(self): 58 while True: 59 try: 60 # Get a packet from the source 61 packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet()) 62 63 # Deliver the packet to the sink 64 self.sink.on_packet(packet) 65 except Exception as error: 66 logger.warning(f'!!! {error}') 67 68 69# ----------------------------------------------------------------------------- 70class PacketParser: 71 ''' 72 In-line parser that accepts data and emits 'on_packet' when a full packet has been 73 parsed 74 ''' 75 76 # pylint: disable=attribute-defined-outside-init 77 78 NEED_TYPE = 0 79 NEED_LENGTH = 1 80 NEED_BODY = 2 81 82 def __init__(self, sink=None): 83 self.sink = sink 84 self.extended_packet_info = {} 85 self.reset() 86 87 def reset(self): 88 self.state = PacketParser.NEED_TYPE 89 self.bytes_needed = 1 90 self.packet = bytearray() 91 self.packet_info = None 92 93 def feed_data(self, data): 94 data_offset = 0 95 data_left = len(data) 96 while data_left and self.bytes_needed: 97 consumed = min(self.bytes_needed, data_left) 98 self.packet.extend(data[data_offset : data_offset + consumed]) 99 data_offset += consumed 100 data_left -= consumed 101 self.bytes_needed -= consumed 102 103 if self.bytes_needed == 0: 104 if self.state == PacketParser.NEED_TYPE: 105 packet_type = self.packet[0] 106 self.packet_info = HCI_PACKET_INFO.get( 107 packet_type 108 ) or self.extended_packet_info.get(packet_type) 109 if self.packet_info is None: 110 raise ValueError(f'invalid packet type {packet_type}') 111 self.state = PacketParser.NEED_LENGTH 112 self.bytes_needed = self.packet_info[0] + self.packet_info[1] 113 elif self.state == PacketParser.NEED_LENGTH: 114 body_length = struct.unpack_from( 115 self.packet_info[2], self.packet, 1 + self.packet_info[1] 116 )[0] 117 self.bytes_needed = body_length 118 self.state = PacketParser.NEED_BODY 119 120 # Emit a packet if one is complete 121 if self.state == PacketParser.NEED_BODY and not self.bytes_needed: 122 if self.sink: 123 try: 124 self.sink.on_packet(bytes(self.packet)) 125 except Exception as error: 126 logger.warning( 127 color(f'!!! Exception in on_packet: {error}', 'red') 128 ) 129 self.reset() 130 131 def set_packet_sink(self, sink): 132 self.sink = sink 133 134 135# ----------------------------------------------------------------------------- 136class PacketReader: 137 ''' 138 Reader that reads HCI packets from a sync source 139 ''' 140 141 def __init__(self, source): 142 self.source = source 143 144 def next_packet(self): 145 # Get the packet type 146 packet_type = self.source.read(1) 147 if len(packet_type) != 1: 148 return None 149 150 # Get the packet info based on its type 151 packet_info = HCI_PACKET_INFO.get(packet_type[0]) 152 if packet_info is None: 153 raise ValueError(f'invalid packet type {packet_type} found') 154 155 # Read the header (that includes the length) 156 header_size = packet_info[0] + packet_info[1] 157 header = self.source.read(header_size) 158 if len(header) != header_size: 159 raise ValueError('packet too short') 160 161 # Read the body 162 body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] 163 body = self.source.read(body_length) 164 if len(body) != body_length: 165 raise ValueError('packet too short') 166 167 return packet_type + header + body 168 169 170# ----------------------------------------------------------------------------- 171class AsyncPacketReader: 172 ''' 173 Reader that reads HCI packets from an async source 174 ''' 175 176 def __init__(self, source): 177 self.source = source 178 179 async def next_packet(self): 180 # Get the packet type 181 packet_type = await self.source.readexactly(1) 182 183 # Get the packet info based on its type 184 packet_info = HCI_PACKET_INFO.get(packet_type[0]) 185 if packet_info is None: 186 raise ValueError(f'invalid packet type {packet_type} found') 187 188 # Read the header (that includes the length) 189 header_size = packet_info[0] + packet_info[1] 190 header = await self.source.readexactly(header_size) 191 192 # Read the body 193 body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] 194 body = await self.source.readexactly(body_length) 195 196 return packet_type + header + body 197 198 199# ----------------------------------------------------------------------------- 200class AsyncPipeSink: 201 ''' 202 Sink that forwards packets asynchronously to another sink 203 ''' 204 205 def __init__(self, sink): 206 self.sink = sink 207 self.loop = asyncio.get_running_loop() 208 209 def on_packet(self, packet): 210 self.loop.call_soon(self.sink.on_packet, packet) 211 212 213# ----------------------------------------------------------------------------- 214class ParserSource: 215 """ 216 Base class designed to be subclassed by transport-specific source classes 217 """ 218 219 def __init__(self): 220 self.parser = PacketParser() 221 self.terminated = asyncio.get_running_loop().create_future() 222 223 def set_packet_sink(self, sink): 224 self.parser.set_packet_sink(sink) 225 226 async def wait_for_termination(self): 227 return await self.terminated 228 229 def close(self): 230 pass 231 232 233# ----------------------------------------------------------------------------- 234class StreamPacketSource(asyncio.Protocol, ParserSource): 235 def data_received(self, data): 236 self.parser.feed_data(data) 237 238 239# ----------------------------------------------------------------------------- 240class StreamPacketSink: 241 def __init__(self, transport): 242 self.transport = transport 243 244 def on_packet(self, packet): 245 self.transport.write(packet) 246 247 def close(self): 248 self.transport.close() 249 250 251# ----------------------------------------------------------------------------- 252class Transport: 253 """ 254 Base class for all transports. 255 256 A Transport represents a source and a sink together. 257 An instance must be closed by calling close() when no longer used. Instances 258 implement the ContextManager protocol so that they may be used in a `async with` 259 statement. 260 An instance is iterable. The iterator yields, in order, its source and sink, so 261 that it may be used with a convenient call syntax like: 262 263 async with create_transport() as (source, sink): 264 ... 265 """ 266 267 def __init__(self, source, sink): 268 self.source = source 269 self.sink = sink 270 271 async def __aenter__(self): 272 return self 273 274 async def __aexit__(self, *args): 275 await self.close() 276 277 def __iter__(self): 278 return iter((self.source, self.sink)) 279 280 async def close(self) -> None: 281 self.source.close() 282 self.sink.close() 283 284 285# ----------------------------------------------------------------------------- 286class PumpedPacketSource(ParserSource): 287 def __init__(self, receive): 288 super().__init__() 289 self.receive_function = receive 290 self.pump_task = None 291 292 def start(self): 293 async def pump_packets(): 294 while True: 295 try: 296 packet = await self.receive_function() 297 self.parser.feed_data(packet) 298 except asyncio.exceptions.CancelledError: 299 logger.debug('source pump task done') 300 break 301 except Exception as error: 302 logger.warning(f'exception while waiting for packet: {error}') 303 self.terminated.set_result(error) 304 break 305 306 self.pump_task = asyncio.create_task(pump_packets()) 307 308 def close(self): 309 if self.pump_task: 310 self.pump_task.cancel() 311 312 313# ----------------------------------------------------------------------------- 314class PumpedPacketSink: 315 def __init__(self, send): 316 self.send_function = send 317 self.packet_queue = asyncio.Queue() 318 self.pump_task = None 319 320 def on_packet(self, packet): 321 self.packet_queue.put_nowait(packet) 322 323 def start(self): 324 async def pump_packets(): 325 while True: 326 try: 327 packet = await self.packet_queue.get() 328 await self.send_function(packet) 329 except asyncio.exceptions.CancelledError: 330 logger.debug('sink pump task done') 331 break 332 except Exception as error: 333 logger.warning(f'exception while sending packet: {error}') 334 break 335 336 self.pump_task = asyncio.create_task(pump_packets()) 337 338 def close(self): 339 if self.pump_task: 340 self.pump_task.cancel() 341 342 343# ----------------------------------------------------------------------------- 344class PumpedTransport(Transport): 345 def __init__(self, source, sink, close_function): 346 super().__init__(source, sink) 347 self.close_function = close_function 348 349 def start(self): 350 self.source.start() 351 self.sink.start() 352 353 async def close(self): 354 await super().close() 355 await self.close_function() 356 357 358# ----------------------------------------------------------------------------- 359class SnoopingTransport(Transport): 360 """Transport wrapper that snoops on packets to/from a wrapped transport.""" 361 362 @staticmethod 363 def create_with( 364 transport: Transport, snooper: ContextManager[Snooper] 365 ) -> SnoopingTransport: 366 """ 367 Create an instance given a snooper that works as as context manager. 368 369 The returned instance will exit the snooper context when it is closed. 370 """ 371 with contextlib.ExitStack() as exit_stack: 372 return SnoopingTransport( 373 transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close 374 ) 375 raise RuntimeError('unexpected code path') # Satisfy the type checker 376 377 class Source: 378 def __init__(self, source, snooper): 379 self.source = source 380 self.snooper = snooper 381 self.sink = None 382 383 def set_packet_sink(self, sink): 384 self.sink = sink 385 self.source.set_packet_sink(self) 386 387 def on_packet(self, packet): 388 self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) 389 if self.sink: 390 self.sink.on_packet(packet) 391 392 class Sink: 393 def __init__(self, sink, snooper): 394 self.sink = sink 395 self.snooper = snooper 396 397 def on_packet(self, packet): 398 self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) 399 if self.sink: 400 self.sink.on_packet(packet) 401 402 def __init__(self, transport, snooper, close_snooper=None): 403 super().__init__( 404 self.Source(transport.source, snooper), self.Sink(transport.sink, snooper) 405 ) 406 self.transport = transport 407 self.close_snooper = close_snooper 408 409 async def close(self): 410 await self.transport.close() 411 if self.close_snooper: 412 self.close_snooper() 413