• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Knowledge about cryptographic mechanisms implemented in Mbed TLS.
2
3This module is entirely based on the PSA API.
4"""
5
6# Copyright The Mbed TLS Contributors
7# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
8#
9
10import enum
11import re
12from typing import FrozenSet, Iterable, List, Optional, Tuple, Dict
13
14from .asymmetric_key_data import ASYMMETRIC_KEY_DATA
15
16
17def short_expression(original: str, level: int = 0) -> str:
18    """Abbreviate the expression, keeping it human-readable.
19
20    If `level` is 0, just remove parts that are implicit from context,
21    such as a leading ``PSA_KEY_TYPE_``.
22    For larger values of `level`, also abbreviate some names in an
23    unambiguous, but ad hoc way.
24    """
25    short = original
26    short = re.sub(r'\bPSA_(?:ALG|ECC_FAMILY|KEY_[A-Z]+)_', r'', short)
27    short = re.sub(r' +', r'', short)
28    if level >= 1:
29        short = re.sub(r'PUBLIC_KEY\b', r'PUB', short)
30        short = re.sub(r'KEY_PAIR\b', r'PAIR', short)
31        short = re.sub(r'\bBRAINPOOL_P', r'BP', short)
32        short = re.sub(r'\bMONTGOMERY\b', r'MGM', short)
33        short = re.sub(r'AEAD_WITH_SHORTENED_TAG\b', r'AEAD_SHORT', short)
34        short = re.sub(r'\bDETERMINISTIC_', r'DET_', short)
35        short = re.sub(r'\bKEY_AGREEMENT\b', r'KA', short)
36        short = re.sub(r'_PSK_TO_MS\b', r'_PSK2MS', short)
37    return short
38
39
40BLOCK_CIPHERS = frozenset(['AES', 'ARIA', 'CAMELLIA', 'DES'])
41BLOCK_MAC_MODES = frozenset(['CBC_MAC', 'CMAC'])
42BLOCK_CIPHER_MODES = frozenset([
43    'CTR', 'CFB', 'OFB', 'XTS', 'CCM_STAR_NO_TAG',
44    'ECB_NO_PADDING', 'CBC_NO_PADDING', 'CBC_PKCS7',
45])
46BLOCK_AEAD_MODES = frozenset(['CCM', 'GCM'])
47
48class EllipticCurveCategory(enum.Enum):
49    """Categorization of elliptic curve families.
50
51    The category of a curve determines what algorithms are defined over it.
52    """
53
54    SHORT_WEIERSTRASS = 0
55    MONTGOMERY = 1
56    TWISTED_EDWARDS = 2
57
58    @staticmethod
59    def from_family(family: str) -> 'EllipticCurveCategory':
60        if family == 'PSA_ECC_FAMILY_MONTGOMERY':
61            return EllipticCurveCategory.MONTGOMERY
62        if family == 'PSA_ECC_FAMILY_TWISTED_EDWARDS':
63            return EllipticCurveCategory.TWISTED_EDWARDS
64        # Default to SW, which most curves belong to.
65        return EllipticCurveCategory.SHORT_WEIERSTRASS
66
67
68class KeyType:
69    """Knowledge about a PSA key type."""
70
71    def __init__(self, name: str, params: Optional[Iterable[str]] = None) -> None:
72        """Analyze a key type.
73
74        The key type must be specified in PSA syntax. In its simplest form,
75        `name` is a string 'PSA_KEY_TYPE_xxx' which is the name of a PSA key
76        type macro. For key types that take arguments, the arguments can
77        be passed either through the optional argument `params` or by
78        passing an expression of the form 'PSA_KEY_TYPE_xxx(param1, ...)'
79        in `name` as a string.
80        """
81
82        self.name = name.strip()
83        """The key type macro name (``PSA_KEY_TYPE_xxx``).
84
85        For key types constructed from a macro with arguments, this is the
86        name of the macro, and the arguments are in `self.params`.
87        """
88        if params is None:
89            if '(' in self.name:
90                m = re.match(r'(\w+)\s*\((.*)\)\Z', self.name)
91                assert m is not None
92                self.name = m.group(1)
93                params = m.group(2).split(',')
94        self.params = (None if params is None else
95                       [param.strip() for param in params])
96        """The parameters of the key type, if there are any.
97
98        None if the key type is a macro without arguments.
99        """
100        assert re.match(r'PSA_KEY_TYPE_\w+\Z', self.name)
101
102        self.expression = self.name
103        """A C expression whose value is the key type encoding."""
104        if self.params is not None:
105            self.expression += '(' + ', '.join(self.params) + ')'
106
107        m = re.match(r'PSA_KEY_TYPE_(\w+)', self.name)
108        assert m
109        self.head = re.sub(r'_(?:PUBLIC_KEY|KEY_PAIR)\Z', r'', m.group(1))
110        """The key type macro name, with common prefixes and suffixes stripped."""
111
112        self.private_type = re.sub(r'_PUBLIC_KEY\Z', r'_KEY_PAIR', self.name)
113        """The key type macro name for the corresponding key pair type.
114
115        For everything other than a public key type, this is the same as
116        `self.name`.
117        """
118
119    def short_expression(self, level: int = 0) -> str:
120        """Abbreviate the expression, keeping it human-readable.
121
122        See `crypto_knowledge.short_expression`.
123        """
124        return short_expression(self.expression, level=level)
125
126    def is_public(self) -> bool:
127        """Whether the key type is for public keys."""
128        return self.name.endswith('_PUBLIC_KEY')
129
130    ECC_KEY_SIZES = {
131        'PSA_ECC_FAMILY_SECP_K1': (192, 224, 256),
132        'PSA_ECC_FAMILY_SECP_R1': (225, 256, 384, 521),
133        'PSA_ECC_FAMILY_SECP_R2': (160,),
134        'PSA_ECC_FAMILY_SECT_K1': (163, 233, 239, 283, 409, 571),
135        'PSA_ECC_FAMILY_SECT_R1': (163, 233, 283, 409, 571),
136        'PSA_ECC_FAMILY_SECT_R2': (163,),
137        'PSA_ECC_FAMILY_BRAINPOOL_P_R1': (160, 192, 224, 256, 320, 384, 512),
138        'PSA_ECC_FAMILY_MONTGOMERY': (255, 448),
139        'PSA_ECC_FAMILY_TWISTED_EDWARDS': (255, 448),
140    } # type: Dict[str, Tuple[int, ...]]
141    KEY_TYPE_SIZES = {
142        'PSA_KEY_TYPE_AES': (128, 192, 256), # exhaustive
143        'PSA_KEY_TYPE_ARC4': (8, 128, 2048), # extremes + sensible
144        'PSA_KEY_TYPE_ARIA': (128, 192, 256), # exhaustive
145        'PSA_KEY_TYPE_CAMELLIA': (128, 192, 256), # exhaustive
146        'PSA_KEY_TYPE_CHACHA20': (256,), # exhaustive
147        'PSA_KEY_TYPE_DERIVE': (120, 128), # sample
148        'PSA_KEY_TYPE_DES': (64, 128, 192), # exhaustive
149        'PSA_KEY_TYPE_HMAC': (128, 160, 224, 256, 384, 512), # standard size for each supported hash
150        'PSA_KEY_TYPE_RAW_DATA': (8, 40, 128), # sample
151        'PSA_KEY_TYPE_RSA_KEY_PAIR': (1024, 1536), # small sample
152    } # type: Dict[str, Tuple[int, ...]]
153    def sizes_to_test(self) -> Tuple[int, ...]:
154        """Return a tuple of key sizes to test.
155
156        For key types that only allow a single size, or only a small set of
157        sizes, these are all the possible sizes. For key types that allow a
158        wide range of sizes, these are a representative sample of sizes,
159        excluding large sizes for which a typical resource-constrained platform
160        may run out of memory.
161        """
162        if self.private_type == 'PSA_KEY_TYPE_ECC_KEY_PAIR':
163            assert self.params is not None
164            return self.ECC_KEY_SIZES[self.params[0]]
165        return self.KEY_TYPE_SIZES[self.private_type]
166
167    # "48657265006973206b6579a064617461"
168    DATA_BLOCK = b'Here\000is key\240data'
169    def key_material(self, bits: int) -> bytes:
170        """Return a byte string containing suitable key material with the given bit length.
171
172        Use the PSA export representation. The resulting byte string is one that
173        can be obtained with the following code:
174        ```
175        psa_set_key_type(&attributes, `self.expression`);
176        psa_set_key_bits(&attributes, `bits`);
177        psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_EXPORT);
178        psa_generate_key(&attributes, &id);
179        psa_export_key(id, `material`, ...);
180        ```
181        """
182        if self.expression in ASYMMETRIC_KEY_DATA:
183            if bits not in ASYMMETRIC_KEY_DATA[self.expression]:
184                raise ValueError('No key data for {}-bit {}'
185                                 .format(bits, self.expression))
186            return ASYMMETRIC_KEY_DATA[self.expression][bits]
187        if bits % 8 != 0:
188            raise ValueError('Non-integer number of bytes: {} bits for {}'
189                             .format(bits, self.expression))
190        length = bits // 8
191        if self.name == 'PSA_KEY_TYPE_DES':
192            # "644573206b457901644573206b457902644573206b457904"
193            des3 = b'dEs kEy\001dEs kEy\002dEs kEy\004'
194            return des3[:length]
195        return b''.join([self.DATA_BLOCK] * (length // len(self.DATA_BLOCK)) +
196                        [self.DATA_BLOCK[:length % len(self.DATA_BLOCK)]])
197
198    def can_do(self, alg: 'Algorithm') -> bool:
199        """Whether this key type can be used for operations with the given algorithm.
200
201        This function does not currently handle key derivation or PAKE.
202        """
203        #pylint: disable=too-many-branches,too-many-return-statements
204        if not alg.is_valid_for_operation():
205            return False
206        if self.head == 'HMAC' and alg.head == 'HMAC':
207            return True
208        if self.head == 'DES':
209            # 64-bit block ciphers only allow a reduced set of modes.
210            return alg.head in [
211                'CBC_NO_PADDING', 'CBC_PKCS7',
212                'ECB_NO_PADDING',
213            ]
214        if self.head in BLOCK_CIPHERS and \
215           alg.head in frozenset.union(BLOCK_MAC_MODES,
216                                       BLOCK_CIPHER_MODES,
217                                       BLOCK_AEAD_MODES):
218            if alg.head in ['CMAC', 'OFB'] and \
219               self.head in ['ARIA', 'CAMELLIA']:
220                return False # not implemented in Mbed TLS
221            return True
222        if self.head == 'CHACHA20' and alg.head == 'CHACHA20_POLY1305':
223            return True
224        if self.head in {'ARC4', 'CHACHA20'} and \
225           alg.head == 'STREAM_CIPHER':
226            return True
227        if self.head == 'RSA' and alg.head.startswith('RSA_'):
228            return True
229        if alg.category == AlgorithmCategory.KEY_AGREEMENT and \
230           self.is_public():
231            # The PSA API does not use public key objects in key agreement
232            # operations: it imports the public key as a formatted byte string.
233            # So a public key object with a key agreement algorithm is not
234            # a valid combination.
235            return False
236        if self.head == 'ECC':
237            assert self.params is not None
238            eccc = EllipticCurveCategory.from_family(self.params[0])
239            if alg.head == 'ECDH' and \
240               eccc in {EllipticCurveCategory.SHORT_WEIERSTRASS,
241                        EllipticCurveCategory.MONTGOMERY}:
242                return True
243            if alg.head == 'ECDSA' and \
244               eccc == EllipticCurveCategory.SHORT_WEIERSTRASS:
245                return True
246            if alg.head in {'PURE_EDDSA', 'EDDSA_PREHASH'} and \
247               eccc == EllipticCurveCategory.TWISTED_EDWARDS:
248                return True
249        return False
250
251
252class AlgorithmCategory(enum.Enum):
253    """PSA algorithm categories."""
254    # The numbers are aligned with the category bits in numerical values of
255    # algorithms.
256    HASH = 2
257    MAC = 3
258    CIPHER = 4
259    AEAD = 5
260    SIGN = 6
261    ASYMMETRIC_ENCRYPTION = 7
262    KEY_DERIVATION = 8
263    KEY_AGREEMENT = 9
264    PAKE = 10
265
266    def requires_key(self) -> bool:
267        """Whether operations in this category are set up with a key."""
268        return self not in {self.HASH, self.KEY_DERIVATION}
269
270    def is_asymmetric(self) -> bool:
271        """Whether operations in this category involve asymmetric keys."""
272        return self in {
273            self.SIGN,
274            self.ASYMMETRIC_ENCRYPTION,
275            self.KEY_AGREEMENT
276        }
277
278
279class AlgorithmNotRecognized(Exception):
280    def __init__(self, expr: str) -> None:
281        super().__init__('Algorithm not recognized: ' + expr)
282        self.expr = expr
283
284
285class Algorithm:
286    """Knowledge about a PSA algorithm."""
287
288    @staticmethod
289    def determine_base(expr: str) -> str:
290        """Return an expression for the "base" of the algorithm.
291
292        This strips off variants of algorithms such as MAC truncation.
293
294        This function does not attempt to detect invalid inputs.
295        """
296        m = re.match(r'PSA_ALG_(?:'
297                     r'(?:TRUNCATED|AT_LEAST_THIS_LENGTH)_MAC|'
298                     r'AEAD_WITH_(?:SHORTENED|AT_LEAST_THIS_LENGTH)_TAG'
299                     r')\((.*),[^,]+\)\Z', expr)
300        if m:
301            expr = m.group(1)
302        return expr
303
304    @staticmethod
305    def determine_head(expr: str) -> str:
306        """Return the head of an algorithm expression.
307
308        The head is the first (outermost) constructor, without its PSA_ALG_
309        prefix, and with some normalization of similar algorithms.
310        """
311        m = re.match(r'PSA_ALG_(?:DETERMINISTIC_)?(\w+)', expr)
312        if not m:
313            raise AlgorithmNotRecognized(expr)
314        head = m.group(1)
315        if head == 'KEY_AGREEMENT':
316            m = re.match(r'PSA_ALG_KEY_AGREEMENT\s*\(\s*PSA_ALG_(\w+)', expr)
317            if not m:
318                raise AlgorithmNotRecognized(expr)
319            head = m.group(1)
320        head = re.sub(r'_ANY\Z', r'', head)
321        if re.match(r'ED[0-9]+PH\Z', head):
322            head = 'EDDSA_PREHASH'
323        return head
324
325    CATEGORY_FROM_HEAD = {
326        'SHA': AlgorithmCategory.HASH,
327        'SHAKE256_512': AlgorithmCategory.HASH,
328        'MD': AlgorithmCategory.HASH,
329        'RIPEMD': AlgorithmCategory.HASH,
330        'ANY_HASH': AlgorithmCategory.HASH,
331        'HMAC': AlgorithmCategory.MAC,
332        'STREAM_CIPHER': AlgorithmCategory.CIPHER,
333        'CHACHA20_POLY1305': AlgorithmCategory.AEAD,
334        'DSA': AlgorithmCategory.SIGN,
335        'ECDSA': AlgorithmCategory.SIGN,
336        'EDDSA': AlgorithmCategory.SIGN,
337        'PURE_EDDSA': AlgorithmCategory.SIGN,
338        'RSA_PSS': AlgorithmCategory.SIGN,
339        'RSA_PKCS1V15_SIGN': AlgorithmCategory.SIGN,
340        'RSA_PKCS1V15_CRYPT': AlgorithmCategory.ASYMMETRIC_ENCRYPTION,
341        'RSA_OAEP': AlgorithmCategory.ASYMMETRIC_ENCRYPTION,
342        'HKDF': AlgorithmCategory.KEY_DERIVATION,
343        'TLS12_PRF': AlgorithmCategory.KEY_DERIVATION,
344        'TLS12_PSK_TO_MS': AlgorithmCategory.KEY_DERIVATION,
345        'PBKDF': AlgorithmCategory.KEY_DERIVATION,
346        'ECDH': AlgorithmCategory.KEY_AGREEMENT,
347        'FFDH': AlgorithmCategory.KEY_AGREEMENT,
348        # KEY_AGREEMENT(...) is a key derivation with a key agreement component
349        'KEY_AGREEMENT': AlgorithmCategory.KEY_DERIVATION,
350        'JPAKE': AlgorithmCategory.PAKE,
351    }
352    for x in BLOCK_MAC_MODES:
353        CATEGORY_FROM_HEAD[x] = AlgorithmCategory.MAC
354    for x in BLOCK_CIPHER_MODES:
355        CATEGORY_FROM_HEAD[x] = AlgorithmCategory.CIPHER
356    for x in BLOCK_AEAD_MODES:
357        CATEGORY_FROM_HEAD[x] = AlgorithmCategory.AEAD
358
359    def determine_category(self, expr: str, head: str) -> AlgorithmCategory:
360        """Return the category of the given algorithm expression.
361
362        This function does not attempt to detect invalid inputs.
363        """
364        prefix = head
365        while prefix:
366            if prefix in self.CATEGORY_FROM_HEAD:
367                return self.CATEGORY_FROM_HEAD[prefix]
368            if re.match(r'.*[0-9]\Z', prefix):
369                prefix = re.sub(r'_*[0-9]+\Z', r'', prefix)
370            else:
371                prefix = re.sub(r'_*[^_]*\Z', r'', prefix)
372        raise AlgorithmNotRecognized(expr)
373
374    @staticmethod
375    def determine_wildcard(expr) -> bool:
376        """Whether the given algorithm expression is a wildcard.
377
378        This function does not attempt to detect invalid inputs.
379        """
380        if re.search(r'\bPSA_ALG_ANY_HASH\b', expr):
381            return True
382        if re.search(r'_AT_LEAST_', expr):
383            return True
384        return False
385
386    def __init__(self, expr: str) -> None:
387        """Analyze an algorithm value.
388
389        The algorithm must be expressed as a C expression containing only
390        calls to PSA algorithm constructor macros and numeric literals.
391
392        This class is only programmed to handle valid expressions. Invalid
393        expressions may result in exceptions or in nonsensical results.
394        """
395        self.expression = re.sub(r'\s+', r'', expr)
396        self.base_expression = self.determine_base(self.expression)
397        self.head = self.determine_head(self.base_expression)
398        self.category = self.determine_category(self.base_expression, self.head)
399        self.is_wildcard = self.determine_wildcard(self.expression)
400
401    def is_key_agreement_with_derivation(self) -> bool:
402        """Whether this is a combined key agreement and key derivation algorithm."""
403        if self.category != AlgorithmCategory.KEY_AGREEMENT:
404            return False
405        m = re.match(r'PSA_ALG_KEY_AGREEMENT\(\w+,\s*(.*)\)\Z', self.expression)
406        if not m:
407            return False
408        kdf_alg = m.group(1)
409        # Assume kdf_alg is either a valid KDF or 0.
410        return not re.match(r'(?:0[Xx])?0+\s*\Z', kdf_alg)
411
412
413    def short_expression(self, level: int = 0) -> str:
414        """Abbreviate the expression, keeping it human-readable.
415
416        See `crypto_knowledge.short_expression`.
417        """
418        return short_expression(self.expression, level=level)
419
420    HASH_LENGTH = {
421        'PSA_ALG_MD5': 16,
422        'PSA_ALG_SHA_1': 20,
423    }
424    HASH_LENGTH_BITS_RE = re.compile(r'([0-9]+)\Z')
425    @classmethod
426    def hash_length(cls, alg: str) -> int:
427        """The length of the given hash algorithm, in bytes."""
428        if alg in cls.HASH_LENGTH:
429            return cls.HASH_LENGTH[alg]
430        m = cls.HASH_LENGTH_BITS_RE.search(alg)
431        if m:
432            return int(m.group(1)) // 8
433        raise ValueError('Unknown hash length for ' + alg)
434
435    PERMITTED_TAG_LENGTHS = {
436        'PSA_ALG_CCM': frozenset([4, 6, 8, 10, 12, 14, 16]),
437        'PSA_ALG_CHACHA20_POLY1305': frozenset([16]),
438        'PSA_ALG_GCM': frozenset([4, 8, 12, 13, 14, 15, 16]),
439    }
440    MAC_LENGTH = {
441        'PSA_ALG_CBC_MAC': 16, # actually the block cipher length
442        'PSA_ALG_CMAC': 16, # actually the block cipher length
443    }
444    HMAC_RE = re.compile(r'PSA_ALG_HMAC\((.*)\)\Z')
445    @classmethod
446    def permitted_truncations(cls, base: str) -> FrozenSet[int]:
447        """Permitted output lengths for the given MAC or AEAD base algorithm.
448
449        For a MAC algorithm, this is the set of truncation lengths that
450        Mbed TLS supports.
451        For an AEAD algorithm, this is the set of truncation lengths that
452        are permitted by the algorithm specification.
453        """
454        if base in cls.PERMITTED_TAG_LENGTHS:
455            return cls.PERMITTED_TAG_LENGTHS[base]
456        max_length = cls.MAC_LENGTH.get(base, None)
457        if max_length is None:
458            m = cls.HMAC_RE.match(base)
459            if m:
460                max_length = cls.hash_length(m.group(1))
461        if max_length is None:
462            raise ValueError('Unknown permitted lengths for ' + base)
463        return frozenset(range(4, max_length + 1))
464
465    TRUNCATED_ALG_RE = re.compile(
466        r'(?P<face>PSA_ALG_(?:AEAD_WITH_SHORTENED_TAG|TRUNCATED_MAC))'
467        r'\((?P<base>.*),'
468        r'(?P<length>0[Xx][0-9A-Fa-f]+|[1-9][0-9]*|0[0-7]*)[LUlu]*\)\Z')
469    def is_invalid_truncation(self) -> bool:
470        """False for a MAC or AEAD algorithm truncated to an invalid length.
471
472        True for a MAC or AEAD algorithm truncated to a valid length or to
473        a length that cannot be determined. True for anything other than
474        a truncated MAC or AEAD.
475        """
476        m = self.TRUNCATED_ALG_RE.match(self.expression)
477        if m:
478            base = m.group('base')
479            to_length = int(m.group('length'), 0)
480            permitted_lengths = self.permitted_truncations(base)
481            if to_length not in permitted_lengths:
482                return True
483        return False
484
485    def is_valid_for_operation(self) -> bool:
486        """Whether this algorithm construction is valid for an operation.
487
488        This function assumes that the algorithm is constructed in a
489        "grammatically" correct way, and only rejects semantically invalid
490        combinations.
491        """
492        if self.is_wildcard:
493            return False
494        if self.is_invalid_truncation():
495            return False
496        return True
497
498    def can_do(self, category: AlgorithmCategory) -> bool:
499        """Whether this algorithm can perform operations in the given category.
500        """
501        if category == self.category:
502            return True
503        if category == AlgorithmCategory.KEY_DERIVATION and \
504           self.is_key_agreement_with_derivation():
505            return True
506        return False
507
508    def usage_flags(self, public: bool = False) -> List[str]:
509        """The list of usage flags describing operations that can perform this algorithm.
510
511        If public is true, only return public-key operations, not private-key operations.
512        """
513        if self.category == AlgorithmCategory.HASH:
514            flags = []
515        elif self.category == AlgorithmCategory.MAC:
516            flags = ['SIGN_HASH', 'SIGN_MESSAGE',
517                     'VERIFY_HASH', 'VERIFY_MESSAGE']
518        elif self.category == AlgorithmCategory.CIPHER or \
519             self.category == AlgorithmCategory.AEAD:
520            flags = ['DECRYPT', 'ENCRYPT']
521        elif self.category == AlgorithmCategory.SIGN:
522            flags = ['VERIFY_HASH', 'VERIFY_MESSAGE']
523            if not public:
524                flags += ['SIGN_HASH', 'SIGN_MESSAGE']
525        elif self.category == AlgorithmCategory.ASYMMETRIC_ENCRYPTION:
526            flags = ['ENCRYPT']
527            if not public:
528                flags += ['DECRYPT']
529        elif self.category == AlgorithmCategory.KEY_DERIVATION or \
530             self.category == AlgorithmCategory.KEY_AGREEMENT:
531            flags = ['DERIVE']
532        else:
533            raise AlgorithmNotRecognized(self.expression)
534        return ['PSA_KEY_USAGE_' + flag for flag in flags]
535