1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import logging 19import asyncio 20from functools import partial 21 22from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT 23from bumble.colors import color 24from bumble.hci import ( 25 Address, 26 HCI_SUCCESS, 27 HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, 28 HCI_CONNECTION_TIMEOUT_ERROR, 29 HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR, 30 HCI_PAGE_TIMEOUT_ERROR, 31 HCI_Connection_Complete_Event, 32) 33from bumble import controller 34 35from typing import Optional, Set 36 37# ----------------------------------------------------------------------------- 38# Logging 39# ----------------------------------------------------------------------------- 40logger = logging.getLogger(__name__) 41 42 43# ----------------------------------------------------------------------------- 44# Utils 45# ----------------------------------------------------------------------------- 46def parse_parameters(params_str): 47 result = {} 48 for param_str in params_str.split(','): 49 if '=' in param_str: 50 key, value = param_str.split('=') 51 result[key] = value 52 return result 53 54 55# ----------------------------------------------------------------------------- 56# TODO: add more support for various LL exchanges 57# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU) 58# ----------------------------------------------------------------------------- 59class LocalLink: 60 ''' 61 Link bus for controllers to communicate with each other 62 ''' 63 64 controllers: Set[controller.Controller] 65 66 def __init__(self): 67 self.controllers = set() 68 self.pending_connection = None 69 self.pending_classic_connection = None 70 71 ############################################################ 72 # Common utils 73 ############################################################ 74 75 def add_controller(self, controller): 76 logger.debug(f'new controller: {controller}') 77 self.controllers.add(controller) 78 79 def remove_controller(self, controller): 80 self.controllers.remove(controller) 81 82 def find_controller(self, address): 83 for controller in self.controllers: 84 if controller.random_address == address: 85 return controller 86 return None 87 88 def find_classic_controller( 89 self, address: Address 90 ) -> Optional[controller.Controller]: 91 for controller in self.controllers: 92 if controller.public_address == address: 93 return controller 94 return None 95 96 def get_pending_connection(self): 97 return self.pending_connection 98 99 ############################################################ 100 # LE handlers 101 ############################################################ 102 103 def on_address_changed(self, controller): 104 pass 105 106 def send_advertising_data(self, sender_address, data): 107 # Send the advertising data to all controllers, except the sender 108 for controller in self.controllers: 109 if controller.random_address != sender_address: 110 controller.on_link_advertising_data(sender_address, data) 111 112 def send_acl_data(self, sender_controller, destination_address, transport, data): 113 # Send the data to the first controller with a matching address 114 if transport == BT_LE_TRANSPORT: 115 destination_controller = self.find_controller(destination_address) 116 source_address = sender_controller.random_address 117 elif transport == BT_BR_EDR_TRANSPORT: 118 destination_controller = self.find_classic_controller(destination_address) 119 source_address = sender_controller.public_address 120 121 if destination_controller is not None: 122 destination_controller.on_link_acl_data(source_address, transport, data) 123 124 def on_connection_complete(self): 125 # Check that we expect this call 126 if not self.pending_connection: 127 logger.warning('on_connection_complete with no pending connection') 128 return 129 130 central_address, le_create_connection_command = self.pending_connection 131 self.pending_connection = None 132 133 # Find the controller that initiated the connection 134 if not (central_controller := self.find_controller(central_address)): 135 logger.warning('!!! Initiating controller not found') 136 return 137 138 # Connect to the first controller with a matching address 139 if peripheral_controller := self.find_controller( 140 le_create_connection_command.peer_address 141 ): 142 central_controller.on_link_peripheral_connection_complete( 143 le_create_connection_command, HCI_SUCCESS 144 ) 145 peripheral_controller.on_link_central_connected(central_address) 146 return 147 148 # No peripheral found 149 central_controller.on_link_peripheral_connection_complete( 150 le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR 151 ) 152 153 def connect(self, central_address, le_create_connection_command): 154 logger.debug( 155 f'$$$ CONNECTION {central_address} -> ' 156 f'{le_create_connection_command.peer_address}' 157 ) 158 self.pending_connection = (central_address, le_create_connection_command) 159 asyncio.get_running_loop().call_soon(self.on_connection_complete) 160 161 def on_disconnection_complete( 162 self, central_address, peripheral_address, disconnect_command 163 ): 164 # Find the controller that initiated the disconnection 165 if not (central_controller := self.find_controller(central_address)): 166 logger.warning('!!! Initiating controller not found') 167 return 168 169 # Disconnect from the first controller with a matching address 170 if peripheral_controller := self.find_controller(peripheral_address): 171 peripheral_controller.on_link_central_disconnected( 172 central_address, disconnect_command.reason 173 ) 174 175 central_controller.on_link_peripheral_disconnection_complete( 176 disconnect_command, HCI_SUCCESS 177 ) 178 179 def disconnect(self, central_address, peripheral_address, disconnect_command): 180 logger.debug( 181 f'$$$ DISCONNECTION {central_address} -> ' 182 f'{peripheral_address}: reason = {disconnect_command.reason}' 183 ) 184 args = [central_address, peripheral_address, disconnect_command] 185 asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) 186 187 # pylint: disable=too-many-arguments 188 def on_connection_encrypted( 189 self, central_address, peripheral_address, rand, ediv, ltk 190 ): 191 logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') 192 193 if central_controller := self.find_controller(central_address): 194 central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk) 195 196 if peripheral_controller := self.find_controller(peripheral_address): 197 peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk) 198 199 def create_cis( 200 self, 201 central_controller: controller.Controller, 202 peripheral_address: Address, 203 cig_id: int, 204 cis_id: int, 205 ) -> None: 206 logger.debug( 207 f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}' 208 ) 209 if peripheral_controller := self.find_controller(peripheral_address): 210 asyncio.get_running_loop().call_soon( 211 peripheral_controller.on_link_cis_request, 212 central_controller.random_address, 213 cig_id, 214 cis_id, 215 ) 216 217 def accept_cis( 218 self, 219 peripheral_controller: controller.Controller, 220 central_address: Address, 221 cig_id: int, 222 cis_id: int, 223 ) -> None: 224 logger.debug( 225 f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}' 226 ) 227 if central_controller := self.find_controller(central_address): 228 asyncio.get_running_loop().call_soon( 229 central_controller.on_link_cis_established, cig_id, cis_id 230 ) 231 asyncio.get_running_loop().call_soon( 232 peripheral_controller.on_link_cis_established, cig_id, cis_id 233 ) 234 235 def disconnect_cis( 236 self, 237 initiator_controller: controller.Controller, 238 peer_address: Address, 239 cig_id: int, 240 cis_id: int, 241 ) -> None: 242 logger.debug( 243 f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}' 244 ) 245 if peer_controller := self.find_controller(peer_address): 246 asyncio.get_running_loop().call_soon( 247 initiator_controller.on_link_cis_disconnected, cig_id, cis_id 248 ) 249 asyncio.get_running_loop().call_soon( 250 peer_controller.on_link_cis_disconnected, cig_id, cis_id 251 ) 252 253 ############################################################ 254 # Classic handlers 255 ############################################################ 256 257 def classic_connect(self, initiator_controller, responder_address): 258 logger.debug( 259 f'[Classic] {initiator_controller.public_address} connects to {responder_address}' 260 ) 261 responder_controller = self.find_classic_controller(responder_address) 262 if responder_controller is None: 263 initiator_controller.on_classic_connection_complete( 264 responder_address, HCI_PAGE_TIMEOUT_ERROR 265 ) 266 return 267 self.pending_classic_connection = (initiator_controller, responder_controller) 268 269 responder_controller.on_classic_connection_request( 270 initiator_controller.public_address, 271 HCI_Connection_Complete_Event.ACL_LINK_TYPE, 272 ) 273 274 def classic_accept_connection( 275 self, responder_controller, initiator_address, responder_role 276 ): 277 logger.debug( 278 f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}' 279 ) 280 initiator_controller = self.find_classic_controller(initiator_address) 281 if initiator_controller is None: 282 responder_controller.on_classic_connection_complete( 283 responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR 284 ) 285 return 286 287 async def task(): 288 if responder_role != BT_PERIPHERAL_ROLE: 289 initiator_controller.on_classic_role_change( 290 responder_controller.public_address, int(not (responder_role)) 291 ) 292 initiator_controller.on_classic_connection_complete( 293 responder_controller.public_address, HCI_SUCCESS 294 ) 295 296 asyncio.create_task(task()) 297 responder_controller.on_classic_role_change( 298 initiator_controller.public_address, responder_role 299 ) 300 responder_controller.on_classic_connection_complete( 301 initiator_controller.public_address, HCI_SUCCESS 302 ) 303 self.pending_classic_connection = None 304 305 def classic_disconnect(self, initiator_controller, responder_address, reason): 306 logger.debug( 307 f'[Classic] {initiator_controller.public_address} disconnects {responder_address}' 308 ) 309 responder_controller = self.find_classic_controller(responder_address) 310 311 async def task(): 312 initiator_controller.on_classic_disconnected(responder_address, reason) 313 314 asyncio.create_task(task()) 315 responder_controller.on_classic_disconnected( 316 initiator_controller.public_address, reason 317 ) 318 319 def classic_switch_role( 320 self, initiator_controller, responder_address, initiator_new_role 321 ): 322 responder_controller = self.find_classic_controller(responder_address) 323 if responder_controller is None: 324 return 325 326 async def task(): 327 initiator_controller.on_classic_role_change( 328 responder_address, initiator_new_role 329 ) 330 331 asyncio.create_task(task()) 332 responder_controller.on_classic_role_change( 333 initiator_controller.public_address, int(not (initiator_new_role)) 334 ) 335 336 def classic_sco_connect( 337 self, 338 initiator_controller: controller.Controller, 339 responder_address: Address, 340 link_type: int, 341 ): 342 logger.debug( 343 f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}' 344 ) 345 responder_controller = self.find_classic_controller(responder_address) 346 # Initiator controller should handle it. 347 assert responder_controller 348 349 responder_controller.on_classic_connection_request( 350 initiator_controller.public_address, 351 link_type, 352 ) 353 354 def classic_accept_sco_connection( 355 self, 356 responder_controller: controller.Controller, 357 initiator_address: Address, 358 link_type: int, 359 ): 360 logger.debug( 361 f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}' 362 ) 363 initiator_controller = self.find_classic_controller(initiator_address) 364 if initiator_controller is None: 365 responder_controller.on_classic_sco_connection_complete( 366 responder_controller.public_address, 367 HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR, 368 link_type, 369 ) 370 return 371 372 async def task(): 373 initiator_controller.on_classic_sco_connection_complete( 374 responder_controller.public_address, HCI_SUCCESS, link_type 375 ) 376 377 asyncio.create_task(task()) 378 responder_controller.on_classic_sco_connection_complete( 379 initiator_controller.public_address, HCI_SUCCESS, link_type 380 ) 381 382 383# ----------------------------------------------------------------------------- 384class RemoteLink: 385 ''' 386 A Link implementation that communicates with other virtual controllers via a 387 WebSocket relay 388 ''' 389 390 def __init__(self, uri): 391 self.controller = None 392 self.uri = uri 393 self.execution_queue = asyncio.Queue() 394 self.websocket = asyncio.get_running_loop().create_future() 395 self.rpc_result = None 396 self.pending_connection = None 397 self.central_connections = set() # List of addresses that we have connected to 398 self.peripheral_connections = ( 399 set() 400 ) # List of addresses that have connected to us 401 402 # Connect and run asynchronously 403 asyncio.create_task(self.run_connection()) 404 asyncio.create_task(self.run_executor_loop()) 405 406 def add_controller(self, controller): 407 if self.controller: 408 raise ValueError('controller already set') 409 self.controller = controller 410 411 def remove_controller(self, controller): 412 if self.controller != controller: 413 raise ValueError('controller mismatch') 414 self.controller = None 415 416 def get_pending_connection(self): 417 return self.pending_connection 418 419 def get_pending_classic_connection(self): 420 return self.pending_classic_connection 421 422 async def wait_until_connected(self): 423 await self.websocket 424 425 def execute(self, async_function): 426 self.execution_queue.put_nowait(async_function()) 427 428 async def run_executor_loop(self): 429 logger.debug('executor loop starting') 430 while True: 431 item = await self.execution_queue.get() 432 try: 433 await item 434 except Exception as error: 435 logger.warning( 436 f'{color("!!! Exception in async handler:", "red")} {error}' 437 ) 438 439 async def run_connection(self): 440 import websockets # lazy import 441 442 # Connect to the relay 443 logger.debug(f'connecting to {self.uri}') 444 # pylint: disable-next=no-member 445 websocket = await websockets.connect(self.uri) 446 self.websocket.set_result(websocket) 447 logger.debug(f'connected to {self.uri}') 448 449 while True: 450 message = await websocket.recv() 451 logger.debug(f'received message: {message}') 452 keyword, *payload = message.split(':', 1) 453 454 handler_name = f'on_{keyword}_received' 455 handler = getattr(self, handler_name, None) 456 if handler: 457 await handler(payload[0] if payload else None) 458 459 def close(self): 460 if self.websocket.done(): 461 logger.debug('closing websocket') 462 websocket = self.websocket.result() 463 asyncio.create_task(websocket.close()) 464 465 async def on_result_received(self, result): 466 if self.rpc_result: 467 self.rpc_result.set_result(result) 468 469 async def on_left_received(self, address): 470 if address in self.central_connections: 471 self.controller.on_link_peripheral_disconnected(Address(address)) 472 self.central_connections.remove(address) 473 474 if address in self.peripheral_connections: 475 self.controller.on_link_central_disconnected( 476 address, HCI_CONNECTION_TIMEOUT_ERROR 477 ) 478 self.peripheral_connections.remove(address) 479 480 async def on_unreachable_received(self, target): 481 await self.on_left_received(target) 482 483 async def on_message_received(self, message): 484 sender, *payload = message.split('/', 1) 485 if payload: 486 keyword, *payload = payload[0].split(':', 1) 487 handler_name = f'on_{keyword}_message_received' 488 handler = getattr(self, handler_name, None) 489 if handler: 490 await handler(sender, payload[0] if payload else None) 491 492 async def on_advertisement_message_received(self, sender, advertisement): 493 try: 494 self.controller.on_link_advertising_data( 495 Address(sender), bytes.fromhex(advertisement) 496 ) 497 except Exception: 498 logger.exception('exception') 499 500 async def on_acl_message_received(self, sender, acl_data): 501 try: 502 self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data)) 503 except Exception: 504 logger.exception('exception') 505 506 async def on_connect_message_received(self, sender, _): 507 # Remember the connection 508 self.peripheral_connections.add(sender) 509 510 # Notify the controller 511 logger.debug(f'connection from central {sender}') 512 self.controller.on_link_central_connected(Address(sender)) 513 514 # Accept the connection by responding to it 515 await self.send_targeted_message(sender, 'connected') 516 517 async def on_connected_message_received(self, sender, _): 518 if not self.pending_connection: 519 logger.warning('received a connection ack, but no connection is pending') 520 return 521 522 # Remember the connection 523 self.central_connections.add(sender) 524 525 # Notify the controller 526 logger.debug(f'connected to peripheral {self.pending_connection.peer_address}') 527 self.controller.on_link_peripheral_connection_complete( 528 self.pending_connection, HCI_SUCCESS 529 ) 530 531 async def on_disconnect_message_received(self, sender, message): 532 # Notify the controller 533 params = parse_parameters(message) 534 reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR))) 535 self.controller.on_link_central_disconnected(Address(sender), reason) 536 537 # Forget the connection 538 if sender in self.peripheral_connections: 539 self.peripheral_connections.remove(sender) 540 541 async def on_encrypted_message_received(self, sender, _): 542 # TODO parse params to get real args 543 self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16)) 544 545 async def send_rpc_command(self, command): 546 # Ensure we have a connection 547 websocket = await self.websocket 548 549 # Create a future value to hold the eventual result 550 assert self.rpc_result is None 551 self.rpc_result = asyncio.get_running_loop().create_future() 552 553 # Send the command 554 await websocket.send(command) 555 556 # Wait for the result 557 rpc_result = await self.rpc_result 558 self.rpc_result = None 559 logger.debug(f'rpc_result: {rpc_result}') 560 561 # TODO: parse the result 562 563 async def send_targeted_message(self, target, message): 564 # Ensure we have a connection 565 websocket = await self.websocket 566 567 # Send the message 568 await websocket.send(f'@{target} {message}') 569 570 async def notify_address_changed(self): 571 await self.send_rpc_command(f'/set-address {self.controller.random_address}') 572 573 def on_address_changed(self, controller): 574 logger.info(f'address changed for {controller}: {controller.random_address}') 575 576 # Notify the relay of the change 577 self.execute(self.notify_address_changed) 578 579 async def send_advertising_data_to_relay(self, data): 580 await self.send_targeted_message('*', f'advertisement:{data.hex()}') 581 582 def send_advertising_data(self, _, data): 583 self.execute(partial(self.send_advertising_data_to_relay, data)) 584 585 async def send_acl_data_to_relay(self, peer_address, data): 586 await self.send_targeted_message(peer_address, f'acl:{data.hex()}') 587 588 def send_acl_data(self, _, peer_address, _transport, data): 589 # TODO: handle different transport 590 self.execute(partial(self.send_acl_data_to_relay, peer_address, data)) 591 592 async def send_connection_request_to_relay(self, peer_address): 593 await self.send_targeted_message(peer_address, 'connect') 594 595 def connect(self, _, le_create_connection_command): 596 if self.pending_connection: 597 logger.warning('connection already pending') 598 return 599 self.pending_connection = le_create_connection_command 600 self.execute( 601 partial( 602 self.send_connection_request_to_relay, 603 str(le_create_connection_command.peer_address), 604 ) 605 ) 606 607 def on_disconnection_complete(self, disconnect_command): 608 self.controller.on_link_peripheral_disconnection_complete( 609 disconnect_command, HCI_SUCCESS 610 ) 611 612 def disconnect(self, central_address, peripheral_address, disconnect_command): 613 logger.debug( 614 f'disconnect {central_address} -> ' 615 f'{peripheral_address}: reason = {disconnect_command.reason}' 616 ) 617 self.execute( 618 partial( 619 self.send_targeted_message, 620 peripheral_address, 621 f'disconnect:reason={disconnect_command.reason}', 622 ) 623 ) 624 asyncio.get_running_loop().call_soon( 625 self.on_disconnection_complete, disconnect_command 626 ) 627 628 def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk): 629 asyncio.get_running_loop().call_soon( 630 self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk 631 ) 632 self.execute( 633 partial( 634 self.send_targeted_message, 635 peripheral_address, 636 f'encrypted:ltk={ltk.hex()}', 637 ) 638 ) 639