1"""Knowledge about the PSA key store as implemented in Mbed TLS. 2""" 3 4# Copyright The Mbed TLS Contributors 5# SPDX-License-Identifier: Apache-2.0 6# 7# Licensed under the Apache License, Version 2.0 (the "License"); you may 8# not use this file except in compliance with the License. 9# You may obtain a copy of the License at 10# 11# http://www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, software 14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16# See the License for the specific language governing permissions and 17# limitations under the License. 18 19import re 20import struct 21from typing import Dict, List, Optional, Set, Union 22import unittest 23 24from mbedtls_dev import c_build_helper 25 26 27class Expr: 28 """Representation of a C expression with a known or knowable numerical value.""" 29 30 def __init__(self, content: Union[int, str]): 31 if isinstance(content, int): 32 digits = 8 if content > 0xffff else 4 33 self.string = '{0:#0{1}x}'.format(content, digits + 2) 34 self.value_if_known = content #type: Optional[int] 35 else: 36 self.string = content 37 self.unknown_values.add(self.normalize(content)) 38 self.value_if_known = None 39 40 value_cache = {} #type: Dict[str, int] 41 """Cache of known values of expressions.""" 42 43 unknown_values = set() #type: Set[str] 44 """Expressions whose values are not present in `value_cache` yet.""" 45 46 def update_cache(self) -> None: 47 """Update `value_cache` for expressions registered in `unknown_values`.""" 48 expressions = sorted(self.unknown_values) 49 values = c_build_helper.get_c_expression_values( 50 'unsigned long', '%lu', 51 expressions, 52 header=""" 53 #include <psa/crypto.h> 54 """, 55 include_path=['include']) #type: List[str] 56 for e, v in zip(expressions, values): 57 self.value_cache[e] = int(v, 0) 58 self.unknown_values.clear() 59 60 @staticmethod 61 def normalize(string: str) -> str: 62 """Put the given C expression in a canonical form. 63 64 This function is only intended to give correct results for the 65 relatively simple kind of C expression typically used with this 66 module. 67 """ 68 return re.sub(r'\s+', r'', string) 69 70 def value(self) -> int: 71 """Return the numerical value of the expression.""" 72 if self.value_if_known is None: 73 if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I): 74 return int(self.string, 0) 75 normalized = self.normalize(self.string) 76 if normalized not in self.value_cache: 77 self.update_cache() 78 self.value_if_known = self.value_cache[normalized] 79 return self.value_if_known 80 81Exprable = Union[str, int, Expr] 82"""Something that can be converted to a C expression with a known numerical value.""" 83 84def as_expr(thing: Exprable) -> Expr: 85 """Return an `Expr` object for `thing`. 86 87 If `thing` is already an `Expr` object, return it. Otherwise build a new 88 `Expr` object from `thing`. `thing` can be an integer or a string that 89 contains a C expression. 90 """ 91 if isinstance(thing, Expr): 92 return thing 93 else: 94 return Expr(thing) 95 96 97class Key: 98 """Representation of a PSA crypto key object and its storage encoding. 99 """ 100 101 LATEST_VERSION = 0 102 """The latest version of the storage format.""" 103 104 def __init__(self, *, 105 version: Optional[int] = None, 106 id: Optional[int] = None, #pylint: disable=redefined-builtin 107 lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT', 108 type: Exprable, #pylint: disable=redefined-builtin 109 bits: int, 110 usage: Exprable, alg: Exprable, alg2: Exprable, 111 material: bytes #pylint: disable=used-before-assignment 112 ) -> None: 113 self.version = self.LATEST_VERSION if version is None else version 114 self.id = id #pylint: disable=invalid-name #type: Optional[int] 115 self.lifetime = as_expr(lifetime) #type: Expr 116 self.type = as_expr(type) #type: Expr 117 self.bits = bits #type: int 118 self.usage = as_expr(usage) #type: Expr 119 self.alg = as_expr(alg) #type: Expr 120 self.alg2 = as_expr(alg2) #type: Expr 121 self.material = material #type: bytes 122 123 MAGIC = b'PSA\000KEY\000' 124 125 @staticmethod 126 def pack( 127 fmt: str, 128 *args: Union[int, Expr] 129 ) -> bytes: #pylint: disable=used-before-assignment 130 """Pack the given arguments into a byte string according to the given format. 131 132 This function is similar to `struct.pack`, but with the following differences: 133 * All integer values are encoded with standard sizes and in 134 little-endian representation. `fmt` must not include an endianness 135 prefix. 136 * Arguments can be `Expr` objects instead of integers. 137 * Only integer-valued elements are supported. 138 """ 139 return struct.pack('<' + fmt, # little-endian, standard sizes 140 *[arg.value() if isinstance(arg, Expr) else arg 141 for arg in args]) 142 143 def bytes(self) -> bytes: 144 """Return the representation of the key in storage as a byte array. 145 146 This is the content of the PSA storage file. When PSA storage is 147 implemented over stdio files, this does not include any wrapping made 148 by the PSA-storage-over-stdio-file implementation. 149 """ 150 header = self.MAGIC + self.pack('L', self.version) 151 if self.version == 0: 152 attributes = self.pack('LHHLLL', 153 self.lifetime, self.type, self.bits, 154 self.usage, self.alg, self.alg2) 155 material = self.pack('L', len(self.material)) + self.material 156 else: 157 raise NotImplementedError 158 return header + attributes + material 159 160 def hex(self) -> str: 161 """Return the representation of the key as a hexadecimal string. 162 163 This is the hexadecimal representation of `self.bytes`. 164 """ 165 return self.bytes().hex() 166 167 def location_value(self) -> int: 168 """The numerical value of the location encoded in the key's lifetime.""" 169 return self.lifetime.value() >> 8 170 171 172class TestKey(unittest.TestCase): 173 # pylint: disable=line-too-long 174 """A few smoke tests for the functionality of the `Key` class.""" 175 176 def test_numerical(self): 177 key = Key(version=0, 178 id=1, lifetime=0x00000001, 179 type=0x2400, bits=128, 180 usage=0x00000300, alg=0x05500200, alg2=0x04c01000, 181 material=b'@ABCDEFGHIJKLMNO') 182 expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f' 183 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex)) 184 self.assertEqual(key.hex(), expected_hex) 185 186 def test_names(self): 187 length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes 188 key = Key(version=0, 189 id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT', 190 type='PSA_KEY_TYPE_RAW_DATA', bits=length*8, 191 usage=0, alg=0, alg2=0, 192 material=b'\x00' * length) 193 expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length 194 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex)) 195 self.assertEqual(key.hex(), expected_hex) 196 197 def test_defaults(self): 198 key = Key(type=0x1001, bits=8, 199 usage=0, alg=0, alg2=0, 200 material=b'\x2a') 201 expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a' 202 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex)) 203 self.assertEqual(key.hex(), expected_hex) 204