1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import logging 19import asyncio 20from functools import partial 21 22from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT 23from bumble.colors import color 24from bumble.hci import ( 25 Address, 26 HCI_SUCCESS, 27 HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, 28 HCI_CONNECTION_TIMEOUT_ERROR, 29 HCI_PAGE_TIMEOUT_ERROR, 30 HCI_Connection_Complete_Event, 31) 32 33# ----------------------------------------------------------------------------- 34# Logging 35# ----------------------------------------------------------------------------- 36logger = logging.getLogger(__name__) 37 38 39# ----------------------------------------------------------------------------- 40# Utils 41# ----------------------------------------------------------------------------- 42def parse_parameters(params_str): 43 result = {} 44 for param_str in params_str.split(','): 45 if '=' in param_str: 46 key, value = param_str.split('=') 47 result[key] = value 48 return result 49 50 51# ----------------------------------------------------------------------------- 52# TODO: add more support for various LL exchanges 53# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU) 54# ----------------------------------------------------------------------------- 55class LocalLink: 56 ''' 57 Link bus for controllers to communicate with each other 58 ''' 59 60 def __init__(self): 61 self.controllers = set() 62 self.pending_connection = None 63 self.pending_classic_connection = None 64 65 ############################################################ 66 # Common utils 67 ############################################################ 68 69 def add_controller(self, controller): 70 logger.debug(f'new controller: {controller}') 71 self.controllers.add(controller) 72 73 def remove_controller(self, controller): 74 self.controllers.remove(controller) 75 76 def find_controller(self, address): 77 for controller in self.controllers: 78 if controller.random_address == address: 79 return controller 80 return None 81 82 def find_classic_controller(self, address): 83 for controller in self.controllers: 84 if controller.public_address == address: 85 return controller 86 return None 87 88 def get_pending_connection(self): 89 return self.pending_connection 90 91 ############################################################ 92 # LE handlers 93 ############################################################ 94 95 def on_address_changed(self, controller): 96 pass 97 98 def send_advertising_data(self, sender_address, data): 99 # Send the advertising data to all controllers, except the sender 100 for controller in self.controllers: 101 if controller.random_address != sender_address: 102 controller.on_link_advertising_data(sender_address, data) 103 104 def send_acl_data(self, sender_controller, destination_address, transport, data): 105 # Send the data to the first controller with a matching address 106 if transport == BT_LE_TRANSPORT: 107 destination_controller = self.find_controller(destination_address) 108 source_address = sender_controller.random_address 109 elif transport == BT_BR_EDR_TRANSPORT: 110 destination_controller = self.find_classic_controller(destination_address) 111 source_address = sender_controller.public_address 112 113 if destination_controller is not None: 114 destination_controller.on_link_acl_data(source_address, transport, data) 115 116 def on_connection_complete(self): 117 # Check that we expect this call 118 if not self.pending_connection: 119 logger.warning('on_connection_complete with no pending connection') 120 return 121 122 central_address, le_create_connection_command = self.pending_connection 123 self.pending_connection = None 124 125 # Find the controller that initiated the connection 126 if not (central_controller := self.find_controller(central_address)): 127 logger.warning('!!! Initiating controller not found') 128 return 129 130 # Connect to the first controller with a matching address 131 if peripheral_controller := self.find_controller( 132 le_create_connection_command.peer_address 133 ): 134 central_controller.on_link_peripheral_connection_complete( 135 le_create_connection_command, HCI_SUCCESS 136 ) 137 peripheral_controller.on_link_central_connected(central_address) 138 return 139 140 # No peripheral found 141 central_controller.on_link_peripheral_connection_complete( 142 le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR 143 ) 144 145 def connect(self, central_address, le_create_connection_command): 146 logger.debug( 147 f'$$$ CONNECTION {central_address} -> ' 148 f'{le_create_connection_command.peer_address}' 149 ) 150 self.pending_connection = (central_address, le_create_connection_command) 151 asyncio.get_running_loop().call_soon(self.on_connection_complete) 152 153 def on_disconnection_complete( 154 self, central_address, peripheral_address, disconnect_command 155 ): 156 # Find the controller that initiated the disconnection 157 if not (central_controller := self.find_controller(central_address)): 158 logger.warning('!!! Initiating controller not found') 159 return 160 161 # Disconnect from the first controller with a matching address 162 if peripheral_controller := self.find_controller(peripheral_address): 163 peripheral_controller.on_link_central_disconnected( 164 central_address, disconnect_command.reason 165 ) 166 167 central_controller.on_link_peripheral_disconnection_complete( 168 disconnect_command, HCI_SUCCESS 169 ) 170 171 def disconnect(self, central_address, peripheral_address, disconnect_command): 172 logger.debug( 173 f'$$$ DISCONNECTION {central_address} -> ' 174 f'{peripheral_address}: reason = {disconnect_command.reason}' 175 ) 176 args = [central_address, peripheral_address, disconnect_command] 177 asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) 178 179 # pylint: disable=too-many-arguments 180 def on_connection_encrypted( 181 self, central_address, peripheral_address, rand, ediv, ltk 182 ): 183 logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') 184 185 if central_controller := self.find_controller(central_address): 186 central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk) 187 188 if peripheral_controller := self.find_controller(peripheral_address): 189 peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk) 190 191 ############################################################ 192 # Classic handlers 193 ############################################################ 194 195 def classic_connect(self, initiator_controller, responder_address): 196 logger.debug( 197 f'[Classic] {initiator_controller.public_address} connects to {responder_address}' 198 ) 199 responder_controller = self.find_classic_controller(responder_address) 200 if responder_controller is None: 201 initiator_controller.on_classic_connection_complete( 202 responder_address, HCI_PAGE_TIMEOUT_ERROR 203 ) 204 return 205 self.pending_classic_connection = (initiator_controller, responder_controller) 206 207 responder_controller.on_classic_connection_request( 208 initiator_controller.public_address, 209 HCI_Connection_Complete_Event.ACL_LINK_TYPE, 210 ) 211 212 def classic_accept_connection( 213 self, responder_controller, initiator_address, responder_role 214 ): 215 logger.debug( 216 f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}' 217 ) 218 initiator_controller = self.find_classic_controller(initiator_address) 219 if initiator_controller is None: 220 responder_controller.on_classic_connection_complete( 221 responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR 222 ) 223 return 224 225 async def task(): 226 if responder_role != BT_PERIPHERAL_ROLE: 227 initiator_controller.on_classic_role_change( 228 responder_controller.public_address, int(not (responder_role)) 229 ) 230 initiator_controller.on_classic_connection_complete( 231 responder_controller.public_address, HCI_SUCCESS 232 ) 233 234 asyncio.create_task(task()) 235 responder_controller.on_classic_role_change( 236 initiator_controller.public_address, responder_role 237 ) 238 responder_controller.on_classic_connection_complete( 239 initiator_controller.public_address, HCI_SUCCESS 240 ) 241 self.pending_classic_connection = None 242 243 def classic_disconnect(self, initiator_controller, responder_address, reason): 244 logger.debug( 245 f'[Classic] {initiator_controller.public_address} disconnects {responder_address}' 246 ) 247 responder_controller = self.find_classic_controller(responder_address) 248 249 async def task(): 250 initiator_controller.on_classic_disconnected(responder_address, reason) 251 252 asyncio.create_task(task()) 253 responder_controller.on_classic_disconnected( 254 initiator_controller.public_address, reason 255 ) 256 257 def classic_switch_role( 258 self, initiator_controller, responder_address, initiator_new_role 259 ): 260 responder_controller = self.find_classic_controller(responder_address) 261 if responder_controller is None: 262 return 263 264 async def task(): 265 initiator_controller.on_classic_role_change( 266 responder_address, initiator_new_role 267 ) 268 269 asyncio.create_task(task()) 270 responder_controller.on_classic_role_change( 271 initiator_controller.public_address, int(not (initiator_new_role)) 272 ) 273 274 275# ----------------------------------------------------------------------------- 276class RemoteLink: 277 ''' 278 A Link implementation that communicates with other virtual controllers via a 279 WebSocket relay 280 ''' 281 282 def __init__(self, uri): 283 self.controller = None 284 self.uri = uri 285 self.execution_queue = asyncio.Queue() 286 self.websocket = asyncio.get_running_loop().create_future() 287 self.rpc_result = None 288 self.pending_connection = None 289 self.central_connections = set() # List of addresses that we have connected to 290 self.peripheral_connections = ( 291 set() 292 ) # List of addresses that have connected to us 293 294 # Connect and run asynchronously 295 asyncio.create_task(self.run_connection()) 296 asyncio.create_task(self.run_executor_loop()) 297 298 def add_controller(self, controller): 299 if self.controller: 300 raise ValueError('controller already set') 301 self.controller = controller 302 303 def remove_controller(self, controller): 304 if self.controller != controller: 305 raise ValueError('controller mismatch') 306 self.controller = None 307 308 def get_pending_connection(self): 309 return self.pending_connection 310 311 def get_pending_classic_connection(self): 312 return self.pending_classic_connection 313 314 async def wait_until_connected(self): 315 await self.websocket 316 317 def execute(self, async_function): 318 self.execution_queue.put_nowait(async_function()) 319 320 async def run_executor_loop(self): 321 logger.debug('executor loop starting') 322 while True: 323 item = await self.execution_queue.get() 324 try: 325 await item 326 except Exception as error: 327 logger.warning( 328 f'{color("!!! Exception in async handler:", "red")} {error}' 329 ) 330 331 async def run_connection(self): 332 import websockets # lazy import 333 334 # Connect to the relay 335 logger.debug(f'connecting to {self.uri}') 336 # pylint: disable-next=no-member 337 websocket = await websockets.connect(self.uri) 338 self.websocket.set_result(websocket) 339 logger.debug(f'connected to {self.uri}') 340 341 while True: 342 message = await websocket.recv() 343 logger.debug(f'received message: {message}') 344 keyword, *payload = message.split(':', 1) 345 346 handler_name = f'on_{keyword}_received' 347 handler = getattr(self, handler_name, None) 348 if handler: 349 await handler(payload[0] if payload else None) 350 351 def close(self): 352 if self.websocket.done(): 353 logger.debug('closing websocket') 354 websocket = self.websocket.result() 355 asyncio.create_task(websocket.close()) 356 357 async def on_result_received(self, result): 358 if self.rpc_result: 359 self.rpc_result.set_result(result) 360 361 async def on_left_received(self, address): 362 if address in self.central_connections: 363 self.controller.on_link_peripheral_disconnected(Address(address)) 364 self.central_connections.remove(address) 365 366 if address in self.peripheral_connections: 367 self.controller.on_link_central_disconnected( 368 address, HCI_CONNECTION_TIMEOUT_ERROR 369 ) 370 self.peripheral_connections.remove(address) 371 372 async def on_unreachable_received(self, target): 373 await self.on_left_received(target) 374 375 async def on_message_received(self, message): 376 sender, *payload = message.split('/', 1) 377 if payload: 378 keyword, *payload = payload[0].split(':', 1) 379 handler_name = f'on_{keyword}_message_received' 380 handler = getattr(self, handler_name, None) 381 if handler: 382 await handler(sender, payload[0] if payload else None) 383 384 async def on_advertisement_message_received(self, sender, advertisement): 385 try: 386 self.controller.on_link_advertising_data( 387 Address(sender), bytes.fromhex(advertisement) 388 ) 389 except Exception: 390 logger.exception('exception') 391 392 async def on_acl_message_received(self, sender, acl_data): 393 try: 394 self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data)) 395 except Exception: 396 logger.exception('exception') 397 398 async def on_connect_message_received(self, sender, _): 399 # Remember the connection 400 self.peripheral_connections.add(sender) 401 402 # Notify the controller 403 logger.debug(f'connection from central {sender}') 404 self.controller.on_link_central_connected(Address(sender)) 405 406 # Accept the connection by responding to it 407 await self.send_targeted_message(sender, 'connected') 408 409 async def on_connected_message_received(self, sender, _): 410 if not self.pending_connection: 411 logger.warning('received a connection ack, but no connection is pending') 412 return 413 414 # Remember the connection 415 self.central_connections.add(sender) 416 417 # Notify the controller 418 logger.debug(f'connected to peripheral {self.pending_connection.peer_address}') 419 self.controller.on_link_peripheral_connection_complete( 420 self.pending_connection, HCI_SUCCESS 421 ) 422 423 async def on_disconnect_message_received(self, sender, message): 424 # Notify the controller 425 params = parse_parameters(message) 426 reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR))) 427 self.controller.on_link_central_disconnected(Address(sender), reason) 428 429 # Forget the connection 430 if sender in self.peripheral_connections: 431 self.peripheral_connections.remove(sender) 432 433 async def on_encrypted_message_received(self, sender, _): 434 # TODO parse params to get real args 435 self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16)) 436 437 async def send_rpc_command(self, command): 438 # Ensure we have a connection 439 websocket = await self.websocket 440 441 # Create a future value to hold the eventual result 442 assert self.rpc_result is None 443 self.rpc_result = asyncio.get_running_loop().create_future() 444 445 # Send the command 446 await websocket.send(command) 447 448 # Wait for the result 449 rpc_result = await self.rpc_result 450 self.rpc_result = None 451 logger.debug(f'rpc_result: {rpc_result}') 452 453 # TODO: parse the result 454 455 async def send_targeted_message(self, target, message): 456 # Ensure we have a connection 457 websocket = await self.websocket 458 459 # Send the message 460 await websocket.send(f'@{target} {message}') 461 462 async def notify_address_changed(self): 463 await self.send_rpc_command(f'/set-address {self.controller.random_address}') 464 465 def on_address_changed(self, controller): 466 logger.info(f'address changed for {controller}: {controller.random_address}') 467 468 # Notify the relay of the change 469 self.execute(self.notify_address_changed) 470 471 async def send_advertising_data_to_relay(self, data): 472 await self.send_targeted_message('*', f'advertisement:{data.hex()}') 473 474 def send_advertising_data(self, _, data): 475 self.execute(partial(self.send_advertising_data_to_relay, data)) 476 477 async def send_acl_data_to_relay(self, peer_address, data): 478 await self.send_targeted_message(peer_address, f'acl:{data.hex()}') 479 480 def send_acl_data(self, _, peer_address, _transport, data): 481 # TODO: handle different transport 482 self.execute(partial(self.send_acl_data_to_relay, peer_address, data)) 483 484 async def send_connection_request_to_relay(self, peer_address): 485 await self.send_targeted_message(peer_address, 'connect') 486 487 def connect(self, _, le_create_connection_command): 488 if self.pending_connection: 489 logger.warning('connection already pending') 490 return 491 self.pending_connection = le_create_connection_command 492 self.execute( 493 partial( 494 self.send_connection_request_to_relay, 495 str(le_create_connection_command.peer_address), 496 ) 497 ) 498 499 def on_disconnection_complete(self, disconnect_command): 500 self.controller.on_link_peripheral_disconnection_complete( 501 disconnect_command, HCI_SUCCESS 502 ) 503 504 def disconnect(self, central_address, peripheral_address, disconnect_command): 505 logger.debug( 506 f'disconnect {central_address} -> ' 507 f'{peripheral_address}: reason = {disconnect_command.reason}' 508 ) 509 self.execute( 510 partial( 511 self.send_targeted_message, 512 peripheral_address, 513 f'disconnect:reason={disconnect_command.reason}', 514 ) 515 ) 516 asyncio.get_running_loop().call_soon( 517 self.on_disconnection_complete, disconnect_command 518 ) 519 520 def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk): 521 asyncio.get_running_loop().call_soon( 522 self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk 523 ) 524 self.execute( 525 partial( 526 self.send_targeted_message, 527 peripheral_address, 528 f'encrypted:ltk={ltk.hex()}', 529 ) 530 ) 531