• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import logging
19import asyncio
20from functools import partial
21
22from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
23from bumble.colors import color
24from bumble.hci import (
25    Address,
26    HCI_SUCCESS,
27    HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
28    HCI_CONNECTION_TIMEOUT_ERROR,
29    HCI_PAGE_TIMEOUT_ERROR,
30    HCI_Connection_Complete_Event,
31)
32
33# -----------------------------------------------------------------------------
34# Logging
35# -----------------------------------------------------------------------------
36logger = logging.getLogger(__name__)
37
38
39# -----------------------------------------------------------------------------
40# Utils
41# -----------------------------------------------------------------------------
42def parse_parameters(params_str):
43    result = {}
44    for param_str in params_str.split(','):
45        if '=' in param_str:
46            key, value = param_str.split('=')
47            result[key] = value
48    return result
49
50
51# -----------------------------------------------------------------------------
52# TODO: add more support for various LL exchanges
53# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
54# -----------------------------------------------------------------------------
55class LocalLink:
56    '''
57    Link bus for controllers to communicate with each other
58    '''
59
60    def __init__(self):
61        self.controllers = set()
62        self.pending_connection = None
63        self.pending_classic_connection = None
64
65    ############################################################
66    # Common utils
67    ############################################################
68
69    def add_controller(self, controller):
70        logger.debug(f'new controller: {controller}')
71        self.controllers.add(controller)
72
73    def remove_controller(self, controller):
74        self.controllers.remove(controller)
75
76    def find_controller(self, address):
77        for controller in self.controllers:
78            if controller.random_address == address:
79                return controller
80        return None
81
82    def find_classic_controller(self, address):
83        for controller in self.controllers:
84            if controller.public_address == address:
85                return controller
86        return None
87
88    def get_pending_connection(self):
89        return self.pending_connection
90
91    ############################################################
92    # LE handlers
93    ############################################################
94
95    def on_address_changed(self, controller):
96        pass
97
98    def send_advertising_data(self, sender_address, data):
99        # Send the advertising data to all controllers, except the sender
100        for controller in self.controllers:
101            if controller.random_address != sender_address:
102                controller.on_link_advertising_data(sender_address, data)
103
104    def send_acl_data(self, sender_controller, destination_address, transport, data):
105        # Send the data to the first controller with a matching address
106        if transport == BT_LE_TRANSPORT:
107            destination_controller = self.find_controller(destination_address)
108            source_address = sender_controller.random_address
109        elif transport == BT_BR_EDR_TRANSPORT:
110            destination_controller = self.find_classic_controller(destination_address)
111            source_address = sender_controller.public_address
112
113        if destination_controller is not None:
114            destination_controller.on_link_acl_data(source_address, transport, data)
115
116    def on_connection_complete(self):
117        # Check that we expect this call
118        if not self.pending_connection:
119            logger.warning('on_connection_complete with no pending connection')
120            return
121
122        central_address, le_create_connection_command = self.pending_connection
123        self.pending_connection = None
124
125        # Find the controller that initiated the connection
126        if not (central_controller := self.find_controller(central_address)):
127            logger.warning('!!! Initiating controller not found')
128            return
129
130        # Connect to the first controller with a matching address
131        if peripheral_controller := self.find_controller(
132            le_create_connection_command.peer_address
133        ):
134            central_controller.on_link_peripheral_connection_complete(
135                le_create_connection_command, HCI_SUCCESS
136            )
137            peripheral_controller.on_link_central_connected(central_address)
138            return
139
140        # No peripheral found
141        central_controller.on_link_peripheral_connection_complete(
142            le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
143        )
144
145    def connect(self, central_address, le_create_connection_command):
146        logger.debug(
147            f'$$$ CONNECTION {central_address} -> '
148            f'{le_create_connection_command.peer_address}'
149        )
150        self.pending_connection = (central_address, le_create_connection_command)
151        asyncio.get_running_loop().call_soon(self.on_connection_complete)
152
153    def on_disconnection_complete(
154        self, central_address, peripheral_address, disconnect_command
155    ):
156        # Find the controller that initiated the disconnection
157        if not (central_controller := self.find_controller(central_address)):
158            logger.warning('!!! Initiating controller not found')
159            return
160
161        # Disconnect from the first controller with a matching address
162        if peripheral_controller := self.find_controller(peripheral_address):
163            peripheral_controller.on_link_central_disconnected(
164                central_address, disconnect_command.reason
165            )
166
167        central_controller.on_link_peripheral_disconnection_complete(
168            disconnect_command, HCI_SUCCESS
169        )
170
171    def disconnect(self, central_address, peripheral_address, disconnect_command):
172        logger.debug(
173            f'$$$ DISCONNECTION {central_address} -> '
174            f'{peripheral_address}: reason = {disconnect_command.reason}'
175        )
176        args = [central_address, peripheral_address, disconnect_command]
177        asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
178
179    # pylint: disable=too-many-arguments
180    def on_connection_encrypted(
181        self, central_address, peripheral_address, rand, ediv, ltk
182    ):
183        logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
184
185        if central_controller := self.find_controller(central_address):
186            central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk)
187
188        if peripheral_controller := self.find_controller(peripheral_address):
189            peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
190
191    ############################################################
192    # Classic handlers
193    ############################################################
194
195    def classic_connect(self, initiator_controller, responder_address):
196        logger.debug(
197            f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
198        )
199        responder_controller = self.find_classic_controller(responder_address)
200        if responder_controller is None:
201            initiator_controller.on_classic_connection_complete(
202                responder_address, HCI_PAGE_TIMEOUT_ERROR
203            )
204            return
205        self.pending_classic_connection = (initiator_controller, responder_controller)
206
207        responder_controller.on_classic_connection_request(
208            initiator_controller.public_address,
209            HCI_Connection_Complete_Event.ACL_LINK_TYPE,
210        )
211
212    def classic_accept_connection(
213        self, responder_controller, initiator_address, responder_role
214    ):
215        logger.debug(
216            f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
217        )
218        initiator_controller = self.find_classic_controller(initiator_address)
219        if initiator_controller is None:
220            responder_controller.on_classic_connection_complete(
221                responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
222            )
223            return
224
225        async def task():
226            if responder_role != BT_PERIPHERAL_ROLE:
227                initiator_controller.on_classic_role_change(
228                    responder_controller.public_address, int(not (responder_role))
229                )
230            initiator_controller.on_classic_connection_complete(
231                responder_controller.public_address, HCI_SUCCESS
232            )
233
234        asyncio.create_task(task())
235        responder_controller.on_classic_role_change(
236            initiator_controller.public_address, responder_role
237        )
238        responder_controller.on_classic_connection_complete(
239            initiator_controller.public_address, HCI_SUCCESS
240        )
241        self.pending_classic_connection = None
242
243    def classic_disconnect(self, initiator_controller, responder_address, reason):
244        logger.debug(
245            f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
246        )
247        responder_controller = self.find_classic_controller(responder_address)
248
249        async def task():
250            initiator_controller.on_classic_disconnected(responder_address, reason)
251
252        asyncio.create_task(task())
253        responder_controller.on_classic_disconnected(
254            initiator_controller.public_address, reason
255        )
256
257    def classic_switch_role(
258        self, initiator_controller, responder_address, initiator_new_role
259    ):
260        responder_controller = self.find_classic_controller(responder_address)
261        if responder_controller is None:
262            return
263
264        async def task():
265            initiator_controller.on_classic_role_change(
266                responder_address, initiator_new_role
267            )
268
269        asyncio.create_task(task())
270        responder_controller.on_classic_role_change(
271            initiator_controller.public_address, int(not (initiator_new_role))
272        )
273
274
275# -----------------------------------------------------------------------------
276class RemoteLink:
277    '''
278    A Link implementation that communicates with other virtual controllers via a
279    WebSocket relay
280    '''
281
282    def __init__(self, uri):
283        self.controller = None
284        self.uri = uri
285        self.execution_queue = asyncio.Queue()
286        self.websocket = asyncio.get_running_loop().create_future()
287        self.rpc_result = None
288        self.pending_connection = None
289        self.central_connections = set()  # List of addresses that we have connected to
290        self.peripheral_connections = (
291            set()
292        )  # List of addresses that have connected to us
293
294        # Connect and run asynchronously
295        asyncio.create_task(self.run_connection())
296        asyncio.create_task(self.run_executor_loop())
297
298    def add_controller(self, controller):
299        if self.controller:
300            raise ValueError('controller already set')
301        self.controller = controller
302
303    def remove_controller(self, controller):
304        if self.controller != controller:
305            raise ValueError('controller mismatch')
306        self.controller = None
307
308    def get_pending_connection(self):
309        return self.pending_connection
310
311    def get_pending_classic_connection(self):
312        return self.pending_classic_connection
313
314    async def wait_until_connected(self):
315        await self.websocket
316
317    def execute(self, async_function):
318        self.execution_queue.put_nowait(async_function())
319
320    async def run_executor_loop(self):
321        logger.debug('executor loop starting')
322        while True:
323            item = await self.execution_queue.get()
324            try:
325                await item
326            except Exception as error:
327                logger.warning(
328                    f'{color("!!! Exception in async handler:", "red")} {error}'
329                )
330
331    async def run_connection(self):
332        import websockets  # lazy import
333
334        # Connect to the relay
335        logger.debug(f'connecting to {self.uri}')
336        # pylint: disable-next=no-member
337        websocket = await websockets.connect(self.uri)
338        self.websocket.set_result(websocket)
339        logger.debug(f'connected to {self.uri}')
340
341        while True:
342            message = await websocket.recv()
343            logger.debug(f'received message: {message}')
344            keyword, *payload = message.split(':', 1)
345
346            handler_name = f'on_{keyword}_received'
347            handler = getattr(self, handler_name, None)
348            if handler:
349                await handler(payload[0] if payload else None)
350
351    def close(self):
352        if self.websocket.done():
353            logger.debug('closing websocket')
354            websocket = self.websocket.result()
355            asyncio.create_task(websocket.close())
356
357    async def on_result_received(self, result):
358        if self.rpc_result:
359            self.rpc_result.set_result(result)
360
361    async def on_left_received(self, address):
362        if address in self.central_connections:
363            self.controller.on_link_peripheral_disconnected(Address(address))
364            self.central_connections.remove(address)
365
366        if address in self.peripheral_connections:
367            self.controller.on_link_central_disconnected(
368                address, HCI_CONNECTION_TIMEOUT_ERROR
369            )
370            self.peripheral_connections.remove(address)
371
372    async def on_unreachable_received(self, target):
373        await self.on_left_received(target)
374
375    async def on_message_received(self, message):
376        sender, *payload = message.split('/', 1)
377        if payload:
378            keyword, *payload = payload[0].split(':', 1)
379            handler_name = f'on_{keyword}_message_received'
380            handler = getattr(self, handler_name, None)
381            if handler:
382                await handler(sender, payload[0] if payload else None)
383
384    async def on_advertisement_message_received(self, sender, advertisement):
385        try:
386            self.controller.on_link_advertising_data(
387                Address(sender), bytes.fromhex(advertisement)
388            )
389        except Exception:
390            logger.exception('exception')
391
392    async def on_acl_message_received(self, sender, acl_data):
393        try:
394            self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data))
395        except Exception:
396            logger.exception('exception')
397
398    async def on_connect_message_received(self, sender, _):
399        # Remember the connection
400        self.peripheral_connections.add(sender)
401
402        # Notify the controller
403        logger.debug(f'connection from central {sender}')
404        self.controller.on_link_central_connected(Address(sender))
405
406        # Accept the connection by responding to it
407        await self.send_targeted_message(sender, 'connected')
408
409    async def on_connected_message_received(self, sender, _):
410        if not self.pending_connection:
411            logger.warning('received a connection ack, but no connection is pending')
412            return
413
414        # Remember the connection
415        self.central_connections.add(sender)
416
417        # Notify the controller
418        logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
419        self.controller.on_link_peripheral_connection_complete(
420            self.pending_connection, HCI_SUCCESS
421        )
422
423    async def on_disconnect_message_received(self, sender, message):
424        # Notify the controller
425        params = parse_parameters(message)
426        reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR)))
427        self.controller.on_link_central_disconnected(Address(sender), reason)
428
429        # Forget the connection
430        if sender in self.peripheral_connections:
431            self.peripheral_connections.remove(sender)
432
433    async def on_encrypted_message_received(self, sender, _):
434        # TODO parse params to get real args
435        self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
436
437    async def send_rpc_command(self, command):
438        # Ensure we have a connection
439        websocket = await self.websocket
440
441        # Create a future value to hold the eventual result
442        assert self.rpc_result is None
443        self.rpc_result = asyncio.get_running_loop().create_future()
444
445        # Send the command
446        await websocket.send(command)
447
448        # Wait for the result
449        rpc_result = await self.rpc_result
450        self.rpc_result = None
451        logger.debug(f'rpc_result: {rpc_result}')
452
453        # TODO: parse the result
454
455    async def send_targeted_message(self, target, message):
456        # Ensure we have a connection
457        websocket = await self.websocket
458
459        # Send the message
460        await websocket.send(f'@{target} {message}')
461
462    async def notify_address_changed(self):
463        await self.send_rpc_command(f'/set-address {self.controller.random_address}')
464
465    def on_address_changed(self, controller):
466        logger.info(f'address changed for {controller}: {controller.random_address}')
467
468        # Notify the relay of the change
469        self.execute(self.notify_address_changed)
470
471    async def send_advertising_data_to_relay(self, data):
472        await self.send_targeted_message('*', f'advertisement:{data.hex()}')
473
474    def send_advertising_data(self, _, data):
475        self.execute(partial(self.send_advertising_data_to_relay, data))
476
477    async def send_acl_data_to_relay(self, peer_address, data):
478        await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
479
480    def send_acl_data(self, _, peer_address, _transport, data):
481        # TODO: handle different transport
482        self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
483
484    async def send_connection_request_to_relay(self, peer_address):
485        await self.send_targeted_message(peer_address, 'connect')
486
487    def connect(self, _, le_create_connection_command):
488        if self.pending_connection:
489            logger.warning('connection already pending')
490            return
491        self.pending_connection = le_create_connection_command
492        self.execute(
493            partial(
494                self.send_connection_request_to_relay,
495                str(le_create_connection_command.peer_address),
496            )
497        )
498
499    def on_disconnection_complete(self, disconnect_command):
500        self.controller.on_link_peripheral_disconnection_complete(
501            disconnect_command, HCI_SUCCESS
502        )
503
504    def disconnect(self, central_address, peripheral_address, disconnect_command):
505        logger.debug(
506            f'disconnect {central_address} -> '
507            f'{peripheral_address}: reason = {disconnect_command.reason}'
508        )
509        self.execute(
510            partial(
511                self.send_targeted_message,
512                peripheral_address,
513                f'disconnect:reason={disconnect_command.reason}',
514            )
515        )
516        asyncio.get_running_loop().call_soon(
517            self.on_disconnection_complete, disconnect_command
518        )
519
520    def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
521        asyncio.get_running_loop().call_soon(
522            self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
523        )
524        self.execute(
525            partial(
526                self.send_targeted_message,
527                peripheral_address,
528                f'encrypted:ltk={ltk.hex()}',
529            )
530        )
531