• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import os
20import logging
21import click
22from prompt_toolkit.shortcuts import PromptSession
23
24from bumble.colors import color
25from bumble.device import Device, Peer
26from bumble.transport import open_transport_or_link
27from bumble.smp import PairingDelegate, PairingConfig
28from bumble.smp import error_name as smp_error_name
29from bumble.keys import JsonKeyStore
30from bumble.core import ProtocolError
31from bumble.gatt import (
32    GATT_DEVICE_NAME_CHARACTERISTIC,
33    GATT_GENERIC_ACCESS_SERVICE,
34    Service,
35    Characteristic,
36    CharacteristicValue,
37)
38from bumble.att import (
39    ATT_Error,
40    ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
41    ATT_INSUFFICIENT_ENCRYPTION_ERROR,
42)
43
44
45# -----------------------------------------------------------------------------
46class Waiter:
47    instance = None
48
49    def __init__(self):
50        self.done = asyncio.get_running_loop().create_future()
51
52    def terminate(self):
53        self.done.set_result(None)
54
55    async def wait_until_terminated(self):
56        return await self.done
57
58
59# -----------------------------------------------------------------------------
60class Delegate(PairingDelegate):
61    def __init__(self, mode, connection, capability_string, do_prompt):
62        super().__init__(
63            {
64                'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
65                'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
66                'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
67                'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
68                'none': PairingDelegate.NO_OUTPUT_NO_INPUT,
69            }[capability_string.lower()]
70        )
71
72        self.mode = mode
73        self.peer = Peer(connection)
74        self.peer_name = None
75        self.do_prompt = do_prompt
76
77    def print(self, message):
78        print(color(message, 'yellow'))
79
80    async def prompt(self, message):
81        # Wait a bit to allow some of the log lines to print before we prompt
82        await asyncio.sleep(1)
83
84        session = PromptSession(message)
85        response = await session.prompt_async()
86        return response.lower().strip()
87
88    async def update_peer_name(self):
89        if self.peer_name is not None:
90            # We already asked the peer
91            return
92
93        # Try to get the peer's name
94        if self.peer:
95            peer_name = await get_peer_name(self.peer, self.mode)
96            self.peer_name = f'{peer_name or ""} [{self.peer.connection.peer_address}]'
97        else:
98            self.peer_name = '[?]'
99
100    async def accept(self):
101        if self.do_prompt:
102            await self.update_peer_name()
103
104            # Prompt for acceptance
105            self.print('###-----------------------------------')
106            self.print(f'### Pairing request from {self.peer_name}')
107            self.print('###-----------------------------------')
108            while True:
109                response = await self.prompt('>>> Accept? ')
110
111                if response == 'yes':
112                    return True
113
114                if response == 'no':
115                    return False
116
117        # Accept silently
118        return True
119
120    async def compare_numbers(self, number, digits):
121        await self.update_peer_name()
122
123        # Prompt for a numeric comparison
124        self.print('###-----------------------------------')
125        self.print(f'### Pairing with {self.peer_name}')
126        self.print('###-----------------------------------')
127        while True:
128            response = await self.prompt(
129                f'>>> Does the other device display {number:0{digits}}? '
130            )
131
132            if response == 'yes':
133                return True
134
135            if response == 'no':
136                return False
137
138    async def get_number(self):
139        await self.update_peer_name()
140
141        # Prompt for a PIN
142        while True:
143            try:
144                self.print('###-----------------------------------')
145                self.print(f'### Pairing with {self.peer_name}')
146                self.print('###-----------------------------------')
147                return int(await self.prompt('>>> Enter PIN: '))
148            except ValueError:
149                pass
150
151    async def display_number(self, number, digits):
152        await self.update_peer_name()
153
154        # Display a PIN code
155        self.print('###-----------------------------------')
156        self.print(f'### Pairing with {self.peer_name}')
157        self.print(f'### PIN: {number:0{digits}}')
158        self.print('###-----------------------------------')
159
160
161# -----------------------------------------------------------------------------
162async def get_peer_name(peer, mode):
163    if mode == 'classic':
164        return await peer.request_name()
165
166    # Try to get the peer name from GATT
167    services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
168    if not services:
169        return None
170
171    values = await peer.read_characteristics_by_uuid(
172        GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
173    )
174    if values:
175        return values[0].decode('utf-8')
176
177    return None
178
179
180# -----------------------------------------------------------------------------
181AUTHENTICATION_ERROR_RETURNED = [False, False]
182
183
184def read_with_error(connection):
185    if not connection.is_encrypted:
186        raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
187
188    if AUTHENTICATION_ERROR_RETURNED[0]:
189        return bytes([1])
190
191    AUTHENTICATION_ERROR_RETURNED[0] = True
192    raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
193
194
195def write_with_error(connection, _value):
196    if not connection.is_encrypted:
197        raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
198
199    if not AUTHENTICATION_ERROR_RETURNED[1]:
200        AUTHENTICATION_ERROR_RETURNED[1] = True
201        raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
202
203
204# -----------------------------------------------------------------------------
205def on_connection(connection, request):
206    print(color(f'<<< Connection: {connection}', 'green'))
207
208    # Listen for pairing events
209    connection.on('pairing_start', on_pairing_start)
210    connection.on('pairing', on_pairing)
211    connection.on('pairing_failure', on_pairing_failure)
212
213    # Listen for encryption changes
214    connection.on(
215        'connection_encryption_change',
216        lambda: on_connection_encryption_change(connection),
217    )
218
219    # Request pairing if needed
220    if request:
221        print(color('>>> Requesting pairing', 'green'))
222        connection.request_pairing()
223
224
225# -----------------------------------------------------------------------------
226def on_connection_encryption_change(connection):
227    print(color('@@@-----------------------------------', 'blue'))
228    print(
229        color(
230            f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted',
231            'blue',
232        )
233    )
234    print(color('@@@-----------------------------------', 'blue'))
235
236
237# -----------------------------------------------------------------------------
238def on_pairing_start():
239    print(color('***-----------------------------------', 'magenta'))
240    print(color('*** Pairing starting', 'magenta'))
241    print(color('***-----------------------------------', 'magenta'))
242
243
244# -----------------------------------------------------------------------------
245def on_pairing(keys):
246    print(color('***-----------------------------------', 'cyan'))
247    print(color('*** Paired!', 'cyan'))
248    keys.print(prefix=color('*** ', 'cyan'))
249    print(color('***-----------------------------------', 'cyan'))
250    Waiter.instance.terminate()
251
252
253# -----------------------------------------------------------------------------
254def on_pairing_failure(reason):
255    print(color('***-----------------------------------', 'red'))
256    print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
257    print(color('***-----------------------------------', 'red'))
258    Waiter.instance.terminate()
259
260
261# -----------------------------------------------------------------------------
262async def pair(
263    mode,
264    sc,
265    mitm,
266    bond,
267    io,
268    prompt,
269    request,
270    print_keys,
271    keystore_file,
272    device_config,
273    hci_transport,
274    address_or_name,
275):
276    Waiter.instance = Waiter()
277
278    print('<<< connecting to HCI...')
279    async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
280        print('<<< connected')
281
282        # Create a device to manage the host
283        device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
284
285        # Set a custom keystore if specified on the command line
286        if keystore_file:
287            device.keystore = JsonKeyStore(namespace=None, filename=keystore_file)
288
289        # Print the existing keys before pairing
290        if print_keys and device.keystore:
291            print(color('@@@-----------------------------------', 'blue'))
292            print(color('@@@ Pairing Keys:', 'blue'))
293            await device.keystore.print(prefix=color('@@@ ', 'blue'))
294            print(color('@@@-----------------------------------', 'blue'))
295
296        # Expose a GATT characteristic that can be used to trigger pairing by
297        # responding with an authentication error when read
298        if mode == 'le':
299            device.add_service(
300                Service(
301                    '50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
302                    [
303                        Characteristic(
304                            '552957FB-CF1F-4A31-9535-E78847E1A714',
305                            Characteristic.READ | Characteristic.WRITE,
306                            Characteristic.READABLE | Characteristic.WRITEABLE,
307                            CharacteristicValue(
308                                read=read_with_error, write=write_with_error
309                            ),
310                        )
311                    ],
312                )
313            )
314
315        # Select LE or Classic
316        if mode == 'classic':
317            device.classic_enabled = True
318            device.le_enabled = False
319
320        # Get things going
321        await device.power_on()
322
323        # Set up a pairing config factory
324        device.pairing_config_factory = lambda connection: PairingConfig(
325            sc, mitm, bond, Delegate(mode, connection, io, prompt)
326        )
327
328        # Connect to a peer or wait for a connection
329        device.on('connection', lambda connection: on_connection(connection, request))
330        if address_or_name is not None:
331            print(color(f'=== Connecting to {address_or_name}...', 'green'))
332            connection = await device.connect(address_or_name)
333
334            if not request:
335                try:
336                    if mode == 'le':
337                        await connection.pair()
338                    else:
339                        await connection.authenticate()
340                    return
341                except ProtocolError as error:
342                    print(color(f'Pairing failed: {error}', 'red'))
343                    return
344        else:
345            # Advertise so that peers can find us and connect
346            await device.start_advertising(auto_restart=True)
347
348        # Run until the user asks to exit
349        await Waiter.instance.wait_until_terminated()
350
351
352# -----------------------------------------------------------------------------
353class LogHandler(logging.Handler):
354    def __init__(self):
355        super().__init__()
356        self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
357
358    def emit(self, record):
359        message = self.format(record)
360        print(message)
361
362
363# -----------------------------------------------------------------------------
364@click.command()
365@click.option(
366    '--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True
367)
368@click.option(
369    '--sc',
370    type=bool,
371    default=True,
372    help='Use the Secure Connections protocol',
373    show_default=True,
374)
375@click.option(
376    '--mitm', type=bool, default=True, help='Request MITM protection', show_default=True
377)
378@click.option(
379    '--bond', type=bool, default=True, help='Enable bonding', show_default=True
380)
381@click.option(
382    '--io',
383    type=click.Choice(
384        ['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']
385    ),
386    default='display+keyboard',
387    show_default=True,
388)
389@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
390@click.option(
391    '--request', is_flag=True, help='Request that the connecting peer initiate pairing'
392)
393@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
394@click.option(
395    '--keystore-file',
396    metavar='<filename>',
397    help='File in which to store the pairing keys',
398)
399@click.argument('device-config')
400@click.argument('hci_transport')
401@click.argument('address-or-name', required=False)
402def main(
403    mode,
404    sc,
405    mitm,
406    bond,
407    io,
408    prompt,
409    request,
410    print_keys,
411    keystore_file,
412    device_config,
413    hci_transport,
414    address_or_name,
415):
416    # Setup logging
417    log_handler = LogHandler()
418    root_logger = logging.getLogger()
419    root_logger.addHandler(log_handler)
420    root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
421
422    # Pair
423    asyncio.run(
424        pair(
425            mode,
426            sc,
427            mitm,
428            bond,
429            io,
430            prompt,
431            request,
432            print_keys,
433            keystore_file,
434            device_config,
435            hci_transport,
436            address_or_name,
437        )
438    )
439
440
441# -----------------------------------------------------------------------------
442if __name__ == '__main__':
443    main()
444