1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19import contextlib 20import struct 21import asyncio 22import logging 23import io 24from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict 25 26from bumble import hci 27from bumble.colors import color 28from bumble.snoop import Snooper 29 30 31# ----------------------------------------------------------------------------- 32# Logging 33# ----------------------------------------------------------------------------- 34logger = logging.getLogger(__name__) 35 36# ----------------------------------------------------------------------------- 37# Information needed to parse HCI packets with a generic parser: 38# For each packet type, the info represents: 39# (length-size, length-offset, unpack-type) 40HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = { 41 hci.HCI_COMMAND_PACKET: (1, 2, 'B'), 42 hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), 43 hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), 44 hci.HCI_EVENT_PACKET: (1, 1, 'B'), 45 hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'), 46} 47 48 49# ----------------------------------------------------------------------------- 50# Errors 51# ----------------------------------------------------------------------------- 52class TransportLostError(Exception): 53 """ 54 The Transport has been lost/disconnected. 55 """ 56 57 58# ----------------------------------------------------------------------------- 59# Typing Protocols 60# ----------------------------------------------------------------------------- 61class TransportSink(Protocol): 62 def on_packet(self, packet: bytes) -> None: ... 63 64 65class TransportSource(Protocol): 66 terminated: asyncio.Future[None] 67 68 def set_packet_sink(self, sink: TransportSink) -> None: ... 69 70 71# ----------------------------------------------------------------------------- 72class PacketPump: 73 """ 74 Pump HCI packets from a reader to a sink. 75 """ 76 77 def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None: 78 self.reader = reader 79 self.sink = sink 80 81 async def run(self) -> None: 82 while True: 83 try: 84 # Deliver the packet to the sink 85 self.sink.on_packet(await self.reader.next_packet()) 86 except Exception as error: 87 logger.warning(f'!!! {error}') 88 89 90# ----------------------------------------------------------------------------- 91class PacketParser: 92 """ 93 In-line parser that accepts data and emits 'on_packet' when a full packet has been 94 parsed. 95 """ 96 97 # pylint: disable=attribute-defined-outside-init 98 99 NEED_TYPE = 0 100 NEED_LENGTH = 1 101 NEED_BODY = 2 102 103 sink: Optional[TransportSink] 104 extended_packet_info: Dict[int, Tuple[int, int, str]] 105 packet_info: Optional[Tuple[int, int, str]] = None 106 107 def __init__(self, sink: Optional[TransportSink] = None) -> None: 108 self.sink = sink 109 self.extended_packet_info = {} 110 self.reset() 111 112 def reset(self) -> None: 113 self.state = PacketParser.NEED_TYPE 114 self.bytes_needed = 1 115 self.packet = bytearray() 116 self.packet_info = None 117 118 def feed_data(self, data: bytes) -> None: 119 data_offset = 0 120 data_left = len(data) 121 while data_left and self.bytes_needed: 122 consumed = min(self.bytes_needed, data_left) 123 self.packet.extend(data[data_offset : data_offset + consumed]) 124 data_offset += consumed 125 data_left -= consumed 126 self.bytes_needed -= consumed 127 128 if self.bytes_needed == 0: 129 if self.state == PacketParser.NEED_TYPE: 130 packet_type = self.packet[0] 131 self.packet_info = HCI_PACKET_INFO.get( 132 packet_type 133 ) or self.extended_packet_info.get(packet_type) 134 if self.packet_info is None: 135 raise ValueError(f'invalid packet type {packet_type}') 136 self.state = PacketParser.NEED_LENGTH 137 self.bytes_needed = self.packet_info[0] + self.packet_info[1] 138 elif self.state == PacketParser.NEED_LENGTH: 139 assert self.packet_info is not None 140 body_length = struct.unpack_from( 141 self.packet_info[2], self.packet, 1 + self.packet_info[1] 142 )[0] 143 self.bytes_needed = body_length 144 self.state = PacketParser.NEED_BODY 145 146 # Emit a packet if one is complete 147 if self.state == PacketParser.NEED_BODY and not self.bytes_needed: 148 if self.sink: 149 try: 150 self.sink.on_packet(bytes(self.packet)) 151 except Exception as error: 152 logger.exception( 153 color(f'!!! Exception in on_packet: {error}', 'red') 154 ) 155 self.reset() 156 157 def set_packet_sink(self, sink: TransportSink) -> None: 158 self.sink = sink 159 160 161# ----------------------------------------------------------------------------- 162class PacketReader: 163 """ 164 Reader that reads HCI packets from a sync source. 165 """ 166 167 def __init__(self, source: io.BufferedReader) -> None: 168 self.source = source 169 self.at_end = False 170 171 def next_packet(self) -> Optional[bytes]: 172 # Get the packet type 173 packet_type = self.source.read(1) 174 if len(packet_type) != 1: 175 self.at_end = True 176 return None 177 178 # Get the packet info based on its type 179 packet_info = HCI_PACKET_INFO.get(packet_type[0]) 180 if packet_info is None: 181 raise ValueError(f'invalid packet type {packet_type[0]} found') 182 183 # Read the header (that includes the length) 184 header_size = packet_info[0] + packet_info[1] 185 header = self.source.read(header_size) 186 if len(header) != header_size: 187 raise ValueError('packet too short') 188 189 # Read the body 190 body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] 191 body = self.source.read(body_length) 192 if len(body) != body_length: 193 raise ValueError('packet too short') 194 195 return packet_type + header + body 196 197 198# ----------------------------------------------------------------------------- 199class AsyncPacketReader: 200 """ 201 Reader that reads HCI packets from an async source. 202 """ 203 204 def __init__(self, source: asyncio.StreamReader) -> None: 205 self.source = source 206 207 async def next_packet(self) -> bytes: 208 # Get the packet type 209 packet_type = await self.source.readexactly(1) 210 211 # Get the packet info based on its type 212 packet_info = HCI_PACKET_INFO.get(packet_type[0]) 213 if packet_info is None: 214 raise ValueError(f'invalid packet type {packet_type[0]} found') 215 216 # Read the header (that includes the length) 217 header_size = packet_info[0] + packet_info[1] 218 header = await self.source.readexactly(header_size) 219 220 # Read the body 221 body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] 222 body = await self.source.readexactly(body_length) 223 224 return packet_type + header + body 225 226 227# ----------------------------------------------------------------------------- 228class AsyncPipeSink: 229 """ 230 Sink that forwards packets asynchronously to another sink. 231 """ 232 233 def __init__(self, sink: TransportSink) -> None: 234 self.sink = sink 235 self.loop = asyncio.get_running_loop() 236 237 def on_packet(self, packet: bytes) -> None: 238 self.loop.call_soon(self.sink.on_packet, packet) 239 240 241# ----------------------------------------------------------------------------- 242class ParserSource: 243 """ 244 Base class designed to be subclassed by transport-specific source classes 245 """ 246 247 terminated: asyncio.Future[None] 248 parser: PacketParser 249 250 def __init__(self) -> None: 251 self.parser = PacketParser() 252 self.terminated = asyncio.get_running_loop().create_future() 253 254 def set_packet_sink(self, sink: TransportSink) -> None: 255 self.parser.set_packet_sink(sink) 256 257 def on_transport_lost(self) -> None: 258 self.terminated.set_result(None) 259 if self.parser.sink: 260 if hasattr(self.parser.sink, 'on_transport_lost'): 261 self.parser.sink.on_transport_lost() 262 263 async def wait_for_termination(self) -> None: 264 """ 265 Convenience method for backward compatibility. Prefer using the `terminated` 266 attribute instead. 267 """ 268 return await self.terminated 269 270 def close(self) -> None: 271 pass 272 273 274# ----------------------------------------------------------------------------- 275class StreamPacketSource(asyncio.Protocol, ParserSource): 276 def data_received(self, data: bytes) -> None: 277 self.parser.feed_data(data) 278 279 280# ----------------------------------------------------------------------------- 281class StreamPacketSink: 282 def __init__(self, transport: asyncio.WriteTransport) -> None: 283 self.transport = transport 284 285 def on_packet(self, packet: bytes) -> None: 286 self.transport.write(packet) 287 288 def close(self) -> None: 289 self.transport.close() 290 291 292# ----------------------------------------------------------------------------- 293class Transport: 294 """ 295 Base class for all transports. 296 297 A Transport represents a source and a sink together. 298 An instance must be closed by calling close() when no longer used. Instances 299 implement the ContextManager protocol so that they may be used in a `async with` 300 statement. 301 An instance is iterable. The iterator yields, in order, its source and sink, so 302 that it may be used with a convenient call syntax like: 303 304 async with create_transport() as (source, sink): 305 ... 306 """ 307 308 def __init__(self, source: TransportSource, sink: TransportSink) -> None: 309 self.source = source 310 self.sink = sink 311 312 async def __aenter__(self): 313 return self 314 315 async def __aexit__(self, *args): 316 await self.close() 317 318 def __iter__(self): 319 return iter((self.source, self.sink)) 320 321 async def close(self) -> None: 322 if hasattr(self.source, 'close'): 323 self.source.close() 324 if hasattr(self.sink, 'close'): 325 self.sink.close() 326 327 328# ----------------------------------------------------------------------------- 329class PumpedPacketSource(ParserSource): 330 pump_task: Optional[asyncio.Task[None]] 331 332 def __init__(self, receive) -> None: 333 super().__init__() 334 self.receive_function = receive 335 self.pump_task = None 336 337 def start(self) -> None: 338 async def pump_packets() -> None: 339 while True: 340 try: 341 packet = await self.receive_function() 342 self.parser.feed_data(packet) 343 except asyncio.CancelledError: 344 logger.debug('source pump task done') 345 self.terminated.set_result(None) 346 break 347 except Exception as error: 348 logger.warning(f'exception while waiting for packet: {error}') 349 self.terminated.set_exception(error) 350 break 351 352 self.pump_task = asyncio.create_task(pump_packets()) 353 354 def close(self) -> None: 355 if self.pump_task: 356 self.pump_task.cancel() 357 358 359# ----------------------------------------------------------------------------- 360class PumpedPacketSink: 361 def __init__(self, send): 362 self.send_function = send 363 self.packet_queue = asyncio.Queue() 364 self.pump_task = None 365 366 def on_packet(self, packet: bytes) -> None: 367 self.packet_queue.put_nowait(packet) 368 369 def start(self): 370 async def pump_packets(): 371 while True: 372 try: 373 packet = await self.packet_queue.get() 374 await self.send_function(packet) 375 except asyncio.CancelledError: 376 logger.debug('sink pump task done') 377 break 378 except Exception as error: 379 logger.warning(f'exception while sending packet: {error}') 380 break 381 382 self.pump_task = asyncio.create_task(pump_packets()) 383 384 def close(self): 385 if self.pump_task: 386 self.pump_task.cancel() 387 388 389# ----------------------------------------------------------------------------- 390class PumpedTransport(Transport): 391 source: PumpedPacketSource 392 sink: PumpedPacketSink 393 394 def __init__( 395 self, 396 source: PumpedPacketSource, 397 sink: PumpedPacketSink, 398 ) -> None: 399 super().__init__(source, sink) 400 401 def start(self) -> None: 402 self.source.start() 403 self.sink.start() 404 405 406# ----------------------------------------------------------------------------- 407class SnoopingTransport(Transport): 408 """Transport wrapper that snoops on packets to/from a wrapped transport.""" 409 410 @staticmethod 411 def create_with( 412 transport: Transport, snooper: ContextManager[Snooper] 413 ) -> SnoopingTransport: 414 """ 415 Create an instance given a snooper that works as as context manager. 416 417 The returned instance will exit the snooper context when it is closed. 418 """ 419 with contextlib.ExitStack() as exit_stack: 420 return SnoopingTransport( 421 transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close 422 ) 423 raise RuntimeError('unexpected code path') # Satisfy the type checker 424 425 class Source: 426 sink: TransportSink 427 428 @property 429 def metadata(self) -> dict[str, Any]: 430 return getattr(self.source, 'metadata', {}) 431 432 def __init__(self, source: TransportSource, snooper: Snooper): 433 self.source = source 434 self.snooper = snooper 435 self.terminated = source.terminated 436 437 def set_packet_sink(self, sink: TransportSink) -> None: 438 self.sink = sink 439 self.source.set_packet_sink(self) 440 441 def on_packet(self, packet: bytes) -> None: 442 self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) 443 if self.sink: 444 self.sink.on_packet(packet) 445 446 class Sink: 447 def __init__(self, sink: TransportSink, snooper: Snooper) -> None: 448 self.sink = sink 449 self.snooper = snooper 450 451 def on_packet(self, packet: bytes) -> None: 452 self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) 453 if self.sink: 454 self.sink.on_packet(packet) 455 456 def __init__( 457 self, 458 transport: Transport, 459 snooper: Snooper, 460 close_snooper=None, 461 ) -> None: 462 super().__init__( 463 self.Source(transport.source, snooper), self.Sink(transport.sink, snooper) 464 ) 465 self.transport = transport 466 self.close_snooper = close_snooper 467 468 async def close(self): 469 await self.transport.close() 470 if self.close_snooper: 471 self.close_snooper() 472