• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Keys and Key Storage
17#
18# -----------------------------------------------------------------------------
19
20# -----------------------------------------------------------------------------
21# Imports
22# -----------------------------------------------------------------------------
23import asyncio
24import logging
25import os
26import json
27from typing import Optional
28
29from .colors import color
30from .hci import Address
31
32
33# -----------------------------------------------------------------------------
34# Logging
35# -----------------------------------------------------------------------------
36logger = logging.getLogger(__name__)
37
38
39# -----------------------------------------------------------------------------
40class PairingKeys:
41    class Key:
42        def __init__(self, value, authenticated=False, ediv=None, rand=None):
43            self.value = value
44            self.authenticated = authenticated
45            self.ediv = ediv
46            self.rand = rand
47
48        @classmethod
49        def from_dict(cls, key_dict):
50            value = bytes.fromhex(key_dict['value'])
51            authenticated = key_dict.get('authenticated', False)
52            ediv = key_dict.get('ediv')
53            rand = key_dict.get('rand')
54            if rand is not None:
55                rand = bytes.fromhex(rand)
56
57            return cls(value, authenticated, ediv, rand)
58
59        def to_dict(self):
60            key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated}
61            if self.ediv is not None:
62                key_dict['ediv'] = self.ediv
63            if self.rand is not None:
64                key_dict['rand'] = self.rand.hex()
65
66            return key_dict
67
68    def __init__(self):
69        self.address_type = None
70        self.ltk = None
71        self.ltk_central = None
72        self.ltk_peripheral = None
73        self.irk = None
74        self.csrk = None
75        self.link_key = None  # Classic
76
77    @staticmethod
78    def key_from_dict(keys_dict, key_name):
79        key_dict = keys_dict.get(key_name)
80        if key_dict is None:
81            return None
82
83        return PairingKeys.Key.from_dict(key_dict)
84
85    @staticmethod
86    def from_dict(keys_dict):
87        keys = PairingKeys()
88
89        keys.address_type = keys_dict.get('address_type')
90        keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
91        keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
92        keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
93        keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
94        keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
95        keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
96
97        return keys
98
99    def to_dict(self):
100        keys = {}
101
102        if self.address_type is not None:
103            keys['address_type'] = self.address_type
104
105        if self.ltk is not None:
106            keys['ltk'] = self.ltk.to_dict()
107
108        if self.ltk_central is not None:
109            keys['ltk_central'] = self.ltk_central.to_dict()
110
111        if self.ltk_peripheral is not None:
112            keys['ltk_peripheral'] = self.ltk_peripheral.to_dict()
113
114        if self.irk is not None:
115            keys['irk'] = self.irk.to_dict()
116
117        if self.csrk is not None:
118            keys['csrk'] = self.csrk.to_dict()
119
120        if self.link_key is not None:
121            keys['link_key'] = self.link_key.to_dict()
122
123        return keys
124
125    def print(self, prefix=''):
126        keys_dict = self.to_dict()
127        for (container_property, value) in keys_dict.items():
128            if isinstance(value, dict):
129                print(f'{prefix}{color(container_property, "cyan")}:')
130                for (key_property, key_value) in value.items():
131                    print(f'{prefix}  {color(key_property, "green")}: {key_value}')
132            else:
133                print(f'{prefix}{color(container_property, "cyan")}: {value}')
134
135
136# -----------------------------------------------------------------------------
137class KeyStore:
138    async def delete(self, name):
139        pass
140
141    async def update(self, name, keys):
142        pass
143
144    async def get(self, _name):
145        return PairingKeys()
146
147    async def get_all(self):
148        return []
149
150    async def delete_all(self):
151        all_keys = await self.get_all()
152        await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
153
154    async def get_resolving_keys(self):
155        all_keys = await self.get_all()
156        resolving_keys = []
157        for (name, keys) in all_keys:
158            if keys.irk is not None:
159                if keys.address_type is None:
160                    address_type = Address.RANDOM_DEVICE_ADDRESS
161                else:
162                    address_type = keys.address_type
163                resolving_keys.append((keys.irk.value, Address(name, address_type)))
164
165        return resolving_keys
166
167    async def print(self, prefix=''):
168        entries = await self.get_all()
169        separator = ''
170        for (name, keys) in entries:
171            print(separator + prefix + color(name, 'yellow'))
172            keys.print(prefix=prefix + '  ')
173            separator = '\n'
174
175    @staticmethod
176    def create_for_device(device_config):
177        if device_config.keystore is None:
178            return None
179
180        keystore_type = device_config.keystore.split(':', 1)[0]
181        if keystore_type == 'JsonKeyStore':
182            return JsonKeyStore.from_device_config(device_config)
183
184        return None
185
186
187# -----------------------------------------------------------------------------
188class JsonKeyStore(KeyStore):
189    APP_NAME = 'Bumble'
190    APP_AUTHOR = 'Google'
191    KEYS_DIR = 'Pairing'
192    DEFAULT_NAMESPACE = '__DEFAULT__'
193
194    def __init__(self, namespace, filename=None):
195        self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
196
197        if filename is None:
198            # Use a default for the current user
199
200            # Import here because this may not exist on all platforms
201            # pylint: disable=import-outside-toplevel
202            import appdirs
203
204            self.directory_name = os.path.join(
205                appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
206            )
207            json_filename = f'{self.namespace}.json'.lower().replace(':', '-')
208            self.filename = os.path.join(self.directory_name, json_filename)
209        else:
210            self.filename = filename
211            self.directory_name = os.path.dirname(os.path.abspath(self.filename))
212
213        logger.debug(f'JSON keystore: {self.filename}')
214
215    @staticmethod
216    def from_device_config(device_config):
217        params = device_config.keystore.split(':', 1)[1:]
218        namespace = str(device_config.address)
219        if params:
220            filename = params[0]
221        else:
222            filename = None
223
224        return JsonKeyStore(namespace, filename)
225
226    async def load(self):
227        try:
228            with open(self.filename, 'r', encoding='utf-8') as json_file:
229                return json.load(json_file)
230        except FileNotFoundError:
231            return {}
232
233    async def save(self, db):
234        # Create the directory if it doesn't exist
235        if not os.path.exists(self.directory_name):
236            os.makedirs(self.directory_name, exist_ok=True)
237
238        # Save to a temporary file
239        temp_filename = self.filename + '.tmp'
240        with open(temp_filename, 'w', encoding='utf-8') as output:
241            json.dump(db, output, sort_keys=True, indent=4)
242
243        # Atomically replace the previous file
244        os.rename(temp_filename, self.filename)
245
246    async def delete(self, name: str) -> None:
247        db = await self.load()
248
249        namespace = db.get(self.namespace)
250        if namespace is None:
251            raise KeyError(name)
252
253        del namespace[name]
254        await self.save(db)
255
256    async def update(self, name, keys):
257        db = await self.load()
258
259        namespace = db.setdefault(self.namespace, {})
260        namespace[name] = keys.to_dict()
261
262        await self.save(db)
263
264    async def get_all(self):
265        db = await self.load()
266
267        namespace = db.get(self.namespace)
268        if namespace is None:
269            return []
270
271        return [
272            (name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
273        ]
274
275    async def delete_all(self):
276        db = await self.load()
277
278        db.pop(self.namespace, None)
279
280        await self.save(db)
281
282    async def get(self, name: str) -> Optional[PairingKeys]:
283        db = await self.load()
284
285        namespace = db.get(self.namespace)
286        if namespace is None:
287            return None
288
289        keys = namespace.get(name)
290        if keys is None:
291            return None
292
293        return PairingKeys.from_dict(keys)
294