1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16import grpc 17import logging 18 19from . import utils 20from bumble import hci 21from bumble.core import BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT, BT_PERIPHERAL_ROLE, ProtocolError 22from bumble.device import Connection as BumbleConnection, Device 23from bumble.hci import HCI_Error 24from bumble.smp import PairingConfig, PairingDelegate as BasePairingDelegate 25from contextlib import suppress 26from google.protobuf import any_pb2, empty_pb2, wrappers_pb2 # pytype: disable=pyi-error 27from google.protobuf.wrappers_pb2 import BoolValue # pytype: disable=pyi-error 28from pandora.host_pb2 import Connection 29from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer 30from pandora.security_pb2 import ( 31 LE_LEVEL1, 32 LE_LEVEL2, 33 LE_LEVEL3, 34 LE_LEVEL4, 35 LEVEL0, 36 LEVEL1, 37 LEVEL2, 38 LEVEL3, 39 LEVEL4, 40 DeleteBondRequest, 41 IsBondedRequest, 42 LESecurityLevel, 43 PairingEvent, 44 PairingEventAnswer, 45 SecureRequest, 46 SecureResponse, 47 SecurityLevel, 48 WaitSecurityRequest, 49 WaitSecurityResponse, 50) 51from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union, cast 52 53 54class PairingDelegate(BasePairingDelegate): 55 def __init__( 56 self, 57 connection: BumbleConnection, 58 service: "SecurityService", 59 io_capability: int = BasePairingDelegate.NO_OUTPUT_NO_INPUT, 60 local_initiator_key_distribution: int = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION, 61 local_responder_key_distribution: int = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION, 62 ) -> None: 63 self.log = utils.BumbleServerLoggerAdapter( 64 logging.getLogger(), {'service_name': 'Security', 'device': connection.device} 65 ) 66 self.connection = connection 67 self.service = service 68 super().__init__(io_capability, local_initiator_key_distribution, local_responder_key_distribution) 69 70 async def accept(self) -> bool: 71 return True 72 73 def add_origin(self, ev: PairingEvent) -> PairingEvent: 74 if not self.connection.is_incomplete: 75 assert ev.connection 76 ev.connection.CopyFrom(Connection(cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big')))) 77 else: 78 # In BR/EDR, connection may not be complete, 79 # use address instead 80 assert self.connection.transport == BT_BR_EDR_TRANSPORT 81 ev.address = bytes(reversed(bytes(self.connection.peer_address))) 82 83 return ev 84 85 async def confirm(self) -> bool: 86 self.log.info(f"Pairing event: `just_works` (io_capability: {self.io_capability})") 87 88 if self.service.event_queue is None or self.service.event_answer is None: 89 return True 90 91 event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty())) 92 self.service.event_queue.put_nowait(event) 93 answer = await anext(self.service.event_answer) # pytype: disable=name-error 94 assert answer.event == event 95 assert answer.answer_variant() == 'confirm' and answer.confirm is not None 96 return answer.confirm 97 98 async def compare_numbers(self, number: int, digits: int = 6) -> bool: 99 self.log.info(f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})") 100 101 if self.service.event_queue is None or self.service.event_answer is None: 102 raise RuntimeError('security: unhandled number comparison request') 103 104 event = self.add_origin(PairingEvent(numeric_comparison=number)) 105 self.service.event_queue.put_nowait(event) 106 answer = await anext(self.service.event_answer) # pytype: disable=name-error 107 assert answer.event == event 108 assert answer.answer_variant() == 'confirm' and answer.confirm is not None 109 return answer.confirm 110 111 async def get_number(self) -> Optional[int]: 112 self.log.info(f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})") 113 114 if self.service.event_queue is None or self.service.event_answer is None: 115 raise RuntimeError('security: unhandled number request') 116 117 event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty())) 118 self.service.event_queue.put_nowait(event) 119 answer = await anext(self.service.event_answer) # pytype: disable=name-error 120 assert answer.event == event 121 assert answer.answer_variant() == 'passkey' 122 return answer.passkey 123 124 async def get_string(self, max_length: int) -> Optional[str]: 125 self.log.info(f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})") 126 127 if self.service.event_queue is None or self.service.event_answer is None: 128 raise RuntimeError('security: unhandled pin_code request') 129 130 event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty())) 131 self.service.event_queue.put_nowait(event) 132 answer = await anext(self.service.event_answer) # pytype: disable=name-error 133 assert answer.event == event 134 assert answer.answer_variant() == 'pin' 135 136 if answer.pin is None: 137 return None 138 139 pin = answer.pin.decode('utf-8') 140 if not pin or len(pin) > max_length: 141 raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes') 142 143 return pin 144 145 async def display_number(self, number: int, digits: int = 6) -> None: 146 self.log.info(f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})") 147 148 if self.service.event_queue is None: 149 raise RuntimeError('security: unhandled number display request') 150 151 event = self.add_origin(PairingEvent(passkey_entry_notification=number)) 152 self.service.event_queue.put_nowait(event) 153 154 155BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = { 156 LEVEL0: lambda connection: True, 157 LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated, 158 LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated, 159 LEVEL3: lambda connection: connection.encryption != 0 160 and connection.authenticated 161 and connection.link_key_type 162 in ( 163 hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, 164 hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, 165 ), 166 LEVEL4: lambda connection: connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM 167 and connection.authenticated 168 and connection.link_key_type == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, 169} 170 171LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = { 172 LE_LEVEL1: lambda connection: True, 173 LE_LEVEL2: lambda connection: connection.encryption != 0, 174 LE_LEVEL3: lambda connection: connection.encryption != 0 and connection.authenticated, 175 LE_LEVEL4: lambda connection: connection.encryption != 0 and connection.authenticated and connection.sc, 176} 177 178 179class SecurityService(SecurityServicer): 180 def __init__(self, device: Device, io_capability: int) -> None: 181 self.log = utils.BumbleServerLoggerAdapter(logging.getLogger(), {'service_name': 'Security', 'device': device}) 182 self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None 183 self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None 184 self.device = device 185 186 def pairing_config_factory(connection: BumbleConnection) -> PairingConfig: 187 return PairingConfig( 188 sc=True, 189 mitm=True, 190 bonding=True, 191 delegate=PairingDelegate( 192 connection, self, io_capability=cast(int, getattr(self.device, 'io_capability')) 193 ), 194 ) 195 196 setattr(device, 'io_capability', io_capability) 197 self.device.pairing_config_factory = pairing_config_factory 198 199 @utils.rpc 200 async def OnPairing( 201 self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext 202 ) -> AsyncGenerator[PairingEvent, None]: 203 self.log.info('OnPairing') 204 205 if self.event_queue is not None: 206 raise RuntimeError('already streaming pairing events') 207 208 if len(self.device.connections): 209 raise RuntimeError('the `OnPairing` method shall be initiated before establishing any connections.') 210 211 self.event_queue = asyncio.Queue() 212 self.event_answer = request 213 214 try: 215 while event := await self.event_queue.get(): 216 yield event 217 218 finally: 219 self.event_queue = None 220 self.event_answer = None 221 222 @utils.rpc 223 async def Secure(self, request: SecureRequest, context: grpc.ServicerContext) -> SecureResponse: 224 connection_handle = int.from_bytes(request.connection.cookie.value, 'big') 225 self.log.info(f"Secure: {connection_handle}") 226 227 connection = self.device.lookup_connection(connection_handle) 228 assert connection 229 230 oneof = request.WhichOneof('level') 231 level = getattr(request, oneof) 232 assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[connection.transport] == oneof 233 234 # security level already reached 235 if self.reached_security_level(connection, level): 236 return SecureResponse(success=empty_pb2.Empty()) 237 238 # trigger pairing if needed 239 if self.need_pairing(connection, level): 240 try: 241 self.log.info('Pair...') 242 243 if connection.transport == BT_LE_TRANSPORT and connection.role == BT_PERIPHERAL_ROLE: 244 wait_for_security: asyncio.Future[bool] = asyncio.get_running_loop().create_future() 245 connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore 246 connection.on("pairing_failure", wait_for_security.set_exception) 247 248 connection.request_pairing() 249 250 await wait_for_security 251 else: 252 await connection.pair() 253 254 self.log.info('Paired') 255 except asyncio.CancelledError: 256 self.log.warning(f"Connection died during encryption") 257 return SecureResponse(connection_died=empty_pb2.Empty()) 258 except (HCI_Error, ProtocolError) as e: 259 self.log.warning(f"Pairing failure: {e}") 260 return SecureResponse(pairing_failure=empty_pb2.Empty()) 261 262 # trigger authentication if needed 263 if self.need_authentication(connection, level): 264 try: 265 self.log.info('Authenticate...') 266 await connection.authenticate() 267 self.log.info('Authenticated') 268 except asyncio.CancelledError: 269 self.log.warning(f"Connection died during authentication") 270 return SecureResponse(connection_died=empty_pb2.Empty()) 271 except (HCI_Error, ProtocolError) as e: 272 self.log.warning(f"Authentication failure: {e}") 273 return SecureResponse(authentication_failure=empty_pb2.Empty()) 274 275 # trigger encryption if needed 276 if self.need_encryption(connection, level): 277 try: 278 self.log.info('Encrypt...') 279 await connection.encrypt() 280 self.log.info('Encrypted') 281 except asyncio.CancelledError: 282 self.log.warning(f"Connection died during encryption") 283 return SecureResponse(connection_died=empty_pb2.Empty()) 284 except (HCI_Error, ProtocolError) as e: 285 self.log.warning(f"Encryption failure: {e}") 286 return SecureResponse(encryption_failure=empty_pb2.Empty()) 287 288 # security level has been reached ? 289 if self.reached_security_level(connection, level): 290 return SecureResponse(success=empty_pb2.Empty()) 291 return SecureResponse(not_reached=empty_pb2.Empty()) 292 293 @utils.rpc 294 async def WaitSecurity(self, request: WaitSecurityRequest, context: grpc.ServicerContext) -> WaitSecurityResponse: 295 connection_handle = int.from_bytes(request.connection.cookie.value, 'big') 296 self.log.info(f"WaitSecurity: {connection_handle}") 297 298 connection = self.device.lookup_connection(connection_handle) 299 assert connection 300 301 assert request.level 302 level = request.level 303 assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[connection.transport] == request.level_variant() 304 305 wait_for_security: asyncio.Future[str] = asyncio.get_running_loop().create_future() 306 authenticate_task: Optional[asyncio.Future[None]] = None 307 308 async def authenticate() -> None: 309 assert connection 310 if (encryption := connection.encryption) != 0: 311 self.log.debug('Disable encryption...') 312 try: 313 await connection.encrypt(enable=False) 314 except: 315 pass 316 self.log.debug('Disable encryption: done') 317 318 self.log.debug('Authenticate...') 319 await connection.authenticate() 320 self.log.debug('Authenticate: done') 321 322 if encryption != 0 and connection.encryption != encryption: 323 self.log.debug('Re-enable encryption...') 324 await connection.encrypt() 325 self.log.debug('Re-enable encryption: done') 326 327 def set_failure(name: str) -> Callable[..., None]: 328 def wrapper(*args: Any) -> None: 329 self.log.info(f'Wait for security: error `{name}`: {args}') 330 wait_for_security.set_result(name) 331 332 return wrapper 333 334 def try_set_success(*_: Any) -> None: 335 assert connection 336 if self.reached_security_level(connection, level): 337 self.log.info(f'Wait for security: done') 338 wait_for_security.set_result('success') 339 340 def on_encryption_change(*_: Any) -> None: 341 assert connection 342 if self.reached_security_level(connection, level): 343 self.log.info(f'Wait for security: done') 344 wait_for_security.set_result('success') 345 elif connection.transport == BT_BR_EDR_TRANSPORT and self.need_authentication(connection, level): 346 nonlocal authenticate_task 347 if authenticate_task is None: 348 authenticate_task = asyncio.create_task(authenticate()) 349 350 listeners: Dict[str, Callable[..., None]] = { 351 'disconnection': set_failure('connection_died'), 352 'pairing_failure': set_failure('pairing_failure'), 353 'connection_authentication_failure': set_failure('authentication_failure'), 354 'connection_encryption_failure': set_failure('encryption_failure'), 355 'pairing': try_set_success, 356 'connection_authentication': try_set_success, 357 'connection_encryption_change': on_encryption_change, 358 } 359 360 # register event handlers 361 for event, listener in listeners.items(): 362 connection.on(event, listener) 363 364 # security level already reached 365 if self.reached_security_level(connection, level): 366 return WaitSecurityResponse(success=empty_pb2.Empty()) 367 368 self.log.info('Wait for security...') 369 kwargs = {} 370 kwargs[await wait_for_security] = empty_pb2.Empty() 371 372 # remove event handlers 373 for event, listener in listeners.items(): 374 connection.remove_listener(event, listener) # type: ignore 375 376 # wait for `authenticate` to finish if any 377 if authenticate_task is not None: 378 self.log.info('Wait for authentication...') 379 try: 380 await authenticate_task # type: ignore 381 except: 382 pass 383 self.log.info('Authenticated') 384 385 return WaitSecurityResponse(**kwargs) 386 387 def reached_security_level( 388 self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel] 389 ) -> bool: 390 self.log.debug( 391 str( 392 { 393 'level': level, 394 'encryption': connection.encryption, 395 'authenticated': connection.authenticated, 396 'sc': connection.sc, 397 'link_key_type': connection.link_key_type, 398 } 399 ) 400 ) 401 402 if isinstance(level, LESecurityLevel): 403 return LE_LEVEL_REACHED[level](connection) 404 405 return BR_LEVEL_REACHED[level](connection) 406 407 def need_pairing(self, connection: BumbleConnection, level: int) -> bool: 408 if connection.transport == BT_LE_TRANSPORT: 409 return level >= LE_LEVEL3 and not connection.authenticated 410 return False 411 412 def need_authentication(self, connection: BumbleConnection, level: int) -> bool: 413 if connection.transport == BT_LE_TRANSPORT: 414 return False 415 if level == LEVEL2 and connection.encryption != 0: 416 return not connection.authenticated 417 return level >= LEVEL2 and not connection.authenticated 418 419 def need_encryption(self, connection: BumbleConnection, level: int) -> bool: 420 # TODO(abel): need to support MITM 421 if connection.transport == BT_LE_TRANSPORT: 422 return level == LE_LEVEL2 and not connection.encryption 423 return level >= LEVEL2 and not connection.encryption 424 425 426class SecurityStorageService(SecurityStorageServicer): 427 def __init__(self, device: Device) -> None: 428 self.log = utils.BumbleServerLoggerAdapter( 429 logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device} 430 ) 431 self.device = device 432 433 @utils.rpc 434 async def IsBonded(self, request: IsBondedRequest, context: grpc.ServicerContext) -> wrappers_pb2.BoolValue: 435 address = utils.address_from_request(request, request.WhichOneof("address")) 436 self.log.info(f"IsBonded: {address}") 437 438 if self.device.keystore is not None: 439 is_bonded = await self.device.keystore.get(str(address)) is not None 440 else: 441 is_bonded = False 442 443 return BoolValue(value=is_bonded) 444 445 @utils.rpc 446 async def DeleteBond(self, request: DeleteBondRequest, context: grpc.ServicerContext) -> empty_pb2.Empty: 447 address = utils.address_from_request(request, request.WhichOneof("address")) 448 self.log.info(f"DeleteBond: {address}") 449 450 if self.device.keystore is not None: 451 with suppress(KeyError): 452 await self.device.keystore.delete(str(address)) 453 454 return empty_pb2.Empty() 455