1"""Knowledge about cryptographic mechanisms implemented in Mbed TLS. 2 3This module is entirely based on the PSA API. 4""" 5 6# Copyright The Mbed TLS Contributors 7# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later 8# 9 10import enum 11import re 12from typing import FrozenSet, Iterable, List, Optional, Tuple, Dict 13 14from .asymmetric_key_data import ASYMMETRIC_KEY_DATA 15 16 17def short_expression(original: str, level: int = 0) -> str: 18 """Abbreviate the expression, keeping it human-readable. 19 20 If `level` is 0, just remove parts that are implicit from context, 21 such as a leading ``PSA_KEY_TYPE_``. 22 For larger values of `level`, also abbreviate some names in an 23 unambiguous, but ad hoc way. 24 """ 25 short = original 26 short = re.sub(r'\bPSA_(?:ALG|ECC_FAMILY|KEY_[A-Z]+)_', r'', short) 27 short = re.sub(r' +', r'', short) 28 if level >= 1: 29 short = re.sub(r'PUBLIC_KEY\b', r'PUB', short) 30 short = re.sub(r'KEY_PAIR\b', r'PAIR', short) 31 short = re.sub(r'\bBRAINPOOL_P', r'BP', short) 32 short = re.sub(r'\bMONTGOMERY\b', r'MGM', short) 33 short = re.sub(r'AEAD_WITH_SHORTENED_TAG\b', r'AEAD_SHORT', short) 34 short = re.sub(r'\bDETERMINISTIC_', r'DET_', short) 35 short = re.sub(r'\bKEY_AGREEMENT\b', r'KA', short) 36 short = re.sub(r'_PSK_TO_MS\b', r'_PSK2MS', short) 37 return short 38 39 40BLOCK_CIPHERS = frozenset(['AES', 'ARIA', 'CAMELLIA', 'DES']) 41BLOCK_MAC_MODES = frozenset(['CBC_MAC', 'CMAC']) 42BLOCK_CIPHER_MODES = frozenset([ 43 'CTR', 'CFB', 'OFB', 'XTS', 'CCM_STAR_NO_TAG', 44 'ECB_NO_PADDING', 'CBC_NO_PADDING', 'CBC_PKCS7', 45]) 46BLOCK_AEAD_MODES = frozenset(['CCM', 'GCM']) 47 48class EllipticCurveCategory(enum.Enum): 49 """Categorization of elliptic curve families. 50 51 The category of a curve determines what algorithms are defined over it. 52 """ 53 54 SHORT_WEIERSTRASS = 0 55 MONTGOMERY = 1 56 TWISTED_EDWARDS = 2 57 58 @staticmethod 59 def from_family(family: str) -> 'EllipticCurveCategory': 60 if family == 'PSA_ECC_FAMILY_MONTGOMERY': 61 return EllipticCurveCategory.MONTGOMERY 62 if family == 'PSA_ECC_FAMILY_TWISTED_EDWARDS': 63 return EllipticCurveCategory.TWISTED_EDWARDS 64 # Default to SW, which most curves belong to. 65 return EllipticCurveCategory.SHORT_WEIERSTRASS 66 67 68class KeyType: 69 """Knowledge about a PSA key type.""" 70 71 def __init__(self, name: str, params: Optional[Iterable[str]] = None) -> None: 72 """Analyze a key type. 73 74 The key type must be specified in PSA syntax. In its simplest form, 75 `name` is a string 'PSA_KEY_TYPE_xxx' which is the name of a PSA key 76 type macro. For key types that take arguments, the arguments can 77 be passed either through the optional argument `params` or by 78 passing an expression of the form 'PSA_KEY_TYPE_xxx(param1, ...)' 79 in `name` as a string. 80 """ 81 82 self.name = name.strip() 83 """The key type macro name (``PSA_KEY_TYPE_xxx``). 84 85 For key types constructed from a macro with arguments, this is the 86 name of the macro, and the arguments are in `self.params`. 87 """ 88 if params is None: 89 if '(' in self.name: 90 m = re.match(r'(\w+)\s*\((.*)\)\Z', self.name) 91 assert m is not None 92 self.name = m.group(1) 93 params = m.group(2).split(',') 94 self.params = (None if params is None else 95 [param.strip() for param in params]) 96 """The parameters of the key type, if there are any. 97 98 None if the key type is a macro without arguments. 99 """ 100 assert re.match(r'PSA_KEY_TYPE_\w+\Z', self.name) 101 102 self.expression = self.name 103 """A C expression whose value is the key type encoding.""" 104 if self.params is not None: 105 self.expression += '(' + ', '.join(self.params) + ')' 106 107 m = re.match(r'PSA_KEY_TYPE_(\w+)', self.name) 108 assert m 109 self.head = re.sub(r'_(?:PUBLIC_KEY|KEY_PAIR)\Z', r'', m.group(1)) 110 """The key type macro name, with common prefixes and suffixes stripped.""" 111 112 self.private_type = re.sub(r'_PUBLIC_KEY\Z', r'_KEY_PAIR', self.name) 113 """The key type macro name for the corresponding key pair type. 114 115 For everything other than a public key type, this is the same as 116 `self.name`. 117 """ 118 119 def short_expression(self, level: int = 0) -> str: 120 """Abbreviate the expression, keeping it human-readable. 121 122 See `crypto_knowledge.short_expression`. 123 """ 124 return short_expression(self.expression, level=level) 125 126 def is_public(self) -> bool: 127 """Whether the key type is for public keys.""" 128 return self.name.endswith('_PUBLIC_KEY') 129 130 ECC_KEY_SIZES = { 131 'PSA_ECC_FAMILY_SECP_K1': (192, 224, 256), 132 'PSA_ECC_FAMILY_SECP_R1': (225, 256, 384, 521), 133 'PSA_ECC_FAMILY_SECP_R2': (160,), 134 'PSA_ECC_FAMILY_SECT_K1': (163, 233, 239, 283, 409, 571), 135 'PSA_ECC_FAMILY_SECT_R1': (163, 233, 283, 409, 571), 136 'PSA_ECC_FAMILY_SECT_R2': (163,), 137 'PSA_ECC_FAMILY_BRAINPOOL_P_R1': (160, 192, 224, 256, 320, 384, 512), 138 'PSA_ECC_FAMILY_MONTGOMERY': (255, 448), 139 'PSA_ECC_FAMILY_TWISTED_EDWARDS': (255, 448), 140 } # type: Dict[str, Tuple[int, ...]] 141 KEY_TYPE_SIZES = { 142 'PSA_KEY_TYPE_AES': (128, 192, 256), # exhaustive 143 'PSA_KEY_TYPE_ARC4': (8, 128, 2048), # extremes + sensible 144 'PSA_KEY_TYPE_ARIA': (128, 192, 256), # exhaustive 145 'PSA_KEY_TYPE_CAMELLIA': (128, 192, 256), # exhaustive 146 'PSA_KEY_TYPE_CHACHA20': (256,), # exhaustive 147 'PSA_KEY_TYPE_DERIVE': (120, 128), # sample 148 'PSA_KEY_TYPE_DES': (64, 128, 192), # exhaustive 149 'PSA_KEY_TYPE_HMAC': (128, 160, 224, 256, 384, 512), # standard size for each supported hash 150 'PSA_KEY_TYPE_RAW_DATA': (8, 40, 128), # sample 151 'PSA_KEY_TYPE_RSA_KEY_PAIR': (1024, 1536), # small sample 152 } # type: Dict[str, Tuple[int, ...]] 153 def sizes_to_test(self) -> Tuple[int, ...]: 154 """Return a tuple of key sizes to test. 155 156 For key types that only allow a single size, or only a small set of 157 sizes, these are all the possible sizes. For key types that allow a 158 wide range of sizes, these are a representative sample of sizes, 159 excluding large sizes for which a typical resource-constrained platform 160 may run out of memory. 161 """ 162 if self.private_type == 'PSA_KEY_TYPE_ECC_KEY_PAIR': 163 assert self.params is not None 164 return self.ECC_KEY_SIZES[self.params[0]] 165 return self.KEY_TYPE_SIZES[self.private_type] 166 167 # "48657265006973206b6579a064617461" 168 DATA_BLOCK = b'Here\000is key\240data' 169 def key_material(self, bits: int) -> bytes: 170 """Return a byte string containing suitable key material with the given bit length. 171 172 Use the PSA export representation. The resulting byte string is one that 173 can be obtained with the following code: 174 ``` 175 psa_set_key_type(&attributes, `self.expression`); 176 psa_set_key_bits(&attributes, `bits`); 177 psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_EXPORT); 178 psa_generate_key(&attributes, &id); 179 psa_export_key(id, `material`, ...); 180 ``` 181 """ 182 if self.expression in ASYMMETRIC_KEY_DATA: 183 if bits not in ASYMMETRIC_KEY_DATA[self.expression]: 184 raise ValueError('No key data for {}-bit {}' 185 .format(bits, self.expression)) 186 return ASYMMETRIC_KEY_DATA[self.expression][bits] 187 if bits % 8 != 0: 188 raise ValueError('Non-integer number of bytes: {} bits for {}' 189 .format(bits, self.expression)) 190 length = bits // 8 191 if self.name == 'PSA_KEY_TYPE_DES': 192 # "644573206b457901644573206b457902644573206b457904" 193 des3 = b'dEs kEy\001dEs kEy\002dEs kEy\004' 194 return des3[:length] 195 return b''.join([self.DATA_BLOCK] * (length // len(self.DATA_BLOCK)) + 196 [self.DATA_BLOCK[:length % len(self.DATA_BLOCK)]]) 197 198 def can_do(self, alg: 'Algorithm') -> bool: 199 """Whether this key type can be used for operations with the given algorithm. 200 201 This function does not currently handle key derivation or PAKE. 202 """ 203 #pylint: disable=too-many-branches,too-many-return-statements 204 if not alg.is_valid_for_operation(): 205 return False 206 if self.head == 'HMAC' and alg.head == 'HMAC': 207 return True 208 if self.head == 'DES': 209 # 64-bit block ciphers only allow a reduced set of modes. 210 return alg.head in [ 211 'CBC_NO_PADDING', 'CBC_PKCS7', 212 'ECB_NO_PADDING', 213 ] 214 if self.head in BLOCK_CIPHERS and \ 215 alg.head in frozenset.union(BLOCK_MAC_MODES, 216 BLOCK_CIPHER_MODES, 217 BLOCK_AEAD_MODES): 218 if alg.head in ['CMAC', 'OFB'] and \ 219 self.head in ['ARIA', 'CAMELLIA']: 220 return False # not implemented in Mbed TLS 221 return True 222 if self.head == 'CHACHA20' and alg.head == 'CHACHA20_POLY1305': 223 return True 224 if self.head in {'ARC4', 'CHACHA20'} and \ 225 alg.head == 'STREAM_CIPHER': 226 return True 227 if self.head == 'RSA' and alg.head.startswith('RSA_'): 228 return True 229 if alg.category == AlgorithmCategory.KEY_AGREEMENT and \ 230 self.is_public(): 231 # The PSA API does not use public key objects in key agreement 232 # operations: it imports the public key as a formatted byte string. 233 # So a public key object with a key agreement algorithm is not 234 # a valid combination. 235 return False 236 if self.head == 'ECC': 237 assert self.params is not None 238 eccc = EllipticCurveCategory.from_family(self.params[0]) 239 if alg.head == 'ECDH' and \ 240 eccc in {EllipticCurveCategory.SHORT_WEIERSTRASS, 241 EllipticCurveCategory.MONTGOMERY}: 242 return True 243 if alg.head == 'ECDSA' and \ 244 eccc == EllipticCurveCategory.SHORT_WEIERSTRASS: 245 return True 246 if alg.head in {'PURE_EDDSA', 'EDDSA_PREHASH'} and \ 247 eccc == EllipticCurveCategory.TWISTED_EDWARDS: 248 return True 249 return False 250 251 252class AlgorithmCategory(enum.Enum): 253 """PSA algorithm categories.""" 254 # The numbers are aligned with the category bits in numerical values of 255 # algorithms. 256 HASH = 2 257 MAC = 3 258 CIPHER = 4 259 AEAD = 5 260 SIGN = 6 261 ASYMMETRIC_ENCRYPTION = 7 262 KEY_DERIVATION = 8 263 KEY_AGREEMENT = 9 264 PAKE = 10 265 266 def requires_key(self) -> bool: 267 """Whether operations in this category are set up with a key.""" 268 return self not in {self.HASH, self.KEY_DERIVATION} 269 270 def is_asymmetric(self) -> bool: 271 """Whether operations in this category involve asymmetric keys.""" 272 return self in { 273 self.SIGN, 274 self.ASYMMETRIC_ENCRYPTION, 275 self.KEY_AGREEMENT 276 } 277 278 279class AlgorithmNotRecognized(Exception): 280 def __init__(self, expr: str) -> None: 281 super().__init__('Algorithm not recognized: ' + expr) 282 self.expr = expr 283 284 285class Algorithm: 286 """Knowledge about a PSA algorithm.""" 287 288 @staticmethod 289 def determine_base(expr: str) -> str: 290 """Return an expression for the "base" of the algorithm. 291 292 This strips off variants of algorithms such as MAC truncation. 293 294 This function does not attempt to detect invalid inputs. 295 """ 296 m = re.match(r'PSA_ALG_(?:' 297 r'(?:TRUNCATED|AT_LEAST_THIS_LENGTH)_MAC|' 298 r'AEAD_WITH_(?:SHORTENED|AT_LEAST_THIS_LENGTH)_TAG' 299 r')\((.*),[^,]+\)\Z', expr) 300 if m: 301 expr = m.group(1) 302 return expr 303 304 @staticmethod 305 def determine_head(expr: str) -> str: 306 """Return the head of an algorithm expression. 307 308 The head is the first (outermost) constructor, without its PSA_ALG_ 309 prefix, and with some normalization of similar algorithms. 310 """ 311 m = re.match(r'PSA_ALG_(?:DETERMINISTIC_)?(\w+)', expr) 312 if not m: 313 raise AlgorithmNotRecognized(expr) 314 head = m.group(1) 315 if head == 'KEY_AGREEMENT': 316 m = re.match(r'PSA_ALG_KEY_AGREEMENT\s*\(\s*PSA_ALG_(\w+)', expr) 317 if not m: 318 raise AlgorithmNotRecognized(expr) 319 head = m.group(1) 320 head = re.sub(r'_ANY\Z', r'', head) 321 if re.match(r'ED[0-9]+PH\Z', head): 322 head = 'EDDSA_PREHASH' 323 return head 324 325 CATEGORY_FROM_HEAD = { 326 'SHA': AlgorithmCategory.HASH, 327 'SHAKE256_512': AlgorithmCategory.HASH, 328 'MD': AlgorithmCategory.HASH, 329 'RIPEMD': AlgorithmCategory.HASH, 330 'ANY_HASH': AlgorithmCategory.HASH, 331 'HMAC': AlgorithmCategory.MAC, 332 'STREAM_CIPHER': AlgorithmCategory.CIPHER, 333 'CHACHA20_POLY1305': AlgorithmCategory.AEAD, 334 'DSA': AlgorithmCategory.SIGN, 335 'ECDSA': AlgorithmCategory.SIGN, 336 'EDDSA': AlgorithmCategory.SIGN, 337 'PURE_EDDSA': AlgorithmCategory.SIGN, 338 'RSA_PSS': AlgorithmCategory.SIGN, 339 'RSA_PKCS1V15_SIGN': AlgorithmCategory.SIGN, 340 'RSA_PKCS1V15_CRYPT': AlgorithmCategory.ASYMMETRIC_ENCRYPTION, 341 'RSA_OAEP': AlgorithmCategory.ASYMMETRIC_ENCRYPTION, 342 'HKDF': AlgorithmCategory.KEY_DERIVATION, 343 'TLS12_PRF': AlgorithmCategory.KEY_DERIVATION, 344 'TLS12_PSK_TO_MS': AlgorithmCategory.KEY_DERIVATION, 345 'PBKDF': AlgorithmCategory.KEY_DERIVATION, 346 'ECDH': AlgorithmCategory.KEY_AGREEMENT, 347 'FFDH': AlgorithmCategory.KEY_AGREEMENT, 348 # KEY_AGREEMENT(...) is a key derivation with a key agreement component 349 'KEY_AGREEMENT': AlgorithmCategory.KEY_DERIVATION, 350 'JPAKE': AlgorithmCategory.PAKE, 351 } 352 for x in BLOCK_MAC_MODES: 353 CATEGORY_FROM_HEAD[x] = AlgorithmCategory.MAC 354 for x in BLOCK_CIPHER_MODES: 355 CATEGORY_FROM_HEAD[x] = AlgorithmCategory.CIPHER 356 for x in BLOCK_AEAD_MODES: 357 CATEGORY_FROM_HEAD[x] = AlgorithmCategory.AEAD 358 359 def determine_category(self, expr: str, head: str) -> AlgorithmCategory: 360 """Return the category of the given algorithm expression. 361 362 This function does not attempt to detect invalid inputs. 363 """ 364 prefix = head 365 while prefix: 366 if prefix in self.CATEGORY_FROM_HEAD: 367 return self.CATEGORY_FROM_HEAD[prefix] 368 if re.match(r'.*[0-9]\Z', prefix): 369 prefix = re.sub(r'_*[0-9]+\Z', r'', prefix) 370 else: 371 prefix = re.sub(r'_*[^_]*\Z', r'', prefix) 372 raise AlgorithmNotRecognized(expr) 373 374 @staticmethod 375 def determine_wildcard(expr) -> bool: 376 """Whether the given algorithm expression is a wildcard. 377 378 This function does not attempt to detect invalid inputs. 379 """ 380 if re.search(r'\bPSA_ALG_ANY_HASH\b', expr): 381 return True 382 if re.search(r'_AT_LEAST_', expr): 383 return True 384 return False 385 386 def __init__(self, expr: str) -> None: 387 """Analyze an algorithm value. 388 389 The algorithm must be expressed as a C expression containing only 390 calls to PSA algorithm constructor macros and numeric literals. 391 392 This class is only programmed to handle valid expressions. Invalid 393 expressions may result in exceptions or in nonsensical results. 394 """ 395 self.expression = re.sub(r'\s+', r'', expr) 396 self.base_expression = self.determine_base(self.expression) 397 self.head = self.determine_head(self.base_expression) 398 self.category = self.determine_category(self.base_expression, self.head) 399 self.is_wildcard = self.determine_wildcard(self.expression) 400 401 def is_key_agreement_with_derivation(self) -> bool: 402 """Whether this is a combined key agreement and key derivation algorithm.""" 403 if self.category != AlgorithmCategory.KEY_AGREEMENT: 404 return False 405 m = re.match(r'PSA_ALG_KEY_AGREEMENT\(\w+,\s*(.*)\)\Z', self.expression) 406 if not m: 407 return False 408 kdf_alg = m.group(1) 409 # Assume kdf_alg is either a valid KDF or 0. 410 return not re.match(r'(?:0[Xx])?0+\s*\Z', kdf_alg) 411 412 413 def short_expression(self, level: int = 0) -> str: 414 """Abbreviate the expression, keeping it human-readable. 415 416 See `crypto_knowledge.short_expression`. 417 """ 418 return short_expression(self.expression, level=level) 419 420 HASH_LENGTH = { 421 'PSA_ALG_MD5': 16, 422 'PSA_ALG_SHA_1': 20, 423 } 424 HASH_LENGTH_BITS_RE = re.compile(r'([0-9]+)\Z') 425 @classmethod 426 def hash_length(cls, alg: str) -> int: 427 """The length of the given hash algorithm, in bytes.""" 428 if alg in cls.HASH_LENGTH: 429 return cls.HASH_LENGTH[alg] 430 m = cls.HASH_LENGTH_BITS_RE.search(alg) 431 if m: 432 return int(m.group(1)) // 8 433 raise ValueError('Unknown hash length for ' + alg) 434 435 PERMITTED_TAG_LENGTHS = { 436 'PSA_ALG_CCM': frozenset([4, 6, 8, 10, 12, 14, 16]), 437 'PSA_ALG_CHACHA20_POLY1305': frozenset([16]), 438 'PSA_ALG_GCM': frozenset([4, 8, 12, 13, 14, 15, 16]), 439 } 440 MAC_LENGTH = { 441 'PSA_ALG_CBC_MAC': 16, # actually the block cipher length 442 'PSA_ALG_CMAC': 16, # actually the block cipher length 443 } 444 HMAC_RE = re.compile(r'PSA_ALG_HMAC\((.*)\)\Z') 445 @classmethod 446 def permitted_truncations(cls, base: str) -> FrozenSet[int]: 447 """Permitted output lengths for the given MAC or AEAD base algorithm. 448 449 For a MAC algorithm, this is the set of truncation lengths that 450 Mbed TLS supports. 451 For an AEAD algorithm, this is the set of truncation lengths that 452 are permitted by the algorithm specification. 453 """ 454 if base in cls.PERMITTED_TAG_LENGTHS: 455 return cls.PERMITTED_TAG_LENGTHS[base] 456 max_length = cls.MAC_LENGTH.get(base, None) 457 if max_length is None: 458 m = cls.HMAC_RE.match(base) 459 if m: 460 max_length = cls.hash_length(m.group(1)) 461 if max_length is None: 462 raise ValueError('Unknown permitted lengths for ' + base) 463 return frozenset(range(4, max_length + 1)) 464 465 TRUNCATED_ALG_RE = re.compile( 466 r'(?P<face>PSA_ALG_(?:AEAD_WITH_SHORTENED_TAG|TRUNCATED_MAC))' 467 r'\((?P<base>.*),' 468 r'(?P<length>0[Xx][0-9A-Fa-f]+|[1-9][0-9]*|0[0-7]*)[LUlu]*\)\Z') 469 def is_invalid_truncation(self) -> bool: 470 """False for a MAC or AEAD algorithm truncated to an invalid length. 471 472 True for a MAC or AEAD algorithm truncated to a valid length or to 473 a length that cannot be determined. True for anything other than 474 a truncated MAC or AEAD. 475 """ 476 m = self.TRUNCATED_ALG_RE.match(self.expression) 477 if m: 478 base = m.group('base') 479 to_length = int(m.group('length'), 0) 480 permitted_lengths = self.permitted_truncations(base) 481 if to_length not in permitted_lengths: 482 return True 483 return False 484 485 def is_valid_for_operation(self) -> bool: 486 """Whether this algorithm construction is valid for an operation. 487 488 This function assumes that the algorithm is constructed in a 489 "grammatically" correct way, and only rejects semantically invalid 490 combinations. 491 """ 492 if self.is_wildcard: 493 return False 494 if self.is_invalid_truncation(): 495 return False 496 return True 497 498 def can_do(self, category: AlgorithmCategory) -> bool: 499 """Whether this algorithm can perform operations in the given category. 500 """ 501 if category == self.category: 502 return True 503 if category == AlgorithmCategory.KEY_DERIVATION and \ 504 self.is_key_agreement_with_derivation(): 505 return True 506 return False 507 508 def usage_flags(self, public: bool = False) -> List[str]: 509 """The list of usage flags describing operations that can perform this algorithm. 510 511 If public is true, only return public-key operations, not private-key operations. 512 """ 513 if self.category == AlgorithmCategory.HASH: 514 flags = [] 515 elif self.category == AlgorithmCategory.MAC: 516 flags = ['SIGN_HASH', 'SIGN_MESSAGE', 517 'VERIFY_HASH', 'VERIFY_MESSAGE'] 518 elif self.category == AlgorithmCategory.CIPHER or \ 519 self.category == AlgorithmCategory.AEAD: 520 flags = ['DECRYPT', 'ENCRYPT'] 521 elif self.category == AlgorithmCategory.SIGN: 522 flags = ['VERIFY_HASH', 'VERIFY_MESSAGE'] 523 if not public: 524 flags += ['SIGN_HASH', 'SIGN_MESSAGE'] 525 elif self.category == AlgorithmCategory.ASYMMETRIC_ENCRYPTION: 526 flags = ['ENCRYPT'] 527 if not public: 528 flags += ['DECRYPT'] 529 elif self.category == AlgorithmCategory.KEY_DERIVATION or \ 530 self.category == AlgorithmCategory.KEY_AGREEMENT: 531 flags = ['DERIVE'] 532 else: 533 raise AlgorithmNotRecognized(self.expression) 534 return ['PSA_KEY_USAGE_' + flag for flag in flags] 535