• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import grpc
17import logging
18
19from . import utils
20from bumble import hci
21from bumble.core import BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT, BT_PERIPHERAL_ROLE, ProtocolError
22from bumble.device import Connection as BumbleConnection, Device
23from bumble.hci import HCI_Error
24from bumble.smp import PairingConfig, PairingDelegate as BasePairingDelegate
25from contextlib import suppress
26from google.protobuf import any_pb2, empty_pb2, wrappers_pb2  # pytype: disable=pyi-error
27from google.protobuf.wrappers_pb2 import BoolValue  # pytype: disable=pyi-error
28from pandora.host_pb2 import Connection
29from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
30from pandora.security_pb2 import (
31    LE_LEVEL1,
32    LE_LEVEL2,
33    LE_LEVEL3,
34    LE_LEVEL4,
35    LEVEL0,
36    LEVEL1,
37    LEVEL2,
38    LEVEL3,
39    LEVEL4,
40    DeleteBondRequest,
41    IsBondedRequest,
42    LESecurityLevel,
43    PairingEvent,
44    PairingEventAnswer,
45    SecureRequest,
46    SecureResponse,
47    SecurityLevel,
48    WaitSecurityRequest,
49    WaitSecurityResponse,
50)
51from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union, cast
52
53
54class PairingDelegate(BasePairingDelegate):
55    def __init__(
56        self,
57        connection: BumbleConnection,
58        service: "SecurityService",
59        io_capability: int = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
60        local_initiator_key_distribution: int = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
61        local_responder_key_distribution: int = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
62    ) -> None:
63        self.log = utils.BumbleServerLoggerAdapter(
64            logging.getLogger(), {'service_name': 'Security', 'device': connection.device}
65        )
66        self.connection = connection
67        self.service = service
68        super().__init__(io_capability, local_initiator_key_distribution, local_responder_key_distribution)
69
70    async def accept(self) -> bool:
71        return True
72
73    def add_origin(self, ev: PairingEvent) -> PairingEvent:
74        if not self.connection.is_incomplete:
75            assert ev.connection
76            ev.connection.CopyFrom(Connection(cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big'))))
77        else:
78            # In BR/EDR, connection may not be complete,
79            # use address instead
80            assert self.connection.transport == BT_BR_EDR_TRANSPORT
81            ev.address = bytes(reversed(bytes(self.connection.peer_address)))
82
83        return ev
84
85    async def confirm(self) -> bool:
86        self.log.info(f"Pairing event: `just_works` (io_capability: {self.io_capability})")
87
88        if self.service.event_queue is None or self.service.event_answer is None:
89            return True
90
91        event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
92        self.service.event_queue.put_nowait(event)
93        answer = await anext(self.service.event_answer)  # pytype: disable=name-error
94        assert answer.event == event
95        assert answer.answer_variant() == 'confirm' and answer.confirm is not None
96        return answer.confirm
97
98    async def compare_numbers(self, number: int, digits: int = 6) -> bool:
99        self.log.info(f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})")
100
101        if self.service.event_queue is None or self.service.event_answer is None:
102            raise RuntimeError('security: unhandled number comparison request')
103
104        event = self.add_origin(PairingEvent(numeric_comparison=number))
105        self.service.event_queue.put_nowait(event)
106        answer = await anext(self.service.event_answer)  # pytype: disable=name-error
107        assert answer.event == event
108        assert answer.answer_variant() == 'confirm' and answer.confirm is not None
109        return answer.confirm
110
111    async def get_number(self) -> Optional[int]:
112        self.log.info(f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})")
113
114        if self.service.event_queue is None or self.service.event_answer is None:
115            raise RuntimeError('security: unhandled number request')
116
117        event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
118        self.service.event_queue.put_nowait(event)
119        answer = await anext(self.service.event_answer)  # pytype: disable=name-error
120        assert answer.event == event
121        assert answer.answer_variant() == 'passkey'
122        return answer.passkey
123
124    async def get_string(self, max_length: int) -> Optional[str]:
125        self.log.info(f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})")
126
127        if self.service.event_queue is None or self.service.event_answer is None:
128            raise RuntimeError('security: unhandled pin_code request')
129
130        event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
131        self.service.event_queue.put_nowait(event)
132        answer = await anext(self.service.event_answer)  # pytype: disable=name-error
133        assert answer.event == event
134        assert answer.answer_variant() == 'pin'
135
136        if answer.pin is None:
137            return None
138
139        pin = answer.pin.decode('utf-8')
140        if not pin or len(pin) > max_length:
141            raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes')
142
143        return pin
144
145    async def display_number(self, number: int, digits: int = 6) -> None:
146        self.log.info(f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})")
147
148        if self.service.event_queue is None:
149            raise RuntimeError('security: unhandled number display request')
150
151        event = self.add_origin(PairingEvent(passkey_entry_notification=number))
152        self.service.event_queue.put_nowait(event)
153
154
155BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
156    LEVEL0: lambda connection: True,
157    LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
158    LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
159    LEVEL3: lambda connection: connection.encryption != 0
160    and connection.authenticated
161    and connection.link_key_type
162    in (
163        hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
164        hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
165    ),
166    LEVEL4: lambda connection: connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM
167    and connection.authenticated
168    and connection.link_key_type == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
169}
170
171LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
172    LE_LEVEL1: lambda connection: True,
173    LE_LEVEL2: lambda connection: connection.encryption != 0,
174    LE_LEVEL3: lambda connection: connection.encryption != 0 and connection.authenticated,
175    LE_LEVEL4: lambda connection: connection.encryption != 0 and connection.authenticated and connection.sc,
176}
177
178
179class SecurityService(SecurityServicer):
180    def __init__(self, device: Device, io_capability: int) -> None:
181        self.log = utils.BumbleServerLoggerAdapter(logging.getLogger(), {'service_name': 'Security', 'device': device})
182        self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
183        self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
184        self.device = device
185
186        def pairing_config_factory(connection: BumbleConnection) -> PairingConfig:
187            return PairingConfig(
188                sc=True,
189                mitm=True,
190                bonding=True,
191                delegate=PairingDelegate(
192                    connection, self, io_capability=cast(int, getattr(self.device, 'io_capability'))
193                ),
194            )
195
196        setattr(device, 'io_capability', io_capability)
197        self.device.pairing_config_factory = pairing_config_factory
198
199    @utils.rpc
200    async def OnPairing(
201        self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
202    ) -> AsyncGenerator[PairingEvent, None]:
203        self.log.info('OnPairing')
204
205        if self.event_queue is not None:
206            raise RuntimeError('already streaming pairing events')
207
208        if len(self.device.connections):
209            raise RuntimeError('the `OnPairing` method shall be initiated before establishing any connections.')
210
211        self.event_queue = asyncio.Queue()
212        self.event_answer = request
213
214        try:
215            while event := await self.event_queue.get():
216                yield event
217
218        finally:
219            self.event_queue = None
220            self.event_answer = None
221
222    @utils.rpc
223    async def Secure(self, request: SecureRequest, context: grpc.ServicerContext) -> SecureResponse:
224        connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
225        self.log.info(f"Secure: {connection_handle}")
226
227        connection = self.device.lookup_connection(connection_handle)
228        assert connection
229
230        oneof = request.WhichOneof('level')
231        level = getattr(request, oneof)
232        assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[connection.transport] == oneof
233
234        # security level already reached
235        if self.reached_security_level(connection, level):
236            return SecureResponse(success=empty_pb2.Empty())
237
238        # trigger pairing if needed
239        if self.need_pairing(connection, level):
240            try:
241                self.log.info('Pair...')
242
243                if connection.transport == BT_LE_TRANSPORT and connection.role == BT_PERIPHERAL_ROLE:
244                    wait_for_security: asyncio.Future[bool] = asyncio.get_running_loop().create_future()
245                    connection.on("pairing", lambda *_: wait_for_security.set_result(True))  # type: ignore
246                    connection.on("pairing_failure", wait_for_security.set_exception)
247
248                    connection.request_pairing()
249
250                    await wait_for_security
251                else:
252                    await connection.pair()
253
254                self.log.info('Paired')
255            except asyncio.CancelledError:
256                self.log.warning(f"Connection died during encryption")
257                return SecureResponse(connection_died=empty_pb2.Empty())
258            except (HCI_Error, ProtocolError) as e:
259                self.log.warning(f"Pairing failure: {e}")
260                return SecureResponse(pairing_failure=empty_pb2.Empty())
261
262        # trigger authentication if needed
263        if self.need_authentication(connection, level):
264            try:
265                self.log.info('Authenticate...')
266                await connection.authenticate()
267                self.log.info('Authenticated')
268            except asyncio.CancelledError:
269                self.log.warning(f"Connection died during authentication")
270                return SecureResponse(connection_died=empty_pb2.Empty())
271            except (HCI_Error, ProtocolError) as e:
272                self.log.warning(f"Authentication failure: {e}")
273                return SecureResponse(authentication_failure=empty_pb2.Empty())
274
275        # trigger encryption if needed
276        if self.need_encryption(connection, level):
277            try:
278                self.log.info('Encrypt...')
279                await connection.encrypt()
280                self.log.info('Encrypted')
281            except asyncio.CancelledError:
282                self.log.warning(f"Connection died during encryption")
283                return SecureResponse(connection_died=empty_pb2.Empty())
284            except (HCI_Error, ProtocolError) as e:
285                self.log.warning(f"Encryption failure: {e}")
286                return SecureResponse(encryption_failure=empty_pb2.Empty())
287
288        # security level has been reached ?
289        if self.reached_security_level(connection, level):
290            return SecureResponse(success=empty_pb2.Empty())
291        return SecureResponse(not_reached=empty_pb2.Empty())
292
293    @utils.rpc
294    async def WaitSecurity(self, request: WaitSecurityRequest, context: grpc.ServicerContext) -> WaitSecurityResponse:
295        connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
296        self.log.info(f"WaitSecurity: {connection_handle}")
297
298        connection = self.device.lookup_connection(connection_handle)
299        assert connection
300
301        assert request.level
302        level = request.level
303        assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[connection.transport] == request.level_variant()
304
305        wait_for_security: asyncio.Future[str] = asyncio.get_running_loop().create_future()
306        authenticate_task: Optional[asyncio.Future[None]] = None
307
308        async def authenticate() -> None:
309            assert connection
310            if (encryption := connection.encryption) != 0:
311                self.log.debug('Disable encryption...')
312                try:
313                    await connection.encrypt(enable=False)
314                except:
315                    pass
316                self.log.debug('Disable encryption: done')
317
318            self.log.debug('Authenticate...')
319            await connection.authenticate()
320            self.log.debug('Authenticate: done')
321
322            if encryption != 0 and connection.encryption != encryption:
323                self.log.debug('Re-enable encryption...')
324                await connection.encrypt()
325                self.log.debug('Re-enable encryption: done')
326
327        def set_failure(name: str) -> Callable[..., None]:
328            def wrapper(*args: Any) -> None:
329                self.log.info(f'Wait for security: error `{name}`: {args}')
330                wait_for_security.set_result(name)
331
332            return wrapper
333
334        def try_set_success(*_: Any) -> None:
335            assert connection
336            if self.reached_security_level(connection, level):
337                self.log.info(f'Wait for security: done')
338                wait_for_security.set_result('success')
339
340        def on_encryption_change(*_: Any) -> None:
341            assert connection
342            if self.reached_security_level(connection, level):
343                self.log.info(f'Wait for security: done')
344                wait_for_security.set_result('success')
345            elif connection.transport == BT_BR_EDR_TRANSPORT and self.need_authentication(connection, level):
346                nonlocal authenticate_task
347                if authenticate_task is None:
348                    authenticate_task = asyncio.create_task(authenticate())
349
350        listeners: Dict[str, Callable[..., None]] = {
351            'disconnection': set_failure('connection_died'),
352            'pairing_failure': set_failure('pairing_failure'),
353            'connection_authentication_failure': set_failure('authentication_failure'),
354            'connection_encryption_failure': set_failure('encryption_failure'),
355            'pairing': try_set_success,
356            'connection_authentication': try_set_success,
357            'connection_encryption_change': on_encryption_change,
358        }
359
360        # register event handlers
361        for event, listener in listeners.items():
362            connection.on(event, listener)
363
364        # security level already reached
365        if self.reached_security_level(connection, level):
366            return WaitSecurityResponse(success=empty_pb2.Empty())
367
368        self.log.info('Wait for security...')
369        kwargs = {}
370        kwargs[await wait_for_security] = empty_pb2.Empty()
371
372        # remove event handlers
373        for event, listener in listeners.items():
374            connection.remove_listener(event, listener)  # type: ignore
375
376        # wait for `authenticate` to finish if any
377        if authenticate_task is not None:
378            self.log.info('Wait for authentication...')
379            try:
380                await authenticate_task  # type: ignore
381            except:
382                pass
383            self.log.info('Authenticated')
384
385        return WaitSecurityResponse(**kwargs)
386
387    def reached_security_level(
388        self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
389    ) -> bool:
390        self.log.debug(
391            str(
392                {
393                    'level': level,
394                    'encryption': connection.encryption,
395                    'authenticated': connection.authenticated,
396                    'sc': connection.sc,
397                    'link_key_type': connection.link_key_type,
398                }
399            )
400        )
401
402        if isinstance(level, LESecurityLevel):
403            return LE_LEVEL_REACHED[level](connection)
404
405        return BR_LEVEL_REACHED[level](connection)
406
407    def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
408        if connection.transport == BT_LE_TRANSPORT:
409            return level >= LE_LEVEL3 and not connection.authenticated
410        return False
411
412    def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
413        if connection.transport == BT_LE_TRANSPORT:
414            return False
415        if level == LEVEL2 and connection.encryption != 0:
416            return not connection.authenticated
417        return level >= LEVEL2 and not connection.authenticated
418
419    def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
420        # TODO(abel): need to support MITM
421        if connection.transport == BT_LE_TRANSPORT:
422            return level == LE_LEVEL2 and not connection.encryption
423        return level >= LEVEL2 and not connection.encryption
424
425
426class SecurityStorageService(SecurityStorageServicer):
427    def __init__(self, device: Device) -> None:
428        self.log = utils.BumbleServerLoggerAdapter(
429            logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device}
430        )
431        self.device = device
432
433    @utils.rpc
434    async def IsBonded(self, request: IsBondedRequest, context: grpc.ServicerContext) -> wrappers_pb2.BoolValue:
435        address = utils.address_from_request(request, request.WhichOneof("address"))
436        self.log.info(f"IsBonded: {address}")
437
438        if self.device.keystore is not None:
439            is_bonded = await self.device.keystore.get(str(address)) is not None
440        else:
441            is_bonded = False
442
443        return BoolValue(value=is_bonded)
444
445    @utils.rpc
446    async def DeleteBond(self, request: DeleteBondRequest, context: grpc.ServicerContext) -> empty_pb2.Empty:
447        address = utils.address_from_request(request, request.WhichOneof("address"))
448        self.log.info(f"DeleteBond: {address}")
449
450        if self.device.keystore is not None:
451            with suppress(KeyError):
452                await self.device.keystore.delete(str(address))
453
454        return empty_pb2.Empty()
455