• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7from enum import Enum
8
9from six.moves import range
10
11from cryptography import utils
12from cryptography.exceptions import (
13    AlreadyFinalized,
14    InvalidKey,
15    UnsupportedAlgorithm,
16    _Reasons,
17)
18from cryptography.hazmat.backends import _get_backend
19from cryptography.hazmat.backends.interfaces import HMACBackend
20from cryptography.hazmat.primitives import constant_time, hashes, hmac
21from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
22
23
24class Mode(Enum):
25    CounterMode = "ctr"
26
27
28class CounterLocation(Enum):
29    BeforeFixed = "before_fixed"
30    AfterFixed = "after_fixed"
31
32
33@utils.register_interface(KeyDerivationFunction)
34class KBKDFHMAC(object):
35    def __init__(
36        self,
37        algorithm,
38        mode,
39        length,
40        rlen,
41        llen,
42        location,
43        label,
44        context,
45        fixed,
46        backend=None,
47    ):
48        backend = _get_backend(backend)
49        if not isinstance(backend, HMACBackend):
50            raise UnsupportedAlgorithm(
51                "Backend object does not implement HMACBackend.",
52                _Reasons.BACKEND_MISSING_INTERFACE,
53            )
54
55        if not isinstance(algorithm, hashes.HashAlgorithm):
56            raise UnsupportedAlgorithm(
57                "Algorithm supplied is not a supported hash algorithm.",
58                _Reasons.UNSUPPORTED_HASH,
59            )
60
61        if not backend.hmac_supported(algorithm):
62            raise UnsupportedAlgorithm(
63                "Algorithm supplied is not a supported hmac algorithm.",
64                _Reasons.UNSUPPORTED_HASH,
65            )
66
67        if not isinstance(mode, Mode):
68            raise TypeError("mode must be of type Mode")
69
70        if not isinstance(location, CounterLocation):
71            raise TypeError("location must be of type CounterLocation")
72
73        if (label or context) and fixed:
74            raise ValueError(
75                "When supplying fixed data, " "label and context are ignored."
76            )
77
78        if rlen is None or not self._valid_byte_length(rlen):
79            raise ValueError("rlen must be between 1 and 4")
80
81        if llen is None and fixed is None:
82            raise ValueError("Please specify an llen")
83
84        if llen is not None and not isinstance(llen, int):
85            raise TypeError("llen must be an integer")
86
87        if label is None:
88            label = b""
89
90        if context is None:
91            context = b""
92
93        utils._check_bytes("label", label)
94        utils._check_bytes("context", context)
95        self._algorithm = algorithm
96        self._mode = mode
97        self._length = length
98        self._rlen = rlen
99        self._llen = llen
100        self._location = location
101        self._label = label
102        self._context = context
103        self._backend = backend
104        self._used = False
105        self._fixed_data = fixed
106
107    def _valid_byte_length(self, value):
108        if not isinstance(value, int):
109            raise TypeError("value must be of type int")
110
111        value_bin = utils.int_to_bytes(1, value)
112        if not 1 <= len(value_bin) <= 4:
113            return False
114        return True
115
116    def derive(self, key_material):
117        if self._used:
118            raise AlreadyFinalized
119
120        utils._check_byteslike("key_material", key_material)
121        self._used = True
122
123        # inverse floor division (equivalent to ceiling)
124        rounds = -(-self._length // self._algorithm.digest_size)
125
126        output = [b""]
127
128        # For counter mode, the number of iterations shall not be
129        # larger than 2^r-1, where r <= 32 is the binary length of the counter
130        # This ensures that the counter values used as an input to the
131        # PRF will not repeat during a particular call to the KDF function.
132        r_bin = utils.int_to_bytes(1, self._rlen)
133        if rounds > pow(2, len(r_bin) * 8) - 1:
134            raise ValueError("There are too many iterations.")
135
136        for i in range(1, rounds + 1):
137            h = hmac.HMAC(key_material, self._algorithm, backend=self._backend)
138
139            counter = utils.int_to_bytes(i, self._rlen)
140            if self._location == CounterLocation.BeforeFixed:
141                h.update(counter)
142
143            h.update(self._generate_fixed_input())
144
145            if self._location == CounterLocation.AfterFixed:
146                h.update(counter)
147
148            output.append(h.finalize())
149
150        return b"".join(output)[: self._length]
151
152    def _generate_fixed_input(self):
153        if self._fixed_data and isinstance(self._fixed_data, bytes):
154            return self._fixed_data
155
156        l_val = utils.int_to_bytes(self._length * 8, self._llen)
157
158        return b"".join([self._label, b"\x00", self._context, l_val])
159
160    def verify(self, key_material, expected_key):
161        if not constant_time.bytes_eq(self.derive(key_material), expected_key):
162            raise InvalidKey
163