# Copyright 2021-2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- import logging import asyncio from functools import partial from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT from bumble.colors import color from bumble.hci import ( Address, HCI_SUCCESS, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, HCI_CONNECTION_TIMEOUT_ERROR, HCI_PAGE_TIMEOUT_ERROR, HCI_Connection_Complete_Event, ) # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Utils # ----------------------------------------------------------------------------- def parse_parameters(params_str): result = {} for param_str in params_str.split(','): if '=' in param_str: key, value = param_str.split('=') result[key] = value return result # ----------------------------------------------------------------------------- # TODO: add more support for various LL exchanges # (see Vol 6, Part B - 2.4 DATA CHANNEL PDU) # ----------------------------------------------------------------------------- class LocalLink: ''' Link bus for controllers to communicate with each other ''' def __init__(self): self.controllers = set() self.pending_connection = None self.pending_classic_connection = None ############################################################ # Common utils ############################################################ def add_controller(self, controller): logger.debug(f'new controller: {controller}') self.controllers.add(controller) def remove_controller(self, controller): self.controllers.remove(controller) def find_controller(self, address): for controller in self.controllers: if controller.random_address == address: return controller return None def find_classic_controller(self, address): for controller in self.controllers: if controller.public_address == address: return controller return None def get_pending_connection(self): return self.pending_connection ############################################################ # LE handlers ############################################################ def on_address_changed(self, controller): pass def send_advertising_data(self, sender_address, data): # Send the advertising data to all controllers, except the sender for controller in self.controllers: if controller.random_address != sender_address: controller.on_link_advertising_data(sender_address, data) def send_acl_data(self, sender_controller, destination_address, transport, data): # Send the data to the first controller with a matching address if transport == BT_LE_TRANSPORT: destination_controller = self.find_controller(destination_address) source_address = sender_controller.random_address elif transport == BT_BR_EDR_TRANSPORT: destination_controller = self.find_classic_controller(destination_address) source_address = sender_controller.public_address if destination_controller is not None: destination_controller.on_link_acl_data(source_address, transport, data) def on_connection_complete(self): # Check that we expect this call if not self.pending_connection: logger.warning('on_connection_complete with no pending connection') return central_address, le_create_connection_command = self.pending_connection self.pending_connection = None # Find the controller that initiated the connection if not (central_controller := self.find_controller(central_address)): logger.warning('!!! Initiating controller not found') return # Connect to the first controller with a matching address if peripheral_controller := self.find_controller( le_create_connection_command.peer_address ): central_controller.on_link_peripheral_connection_complete( le_create_connection_command, HCI_SUCCESS ) peripheral_controller.on_link_central_connected(central_address) return # No peripheral found central_controller.on_link_peripheral_connection_complete( le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR ) def connect(self, central_address, le_create_connection_command): logger.debug( f'$$$ CONNECTION {central_address} -> ' f'{le_create_connection_command.peer_address}' ) self.pending_connection = (central_address, le_create_connection_command) asyncio.get_running_loop().call_soon(self.on_connection_complete) def on_disconnection_complete( self, central_address, peripheral_address, disconnect_command ): # Find the controller that initiated the disconnection if not (central_controller := self.find_controller(central_address)): logger.warning('!!! Initiating controller not found') return # Disconnect from the first controller with a matching address if peripheral_controller := self.find_controller(peripheral_address): peripheral_controller.on_link_central_disconnected( central_address, disconnect_command.reason ) central_controller.on_link_peripheral_disconnection_complete( disconnect_command, HCI_SUCCESS ) def disconnect(self, central_address, peripheral_address, disconnect_command): logger.debug( f'$$$ DISCONNECTION {central_address} -> ' f'{peripheral_address}: reason = {disconnect_command.reason}' ) args = [central_address, peripheral_address, disconnect_command] asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) # pylint: disable=too-many-arguments def on_connection_encrypted( self, central_address, peripheral_address, rand, ediv, ltk ): logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') if central_controller := self.find_controller(central_address): central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk) if peripheral_controller := self.find_controller(peripheral_address): peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk) ############################################################ # Classic handlers ############################################################ def classic_connect(self, initiator_controller, responder_address): logger.debug( f'[Classic] {initiator_controller.public_address} connects to {responder_address}' ) responder_controller = self.find_classic_controller(responder_address) if responder_controller is None: initiator_controller.on_classic_connection_complete( responder_address, HCI_PAGE_TIMEOUT_ERROR ) return self.pending_classic_connection = (initiator_controller, responder_controller) responder_controller.on_classic_connection_request( initiator_controller.public_address, HCI_Connection_Complete_Event.ACL_LINK_TYPE, ) def classic_accept_connection( self, responder_controller, initiator_address, responder_role ): logger.debug( f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}' ) initiator_controller = self.find_classic_controller(initiator_address) if initiator_controller is None: responder_controller.on_classic_connection_complete( responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR ) return async def task(): if responder_role != BT_PERIPHERAL_ROLE: initiator_controller.on_classic_role_change( responder_controller.public_address, int(not (responder_role)) ) initiator_controller.on_classic_connection_complete( responder_controller.public_address, HCI_SUCCESS ) asyncio.create_task(task()) responder_controller.on_classic_role_change( initiator_controller.public_address, responder_role ) responder_controller.on_classic_connection_complete( initiator_controller.public_address, HCI_SUCCESS ) self.pending_classic_connection = None def classic_disconnect(self, initiator_controller, responder_address, reason): logger.debug( f'[Classic] {initiator_controller.public_address} disconnects {responder_address}' ) responder_controller = self.find_classic_controller(responder_address) async def task(): initiator_controller.on_classic_disconnected(responder_address, reason) asyncio.create_task(task()) responder_controller.on_classic_disconnected( initiator_controller.public_address, reason ) def classic_switch_role( self, initiator_controller, responder_address, initiator_new_role ): responder_controller = self.find_classic_controller(responder_address) if responder_controller is None: return async def task(): initiator_controller.on_classic_role_change( responder_address, initiator_new_role ) asyncio.create_task(task()) responder_controller.on_classic_role_change( initiator_controller.public_address, int(not (initiator_new_role)) ) # ----------------------------------------------------------------------------- class RemoteLink: ''' A Link implementation that communicates with other virtual controllers via a WebSocket relay ''' def __init__(self, uri): self.controller = None self.uri = uri self.execution_queue = asyncio.Queue() self.websocket = asyncio.get_running_loop().create_future() self.rpc_result = None self.pending_connection = None self.central_connections = set() # List of addresses that we have connected to self.peripheral_connections = ( set() ) # List of addresses that have connected to us # Connect and run asynchronously asyncio.create_task(self.run_connection()) asyncio.create_task(self.run_executor_loop()) def add_controller(self, controller): if self.controller: raise ValueError('controller already set') self.controller = controller def remove_controller(self, controller): if self.controller != controller: raise ValueError('controller mismatch') self.controller = None def get_pending_connection(self): return self.pending_connection def get_pending_classic_connection(self): return self.pending_classic_connection async def wait_until_connected(self): await self.websocket def execute(self, async_function): self.execution_queue.put_nowait(async_function()) async def run_executor_loop(self): logger.debug('executor loop starting') while True: item = await self.execution_queue.get() try: await item except Exception as error: logger.warning( f'{color("!!! Exception in async handler:", "red")} {error}' ) async def run_connection(self): import websockets # lazy import # Connect to the relay logger.debug(f'connecting to {self.uri}') # pylint: disable-next=no-member websocket = await websockets.connect(self.uri) self.websocket.set_result(websocket) logger.debug(f'connected to {self.uri}') while True: message = await websocket.recv() logger.debug(f'received message: {message}') keyword, *payload = message.split(':', 1) handler_name = f'on_{keyword}_received' handler = getattr(self, handler_name, None) if handler: await handler(payload[0] if payload else None) def close(self): if self.websocket.done(): logger.debug('closing websocket') websocket = self.websocket.result() asyncio.create_task(websocket.close()) async def on_result_received(self, result): if self.rpc_result: self.rpc_result.set_result(result) async def on_left_received(self, address): if address in self.central_connections: self.controller.on_link_peripheral_disconnected(Address(address)) self.central_connections.remove(address) if address in self.peripheral_connections: self.controller.on_link_central_disconnected( address, HCI_CONNECTION_TIMEOUT_ERROR ) self.peripheral_connections.remove(address) async def on_unreachable_received(self, target): await self.on_left_received(target) async def on_message_received(self, message): sender, *payload = message.split('/', 1) if payload: keyword, *payload = payload[0].split(':', 1) handler_name = f'on_{keyword}_message_received' handler = getattr(self, handler_name, None) if handler: await handler(sender, payload[0] if payload else None) async def on_advertisement_message_received(self, sender, advertisement): try: self.controller.on_link_advertising_data( Address(sender), bytes.fromhex(advertisement) ) except Exception: logger.exception('exception') async def on_acl_message_received(self, sender, acl_data): try: self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data)) except Exception: logger.exception('exception') async def on_connect_message_received(self, sender, _): # Remember the connection self.peripheral_connections.add(sender) # Notify the controller logger.debug(f'connection from central {sender}') self.controller.on_link_central_connected(Address(sender)) # Accept the connection by responding to it await self.send_targeted_message(sender, 'connected') async def on_connected_message_received(self, sender, _): if not self.pending_connection: logger.warning('received a connection ack, but no connection is pending') return # Remember the connection self.central_connections.add(sender) # Notify the controller logger.debug(f'connected to peripheral {self.pending_connection.peer_address}') self.controller.on_link_peripheral_connection_complete( self.pending_connection, HCI_SUCCESS ) async def on_disconnect_message_received(self, sender, message): # Notify the controller params = parse_parameters(message) reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR))) self.controller.on_link_central_disconnected(Address(sender), reason) # Forget the connection if sender in self.peripheral_connections: self.peripheral_connections.remove(sender) async def on_encrypted_message_received(self, sender, _): # TODO parse params to get real args self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16)) async def send_rpc_command(self, command): # Ensure we have a connection websocket = await self.websocket # Create a future value to hold the eventual result assert self.rpc_result is None self.rpc_result = asyncio.get_running_loop().create_future() # Send the command await websocket.send(command) # Wait for the result rpc_result = await self.rpc_result self.rpc_result = None logger.debug(f'rpc_result: {rpc_result}') # TODO: parse the result async def send_targeted_message(self, target, message): # Ensure we have a connection websocket = await self.websocket # Send the message await websocket.send(f'@{target} {message}') async def notify_address_changed(self): await self.send_rpc_command(f'/set-address {self.controller.random_address}') def on_address_changed(self, controller): logger.info(f'address changed for {controller}: {controller.random_address}') # Notify the relay of the change self.execute(self.notify_address_changed) async def send_advertising_data_to_relay(self, data): await self.send_targeted_message('*', f'advertisement:{data.hex()}') def send_advertising_data(self, _, data): self.execute(partial(self.send_advertising_data_to_relay, data)) async def send_acl_data_to_relay(self, peer_address, data): await self.send_targeted_message(peer_address, f'acl:{data.hex()}') def send_acl_data(self, _, peer_address, _transport, data): # TODO: handle different transport self.execute(partial(self.send_acl_data_to_relay, peer_address, data)) async def send_connection_request_to_relay(self, peer_address): await self.send_targeted_message(peer_address, 'connect') def connect(self, _, le_create_connection_command): if self.pending_connection: logger.warning('connection already pending') return self.pending_connection = le_create_connection_command self.execute( partial( self.send_connection_request_to_relay, str(le_create_connection_command.peer_address), ) ) def on_disconnection_complete(self, disconnect_command): self.controller.on_link_peripheral_disconnection_complete( disconnect_command, HCI_SUCCESS ) def disconnect(self, central_address, peripheral_address, disconnect_command): logger.debug( f'disconnect {central_address} -> ' f'{peripheral_address}: reason = {disconnect_command.reason}' ) self.execute( partial( self.send_targeted_message, peripheral_address, f'disconnect:reason={disconnect_command.reason}', ) ) asyncio.get_running_loop().call_soon( self.on_disconnection_complete, disconnect_command ) def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk): asyncio.get_running_loop().call_soon( self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk ) self.execute( partial( self.send_targeted_message, peripheral_address, f'encrypted:ltk={ltk.hex()}', ) )