1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Keys and Key Storage 17# 18# ----------------------------------------------------------------------------- 19 20# ----------------------------------------------------------------------------- 21# Imports 22# ----------------------------------------------------------------------------- 23import asyncio 24import logging 25import os 26import json 27from typing import Optional 28 29from .colors import color 30from .hci import Address 31 32 33# ----------------------------------------------------------------------------- 34# Logging 35# ----------------------------------------------------------------------------- 36logger = logging.getLogger(__name__) 37 38 39# ----------------------------------------------------------------------------- 40class PairingKeys: 41 class Key: 42 def __init__(self, value, authenticated=False, ediv=None, rand=None): 43 self.value = value 44 self.authenticated = authenticated 45 self.ediv = ediv 46 self.rand = rand 47 48 @classmethod 49 def from_dict(cls, key_dict): 50 value = bytes.fromhex(key_dict['value']) 51 authenticated = key_dict.get('authenticated', False) 52 ediv = key_dict.get('ediv') 53 rand = key_dict.get('rand') 54 if rand is not None: 55 rand = bytes.fromhex(rand) 56 57 return cls(value, authenticated, ediv, rand) 58 59 def to_dict(self): 60 key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated} 61 if self.ediv is not None: 62 key_dict['ediv'] = self.ediv 63 if self.rand is not None: 64 key_dict['rand'] = self.rand.hex() 65 66 return key_dict 67 68 def __init__(self): 69 self.address_type = None 70 self.ltk = None 71 self.ltk_central = None 72 self.ltk_peripheral = None 73 self.irk = None 74 self.csrk = None 75 self.link_key = None # Classic 76 77 @staticmethod 78 def key_from_dict(keys_dict, key_name): 79 key_dict = keys_dict.get(key_name) 80 if key_dict is None: 81 return None 82 83 return PairingKeys.Key.from_dict(key_dict) 84 85 @staticmethod 86 def from_dict(keys_dict): 87 keys = PairingKeys() 88 89 keys.address_type = keys_dict.get('address_type') 90 keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') 91 keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central') 92 keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral') 93 keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk') 94 keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk') 95 keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key') 96 97 return keys 98 99 def to_dict(self): 100 keys = {} 101 102 if self.address_type is not None: 103 keys['address_type'] = self.address_type 104 105 if self.ltk is not None: 106 keys['ltk'] = self.ltk.to_dict() 107 108 if self.ltk_central is not None: 109 keys['ltk_central'] = self.ltk_central.to_dict() 110 111 if self.ltk_peripheral is not None: 112 keys['ltk_peripheral'] = self.ltk_peripheral.to_dict() 113 114 if self.irk is not None: 115 keys['irk'] = self.irk.to_dict() 116 117 if self.csrk is not None: 118 keys['csrk'] = self.csrk.to_dict() 119 120 if self.link_key is not None: 121 keys['link_key'] = self.link_key.to_dict() 122 123 return keys 124 125 def print(self, prefix=''): 126 keys_dict = self.to_dict() 127 for (container_property, value) in keys_dict.items(): 128 if isinstance(value, dict): 129 print(f'{prefix}{color(container_property, "cyan")}:') 130 for (key_property, key_value) in value.items(): 131 print(f'{prefix} {color(key_property, "green")}: {key_value}') 132 else: 133 print(f'{prefix}{color(container_property, "cyan")}: {value}') 134 135 136# ----------------------------------------------------------------------------- 137class KeyStore: 138 async def delete(self, name): 139 pass 140 141 async def update(self, name, keys): 142 pass 143 144 async def get(self, _name): 145 return PairingKeys() 146 147 async def get_all(self): 148 return [] 149 150 async def delete_all(self): 151 all_keys = await self.get_all() 152 await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) 153 154 async def get_resolving_keys(self): 155 all_keys = await self.get_all() 156 resolving_keys = [] 157 for (name, keys) in all_keys: 158 if keys.irk is not None: 159 if keys.address_type is None: 160 address_type = Address.RANDOM_DEVICE_ADDRESS 161 else: 162 address_type = keys.address_type 163 resolving_keys.append((keys.irk.value, Address(name, address_type))) 164 165 return resolving_keys 166 167 async def print(self, prefix=''): 168 entries = await self.get_all() 169 separator = '' 170 for (name, keys) in entries: 171 print(separator + prefix + color(name, 'yellow')) 172 keys.print(prefix=prefix + ' ') 173 separator = '\n' 174 175 @staticmethod 176 def create_for_device(device_config): 177 if device_config.keystore is None: 178 return None 179 180 keystore_type = device_config.keystore.split(':', 1)[0] 181 if keystore_type == 'JsonKeyStore': 182 return JsonKeyStore.from_device_config(device_config) 183 184 return None 185 186 187# ----------------------------------------------------------------------------- 188class JsonKeyStore(KeyStore): 189 APP_NAME = 'Bumble' 190 APP_AUTHOR = 'Google' 191 KEYS_DIR = 'Pairing' 192 DEFAULT_NAMESPACE = '__DEFAULT__' 193 194 def __init__(self, namespace, filename=None): 195 self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE 196 197 if filename is None: 198 # Use a default for the current user 199 200 # Import here because this may not exist on all platforms 201 # pylint: disable=import-outside-toplevel 202 import appdirs 203 204 self.directory_name = os.path.join( 205 appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR 206 ) 207 json_filename = f'{self.namespace}.json'.lower().replace(':', '-') 208 self.filename = os.path.join(self.directory_name, json_filename) 209 else: 210 self.filename = filename 211 self.directory_name = os.path.dirname(os.path.abspath(self.filename)) 212 213 logger.debug(f'JSON keystore: {self.filename}') 214 215 @staticmethod 216 def from_device_config(device_config): 217 params = device_config.keystore.split(':', 1)[1:] 218 namespace = str(device_config.address) 219 if params: 220 filename = params[0] 221 else: 222 filename = None 223 224 return JsonKeyStore(namespace, filename) 225 226 async def load(self): 227 try: 228 with open(self.filename, 'r', encoding='utf-8') as json_file: 229 return json.load(json_file) 230 except FileNotFoundError: 231 return {} 232 233 async def save(self, db): 234 # Create the directory if it doesn't exist 235 if not os.path.exists(self.directory_name): 236 os.makedirs(self.directory_name, exist_ok=True) 237 238 # Save to a temporary file 239 temp_filename = self.filename + '.tmp' 240 with open(temp_filename, 'w', encoding='utf-8') as output: 241 json.dump(db, output, sort_keys=True, indent=4) 242 243 # Atomically replace the previous file 244 os.rename(temp_filename, self.filename) 245 246 async def delete(self, name: str) -> None: 247 db = await self.load() 248 249 namespace = db.get(self.namespace) 250 if namespace is None: 251 raise KeyError(name) 252 253 del namespace[name] 254 await self.save(db) 255 256 async def update(self, name, keys): 257 db = await self.load() 258 259 namespace = db.setdefault(self.namespace, {}) 260 namespace[name] = keys.to_dict() 261 262 await self.save(db) 263 264 async def get_all(self): 265 db = await self.load() 266 267 namespace = db.get(self.namespace) 268 if namespace is None: 269 return [] 270 271 return [ 272 (name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items() 273 ] 274 275 async def delete_all(self): 276 db = await self.load() 277 278 db.pop(self.namespace, None) 279 280 await self.save(db) 281 282 async def get(self, name: str) -> Optional[PairingKeys]: 283 db = await self.load() 284 285 namespace = db.get(self.namespace) 286 if namespace is None: 287 return None 288 289 keys = namespace.get(name) 290 if keys is None: 291 return None 292 293 return PairingKeys.from_dict(keys) 294