# Copyright 2021-2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Keys and Key Storage # # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio import logging import os import json from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type from typing_extensions import Self from .colors import color from .hci import Address if TYPE_CHECKING: from .device import Device # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- class PairingKeys: class Key: def __init__(self, value, authenticated=False, ediv=None, rand=None): self.value = value self.authenticated = authenticated self.ediv = ediv self.rand = rand @classmethod def from_dict(cls, key_dict): value = bytes.fromhex(key_dict['value']) authenticated = key_dict.get('authenticated', False) ediv = key_dict.get('ediv') rand = key_dict.get('rand') if rand is not None: rand = bytes.fromhex(rand) return cls(value, authenticated, ediv, rand) def to_dict(self): key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated} if self.ediv is not None: key_dict['ediv'] = self.ediv if self.rand is not None: key_dict['rand'] = self.rand.hex() return key_dict def __init__(self): self.address_type = None self.ltk = None self.ltk_central = None self.ltk_peripheral = None self.irk = None self.csrk = None self.link_key = None # Classic @staticmethod def key_from_dict(keys_dict, key_name): key_dict = keys_dict.get(key_name) if key_dict is None: return None return PairingKeys.Key.from_dict(key_dict) @staticmethod def from_dict(keys_dict): keys = PairingKeys() keys.address_type = keys_dict.get('address_type') keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central') keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral') keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk') keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk') keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key') return keys def to_dict(self): keys = {} if self.address_type is not None: keys['address_type'] = self.address_type if self.ltk is not None: keys['ltk'] = self.ltk.to_dict() if self.ltk_central is not None: keys['ltk_central'] = self.ltk_central.to_dict() if self.ltk_peripheral is not None: keys['ltk_peripheral'] = self.ltk_peripheral.to_dict() if self.irk is not None: keys['irk'] = self.irk.to_dict() if self.csrk is not None: keys['csrk'] = self.csrk.to_dict() if self.link_key is not None: keys['link_key'] = self.link_key.to_dict() return keys def print(self, prefix=''): keys_dict = self.to_dict() for container_property, value in keys_dict.items(): if isinstance(value, dict): print(f'{prefix}{color(container_property, "cyan")}:') for key_property, key_value in value.items(): print(f'{prefix} {color(key_property, "green")}: {key_value}') else: print(f'{prefix}{color(container_property, "cyan")}: {value}') # ----------------------------------------------------------------------------- class KeyStore: async def delete(self, name: str): pass async def update(self, name: str, keys: PairingKeys) -> None: pass async def get(self, _name: str) -> Optional[PairingKeys]: return None async def get_all(self) -> List[Tuple[str, PairingKeys]]: return [] async def delete_all(self) -> None: all_keys = await self.get_all() await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) async def get_resolving_keys(self): all_keys = await self.get_all() resolving_keys = [] for name, keys in all_keys: if keys.irk is not None: if keys.address_type is None: address_type = Address.RANDOM_DEVICE_ADDRESS else: address_type = keys.address_type resolving_keys.append((keys.irk.value, Address(name, address_type))) return resolving_keys async def print(self, prefix=''): entries = await self.get_all() separator = '' for name, keys in entries: print(separator + prefix + color(name, 'yellow')) keys.print(prefix=prefix + ' ') separator = '\n' @staticmethod def create_for_device(device: Device) -> KeyStore: if device.config.keystore is None: return MemoryKeyStore() keystore_type = device.config.keystore.split(':', 1)[0] if keystore_type == 'JsonKeyStore': return JsonKeyStore.from_device(device) return MemoryKeyStore() # ----------------------------------------------------------------------------- class JsonKeyStore(KeyStore): """ KeyStore implementation that is backed by a JSON file. This implementation supports storing a hierarchy of key sets in a single file. A key set is a representation of a PairingKeys object. Each key set is stored in a map, with the address of paired peer as the key. Maps are themselves grouped into namespaces, grouping pairing keys by controller addresses. The JSON object model looks like: { "": { "peer-address": { "address_type": , "irk" : { "authenticated": , "value": "hex-encoded-key" }, ... other keys ... }, ... other peers ... } ... other namespaces ... } A namespace is typically the BD_ADDR of a controller, since that is a convenient unique identifier, but it may be something else. A special namespace, called the "default" namespace, is used when instantiating this class without a namespace. With the default namespace, reading from a file will load an existing namespace if there is only one, which may be convenient for reading from a file with a single key set and for which the namespace isn't known. If the file does not include any existing key set, or if there are more than one and none has the default name, a new one will be created with the name "__DEFAULT__". """ APP_NAME = 'Bumble' APP_AUTHOR = 'Google' KEYS_DIR = 'Pairing' DEFAULT_NAMESPACE = '__DEFAULT__' DEFAULT_BASE_NAME = "keys" def __init__(self, namespace, filename=None): self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE if filename is None: # Use a default for the current user # Import here because this may not exist on all platforms # pylint: disable=import-outside-toplevel import appdirs self.directory_name = os.path.join( appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR ) base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace json_filename = ( f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p') ) self.filename = os.path.join(self.directory_name, json_filename) else: self.filename = filename self.directory_name = os.path.dirname(os.path.abspath(self.filename)) logger.debug(f'JSON keystore: {self.filename}') @classmethod def from_device( cls: Type[Self], device: Device, filename: Optional[str] = None ) -> Self: if not filename: # Extract the filename from the config if there is one if device.config.keystore is not None: params = device.config.keystore.split(':', 1)[1:] if params: filename = params[0] # Use a namespace based on the device address if device.public_address not in (Address.ANY, Address.ANY_RANDOM): namespace = str(device.public_address) elif device.random_address != Address.ANY_RANDOM: namespace = str(device.random_address) else: namespace = JsonKeyStore.DEFAULT_NAMESPACE return cls(namespace, filename) async def load(self): # Try to open the file, without failing. If the file does not exist, it # will be created upon saving. try: with open(self.filename, 'r', encoding='utf-8') as json_file: db = json.load(json_file) except FileNotFoundError: db = {} # First, look for a namespace match if self.namespace in db: return (db, db[self.namespace]) # Then, if the namespace is the default namespace, and there's # only one entry in the db, use that if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1: return next(iter(db.items())) # Finally, just create an empty key map for the namespace key_map = {} db[self.namespace] = key_map return (db, key_map) async def save(self, db): # Create the directory if it doesn't exist if not os.path.exists(self.directory_name): os.makedirs(self.directory_name, exist_ok=True) # Save to a temporary file temp_filename = self.filename + '.tmp' with open(temp_filename, 'w', encoding='utf-8') as output: json.dump(db, output, sort_keys=True, indent=4) # Atomically replace the previous file os.replace(temp_filename, self.filename) async def delete(self, name: str) -> None: db, key_map = await self.load() del key_map[name] await self.save(db) async def update(self, name, keys): db, key_map = await self.load() key_map.setdefault(name, {}).update(keys.to_dict()) await self.save(db) async def get_all(self): _, key_map = await self.load() return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()] async def delete_all(self): db, key_map = await self.load() key_map.clear() await self.save(db) async def get(self, name: str) -> Optional[PairingKeys]: _, key_map = await self.load() if name not in key_map: return None return PairingKeys.from_dict(key_map[name]) # ----------------------------------------------------------------------------- class MemoryKeyStore(KeyStore): all_keys: Dict[str, PairingKeys] def __init__(self) -> None: self.all_keys = {} async def delete(self, name: str) -> None: if name in self.all_keys: del self.all_keys[name] async def update(self, name: str, keys: PairingKeys) -> None: self.all_keys[name] = keys async def get(self, name: str) -> Optional[PairingKeys]: return self.all_keys.get(name) async def get_all(self) -> List[Tuple[str, PairingKeys]]: return list(self.all_keys.items())