• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import struct
8
9from cryptography import utils
10from cryptography.exceptions import (
11    AlreadyFinalized,
12    InvalidKey,
13    UnsupportedAlgorithm,
14    _Reasons,
15)
16from cryptography.hazmat.backends import _get_backend
17from cryptography.hazmat.backends.interfaces import HMACBackend
18from cryptography.hazmat.backends.interfaces import HashBackend
19from cryptography.hazmat.primitives import constant_time, hashes, hmac
20from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
21
22
23def _int_to_u32be(n):
24    return struct.pack(">I", n)
25
26
27def _common_args_checks(algorithm, length, otherinfo):
28    max_length = algorithm.digest_size * (2 ** 32 - 1)
29    if length > max_length:
30        raise ValueError(
31            "Can not derive keys larger than {} bits.".format(max_length)
32        )
33    if otherinfo is not None:
34        utils._check_bytes("otherinfo", otherinfo)
35
36
37def _concatkdf_derive(key_material, length, auxfn, otherinfo):
38    utils._check_byteslike("key_material", key_material)
39    output = [b""]
40    outlen = 0
41    counter = 1
42
43    while length > outlen:
44        h = auxfn()
45        h.update(_int_to_u32be(counter))
46        h.update(key_material)
47        h.update(otherinfo)
48        output.append(h.finalize())
49        outlen += len(output[-1])
50        counter += 1
51
52    return b"".join(output)[:length]
53
54
55@utils.register_interface(KeyDerivationFunction)
56class ConcatKDFHash(object):
57    def __init__(self, algorithm, length, otherinfo, backend=None):
58        backend = _get_backend(backend)
59
60        _common_args_checks(algorithm, length, otherinfo)
61        self._algorithm = algorithm
62        self._length = length
63        self._otherinfo = otherinfo
64        if self._otherinfo is None:
65            self._otherinfo = b""
66
67        if not isinstance(backend, HashBackend):
68            raise UnsupportedAlgorithm(
69                "Backend object does not implement HashBackend.",
70                _Reasons.BACKEND_MISSING_INTERFACE,
71            )
72        self._backend = backend
73        self._used = False
74
75    def _hash(self):
76        return hashes.Hash(self._algorithm, self._backend)
77
78    def derive(self, key_material):
79        if self._used:
80            raise AlreadyFinalized
81        self._used = True
82        return _concatkdf_derive(
83            key_material, self._length, self._hash, self._otherinfo
84        )
85
86    def verify(self, key_material, expected_key):
87        if not constant_time.bytes_eq(self.derive(key_material), expected_key):
88            raise InvalidKey
89
90
91@utils.register_interface(KeyDerivationFunction)
92class ConcatKDFHMAC(object):
93    def __init__(self, algorithm, length, salt, otherinfo, backend=None):
94        backend = _get_backend(backend)
95
96        _common_args_checks(algorithm, length, otherinfo)
97        self._algorithm = algorithm
98        self._length = length
99        self._otherinfo = otherinfo
100        if self._otherinfo is None:
101            self._otherinfo = b""
102
103        if salt is None:
104            salt = b"\x00" * algorithm.block_size
105        else:
106            utils._check_bytes("salt", salt)
107
108        self._salt = salt
109
110        if not isinstance(backend, HMACBackend):
111            raise UnsupportedAlgorithm(
112                "Backend object does not implement HMACBackend.",
113                _Reasons.BACKEND_MISSING_INTERFACE,
114            )
115        self._backend = backend
116        self._used = False
117
118    def _hmac(self):
119        return hmac.HMAC(self._salt, self._algorithm, self._backend)
120
121    def derive(self, key_material):
122        if self._used:
123            raise AlreadyFinalized
124        self._used = True
125        return _concatkdf_derive(
126            key_material, self._length, self._hmac, self._otherinfo
127        )
128
129    def verify(self, key_material, expected_key):
130        if not constant_time.bytes_eq(self.derive(key_material), expected_key):
131            raise InvalidKey
132