• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import logging
19import asyncio
20from functools import partial
21
22from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
23from bumble.colors import color
24from bumble.hci import (
25    Address,
26    HCI_SUCCESS,
27    HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
28    HCI_CONNECTION_TIMEOUT_ERROR,
29    HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
30    HCI_PAGE_TIMEOUT_ERROR,
31    HCI_Connection_Complete_Event,
32)
33from bumble import controller
34
35from typing import Optional, Set
36
37# -----------------------------------------------------------------------------
38# Logging
39# -----------------------------------------------------------------------------
40logger = logging.getLogger(__name__)
41
42
43# -----------------------------------------------------------------------------
44# Utils
45# -----------------------------------------------------------------------------
46def parse_parameters(params_str):
47    result = {}
48    for param_str in params_str.split(','):
49        if '=' in param_str:
50            key, value = param_str.split('=')
51            result[key] = value
52    return result
53
54
55# -----------------------------------------------------------------------------
56# TODO: add more support for various LL exchanges
57# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
58# -----------------------------------------------------------------------------
59class LocalLink:
60    '''
61    Link bus for controllers to communicate with each other
62    '''
63
64    controllers: Set[controller.Controller]
65
66    def __init__(self):
67        self.controllers = set()
68        self.pending_connection = None
69        self.pending_classic_connection = None
70
71    ############################################################
72    # Common utils
73    ############################################################
74
75    def add_controller(self, controller):
76        logger.debug(f'new controller: {controller}')
77        self.controllers.add(controller)
78
79    def remove_controller(self, controller):
80        self.controllers.remove(controller)
81
82    def find_controller(self, address):
83        for controller in self.controllers:
84            if controller.random_address == address:
85                return controller
86        return None
87
88    def find_classic_controller(
89        self, address: Address
90    ) -> Optional[controller.Controller]:
91        for controller in self.controllers:
92            if controller.public_address == address:
93                return controller
94        return None
95
96    def get_pending_connection(self):
97        return self.pending_connection
98
99    ############################################################
100    # LE handlers
101    ############################################################
102
103    def on_address_changed(self, controller):
104        pass
105
106    def send_advertising_data(self, sender_address, data):
107        # Send the advertising data to all controllers, except the sender
108        for controller in self.controllers:
109            if controller.random_address != sender_address:
110                controller.on_link_advertising_data(sender_address, data)
111
112    def send_acl_data(self, sender_controller, destination_address, transport, data):
113        # Send the data to the first controller with a matching address
114        if transport == BT_LE_TRANSPORT:
115            destination_controller = self.find_controller(destination_address)
116            source_address = sender_controller.random_address
117        elif transport == BT_BR_EDR_TRANSPORT:
118            destination_controller = self.find_classic_controller(destination_address)
119            source_address = sender_controller.public_address
120
121        if destination_controller is not None:
122            destination_controller.on_link_acl_data(source_address, transport, data)
123
124    def on_connection_complete(self):
125        # Check that we expect this call
126        if not self.pending_connection:
127            logger.warning('on_connection_complete with no pending connection')
128            return
129
130        central_address, le_create_connection_command = self.pending_connection
131        self.pending_connection = None
132
133        # Find the controller that initiated the connection
134        if not (central_controller := self.find_controller(central_address)):
135            logger.warning('!!! Initiating controller not found')
136            return
137
138        # Connect to the first controller with a matching address
139        if peripheral_controller := self.find_controller(
140            le_create_connection_command.peer_address
141        ):
142            central_controller.on_link_peripheral_connection_complete(
143                le_create_connection_command, HCI_SUCCESS
144            )
145            peripheral_controller.on_link_central_connected(central_address)
146            return
147
148        # No peripheral found
149        central_controller.on_link_peripheral_connection_complete(
150            le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
151        )
152
153    def connect(self, central_address, le_create_connection_command):
154        logger.debug(
155            f'$$$ CONNECTION {central_address} -> '
156            f'{le_create_connection_command.peer_address}'
157        )
158        self.pending_connection = (central_address, le_create_connection_command)
159        asyncio.get_running_loop().call_soon(self.on_connection_complete)
160
161    def on_disconnection_complete(
162        self, central_address, peripheral_address, disconnect_command
163    ):
164        # Find the controller that initiated the disconnection
165        if not (central_controller := self.find_controller(central_address)):
166            logger.warning('!!! Initiating controller not found')
167            return
168
169        # Disconnect from the first controller with a matching address
170        if peripheral_controller := self.find_controller(peripheral_address):
171            peripheral_controller.on_link_central_disconnected(
172                central_address, disconnect_command.reason
173            )
174
175        central_controller.on_link_peripheral_disconnection_complete(
176            disconnect_command, HCI_SUCCESS
177        )
178
179    def disconnect(self, central_address, peripheral_address, disconnect_command):
180        logger.debug(
181            f'$$$ DISCONNECTION {central_address} -> '
182            f'{peripheral_address}: reason = {disconnect_command.reason}'
183        )
184        args = [central_address, peripheral_address, disconnect_command]
185        asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
186
187    # pylint: disable=too-many-arguments
188    def on_connection_encrypted(
189        self, central_address, peripheral_address, rand, ediv, ltk
190    ):
191        logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
192
193        if central_controller := self.find_controller(central_address):
194            central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk)
195
196        if peripheral_controller := self.find_controller(peripheral_address):
197            peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
198
199    def create_cis(
200        self,
201        central_controller: controller.Controller,
202        peripheral_address: Address,
203        cig_id: int,
204        cis_id: int,
205    ) -> None:
206        logger.debug(
207            f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
208        )
209        if peripheral_controller := self.find_controller(peripheral_address):
210            asyncio.get_running_loop().call_soon(
211                peripheral_controller.on_link_cis_request,
212                central_controller.random_address,
213                cig_id,
214                cis_id,
215            )
216
217    def accept_cis(
218        self,
219        peripheral_controller: controller.Controller,
220        central_address: Address,
221        cig_id: int,
222        cis_id: int,
223    ) -> None:
224        logger.debug(
225            f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
226        )
227        if central_controller := self.find_controller(central_address):
228            asyncio.get_running_loop().call_soon(
229                central_controller.on_link_cis_established, cig_id, cis_id
230            )
231            asyncio.get_running_loop().call_soon(
232                peripheral_controller.on_link_cis_established, cig_id, cis_id
233            )
234
235    def disconnect_cis(
236        self,
237        initiator_controller: controller.Controller,
238        peer_address: Address,
239        cig_id: int,
240        cis_id: int,
241    ) -> None:
242        logger.debug(
243            f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
244        )
245        if peer_controller := self.find_controller(peer_address):
246            asyncio.get_running_loop().call_soon(
247                initiator_controller.on_link_cis_disconnected, cig_id, cis_id
248            )
249            asyncio.get_running_loop().call_soon(
250                peer_controller.on_link_cis_disconnected, cig_id, cis_id
251            )
252
253    ############################################################
254    # Classic handlers
255    ############################################################
256
257    def classic_connect(self, initiator_controller, responder_address):
258        logger.debug(
259            f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
260        )
261        responder_controller = self.find_classic_controller(responder_address)
262        if responder_controller is None:
263            initiator_controller.on_classic_connection_complete(
264                responder_address, HCI_PAGE_TIMEOUT_ERROR
265            )
266            return
267        self.pending_classic_connection = (initiator_controller, responder_controller)
268
269        responder_controller.on_classic_connection_request(
270            initiator_controller.public_address,
271            HCI_Connection_Complete_Event.ACL_LINK_TYPE,
272        )
273
274    def classic_accept_connection(
275        self, responder_controller, initiator_address, responder_role
276    ):
277        logger.debug(
278            f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
279        )
280        initiator_controller = self.find_classic_controller(initiator_address)
281        if initiator_controller is None:
282            responder_controller.on_classic_connection_complete(
283                responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
284            )
285            return
286
287        async def task():
288            if responder_role != BT_PERIPHERAL_ROLE:
289                initiator_controller.on_classic_role_change(
290                    responder_controller.public_address, int(not (responder_role))
291                )
292            initiator_controller.on_classic_connection_complete(
293                responder_controller.public_address, HCI_SUCCESS
294            )
295
296        asyncio.create_task(task())
297        responder_controller.on_classic_role_change(
298            initiator_controller.public_address, responder_role
299        )
300        responder_controller.on_classic_connection_complete(
301            initiator_controller.public_address, HCI_SUCCESS
302        )
303        self.pending_classic_connection = None
304
305    def classic_disconnect(self, initiator_controller, responder_address, reason):
306        logger.debug(
307            f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
308        )
309        responder_controller = self.find_classic_controller(responder_address)
310
311        async def task():
312            initiator_controller.on_classic_disconnected(responder_address, reason)
313
314        asyncio.create_task(task())
315        responder_controller.on_classic_disconnected(
316            initiator_controller.public_address, reason
317        )
318
319    def classic_switch_role(
320        self, initiator_controller, responder_address, initiator_new_role
321    ):
322        responder_controller = self.find_classic_controller(responder_address)
323        if responder_controller is None:
324            return
325
326        async def task():
327            initiator_controller.on_classic_role_change(
328                responder_address, initiator_new_role
329            )
330
331        asyncio.create_task(task())
332        responder_controller.on_classic_role_change(
333            initiator_controller.public_address, int(not (initiator_new_role))
334        )
335
336    def classic_sco_connect(
337        self,
338        initiator_controller: controller.Controller,
339        responder_address: Address,
340        link_type: int,
341    ):
342        logger.debug(
343            f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
344        )
345        responder_controller = self.find_classic_controller(responder_address)
346        # Initiator controller should handle it.
347        assert responder_controller
348
349        responder_controller.on_classic_connection_request(
350            initiator_controller.public_address,
351            link_type,
352        )
353
354    def classic_accept_sco_connection(
355        self,
356        responder_controller: controller.Controller,
357        initiator_address: Address,
358        link_type: int,
359    ):
360        logger.debug(
361            f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
362        )
363        initiator_controller = self.find_classic_controller(initiator_address)
364        if initiator_controller is None:
365            responder_controller.on_classic_sco_connection_complete(
366                responder_controller.public_address,
367                HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
368                link_type,
369            )
370            return
371
372        async def task():
373            initiator_controller.on_classic_sco_connection_complete(
374                responder_controller.public_address, HCI_SUCCESS, link_type
375            )
376
377        asyncio.create_task(task())
378        responder_controller.on_classic_sco_connection_complete(
379            initiator_controller.public_address, HCI_SUCCESS, link_type
380        )
381
382
383# -----------------------------------------------------------------------------
384class RemoteLink:
385    '''
386    A Link implementation that communicates with other virtual controllers via a
387    WebSocket relay
388    '''
389
390    def __init__(self, uri):
391        self.controller = None
392        self.uri = uri
393        self.execution_queue = asyncio.Queue()
394        self.websocket = asyncio.get_running_loop().create_future()
395        self.rpc_result = None
396        self.pending_connection = None
397        self.central_connections = set()  # List of addresses that we have connected to
398        self.peripheral_connections = (
399            set()
400        )  # List of addresses that have connected to us
401
402        # Connect and run asynchronously
403        asyncio.create_task(self.run_connection())
404        asyncio.create_task(self.run_executor_loop())
405
406    def add_controller(self, controller):
407        if self.controller:
408            raise ValueError('controller already set')
409        self.controller = controller
410
411    def remove_controller(self, controller):
412        if self.controller != controller:
413            raise ValueError('controller mismatch')
414        self.controller = None
415
416    def get_pending_connection(self):
417        return self.pending_connection
418
419    def get_pending_classic_connection(self):
420        return self.pending_classic_connection
421
422    async def wait_until_connected(self):
423        await self.websocket
424
425    def execute(self, async_function):
426        self.execution_queue.put_nowait(async_function())
427
428    async def run_executor_loop(self):
429        logger.debug('executor loop starting')
430        while True:
431            item = await self.execution_queue.get()
432            try:
433                await item
434            except Exception as error:
435                logger.warning(
436                    f'{color("!!! Exception in async handler:", "red")} {error}'
437                )
438
439    async def run_connection(self):
440        import websockets  # lazy import
441
442        # Connect to the relay
443        logger.debug(f'connecting to {self.uri}')
444        # pylint: disable-next=no-member
445        websocket = await websockets.connect(self.uri)
446        self.websocket.set_result(websocket)
447        logger.debug(f'connected to {self.uri}')
448
449        while True:
450            message = await websocket.recv()
451            logger.debug(f'received message: {message}')
452            keyword, *payload = message.split(':', 1)
453
454            handler_name = f'on_{keyword}_received'
455            handler = getattr(self, handler_name, None)
456            if handler:
457                await handler(payload[0] if payload else None)
458
459    def close(self):
460        if self.websocket.done():
461            logger.debug('closing websocket')
462            websocket = self.websocket.result()
463            asyncio.create_task(websocket.close())
464
465    async def on_result_received(self, result):
466        if self.rpc_result:
467            self.rpc_result.set_result(result)
468
469    async def on_left_received(self, address):
470        if address in self.central_connections:
471            self.controller.on_link_peripheral_disconnected(Address(address))
472            self.central_connections.remove(address)
473
474        if address in self.peripheral_connections:
475            self.controller.on_link_central_disconnected(
476                address, HCI_CONNECTION_TIMEOUT_ERROR
477            )
478            self.peripheral_connections.remove(address)
479
480    async def on_unreachable_received(self, target):
481        await self.on_left_received(target)
482
483    async def on_message_received(self, message):
484        sender, *payload = message.split('/', 1)
485        if payload:
486            keyword, *payload = payload[0].split(':', 1)
487            handler_name = f'on_{keyword}_message_received'
488            handler = getattr(self, handler_name, None)
489            if handler:
490                await handler(sender, payload[0] if payload else None)
491
492    async def on_advertisement_message_received(self, sender, advertisement):
493        try:
494            self.controller.on_link_advertising_data(
495                Address(sender), bytes.fromhex(advertisement)
496            )
497        except Exception:
498            logger.exception('exception')
499
500    async def on_acl_message_received(self, sender, acl_data):
501        try:
502            self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data))
503        except Exception:
504            logger.exception('exception')
505
506    async def on_connect_message_received(self, sender, _):
507        # Remember the connection
508        self.peripheral_connections.add(sender)
509
510        # Notify the controller
511        logger.debug(f'connection from central {sender}')
512        self.controller.on_link_central_connected(Address(sender))
513
514        # Accept the connection by responding to it
515        await self.send_targeted_message(sender, 'connected')
516
517    async def on_connected_message_received(self, sender, _):
518        if not self.pending_connection:
519            logger.warning('received a connection ack, but no connection is pending')
520            return
521
522        # Remember the connection
523        self.central_connections.add(sender)
524
525        # Notify the controller
526        logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
527        self.controller.on_link_peripheral_connection_complete(
528            self.pending_connection, HCI_SUCCESS
529        )
530
531    async def on_disconnect_message_received(self, sender, message):
532        # Notify the controller
533        params = parse_parameters(message)
534        reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR)))
535        self.controller.on_link_central_disconnected(Address(sender), reason)
536
537        # Forget the connection
538        if sender in self.peripheral_connections:
539            self.peripheral_connections.remove(sender)
540
541    async def on_encrypted_message_received(self, sender, _):
542        # TODO parse params to get real args
543        self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
544
545    async def send_rpc_command(self, command):
546        # Ensure we have a connection
547        websocket = await self.websocket
548
549        # Create a future value to hold the eventual result
550        assert self.rpc_result is None
551        self.rpc_result = asyncio.get_running_loop().create_future()
552
553        # Send the command
554        await websocket.send(command)
555
556        # Wait for the result
557        rpc_result = await self.rpc_result
558        self.rpc_result = None
559        logger.debug(f'rpc_result: {rpc_result}')
560
561        # TODO: parse the result
562
563    async def send_targeted_message(self, target, message):
564        # Ensure we have a connection
565        websocket = await self.websocket
566
567        # Send the message
568        await websocket.send(f'@{target} {message}')
569
570    async def notify_address_changed(self):
571        await self.send_rpc_command(f'/set-address {self.controller.random_address}')
572
573    def on_address_changed(self, controller):
574        logger.info(f'address changed for {controller}: {controller.random_address}')
575
576        # Notify the relay of the change
577        self.execute(self.notify_address_changed)
578
579    async def send_advertising_data_to_relay(self, data):
580        await self.send_targeted_message('*', f'advertisement:{data.hex()}')
581
582    def send_advertising_data(self, _, data):
583        self.execute(partial(self.send_advertising_data_to_relay, data))
584
585    async def send_acl_data_to_relay(self, peer_address, data):
586        await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
587
588    def send_acl_data(self, _, peer_address, _transport, data):
589        # TODO: handle different transport
590        self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
591
592    async def send_connection_request_to_relay(self, peer_address):
593        await self.send_targeted_message(peer_address, 'connect')
594
595    def connect(self, _, le_create_connection_command):
596        if self.pending_connection:
597            logger.warning('connection already pending')
598            return
599        self.pending_connection = le_create_connection_command
600        self.execute(
601            partial(
602                self.send_connection_request_to_relay,
603                str(le_create_connection_command.peer_address),
604            )
605        )
606
607    def on_disconnection_complete(self, disconnect_command):
608        self.controller.on_link_peripheral_disconnection_complete(
609            disconnect_command, HCI_SUCCESS
610        )
611
612    def disconnect(self, central_address, peripheral_address, disconnect_command):
613        logger.debug(
614            f'disconnect {central_address} -> '
615            f'{peripheral_address}: reason = {disconnect_command.reason}'
616        )
617        self.execute(
618            partial(
619                self.send_targeted_message,
620                peripheral_address,
621                f'disconnect:reason={disconnect_command.reason}',
622            )
623        )
624        asyncio.get_running_loop().call_soon(
625            self.on_disconnection_complete, disconnect_command
626        )
627
628    def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
629        asyncio.get_running_loop().call_soon(
630            self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
631        )
632        self.execute(
633            partial(
634                self.send_targeted_message,
635                peripheral_address,
636                f'encrypted:ltk={ltk.hex()}',
637            )
638        )
639