1# This file is dual licensed under the terms of the Apache License, Version 2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 3# for complete details. 4 5from __future__ import absolute_import, division, print_function 6 7from enum import Enum 8 9from six.moves import range 10 11from cryptography import utils 12from cryptography.exceptions import ( 13 AlreadyFinalized, 14 InvalidKey, 15 UnsupportedAlgorithm, 16 _Reasons, 17) 18from cryptography.hazmat.backends import _get_backend 19from cryptography.hazmat.backends.interfaces import HMACBackend 20from cryptography.hazmat.primitives import constant_time, hashes, hmac 21from cryptography.hazmat.primitives.kdf import KeyDerivationFunction 22 23 24class Mode(Enum): 25 CounterMode = "ctr" 26 27 28class CounterLocation(Enum): 29 BeforeFixed = "before_fixed" 30 AfterFixed = "after_fixed" 31 32 33@utils.register_interface(KeyDerivationFunction) 34class KBKDFHMAC(object): 35 def __init__( 36 self, 37 algorithm, 38 mode, 39 length, 40 rlen, 41 llen, 42 location, 43 label, 44 context, 45 fixed, 46 backend=None, 47 ): 48 backend = _get_backend(backend) 49 if not isinstance(backend, HMACBackend): 50 raise UnsupportedAlgorithm( 51 "Backend object does not implement HMACBackend.", 52 _Reasons.BACKEND_MISSING_INTERFACE, 53 ) 54 55 if not isinstance(algorithm, hashes.HashAlgorithm): 56 raise UnsupportedAlgorithm( 57 "Algorithm supplied is not a supported hash algorithm.", 58 _Reasons.UNSUPPORTED_HASH, 59 ) 60 61 if not backend.hmac_supported(algorithm): 62 raise UnsupportedAlgorithm( 63 "Algorithm supplied is not a supported hmac algorithm.", 64 _Reasons.UNSUPPORTED_HASH, 65 ) 66 67 if not isinstance(mode, Mode): 68 raise TypeError("mode must be of type Mode") 69 70 if not isinstance(location, CounterLocation): 71 raise TypeError("location must be of type CounterLocation") 72 73 if (label or context) and fixed: 74 raise ValueError( 75 "When supplying fixed data, " "label and context are ignored." 76 ) 77 78 if rlen is None or not self._valid_byte_length(rlen): 79 raise ValueError("rlen must be between 1 and 4") 80 81 if llen is None and fixed is None: 82 raise ValueError("Please specify an llen") 83 84 if llen is not None and not isinstance(llen, int): 85 raise TypeError("llen must be an integer") 86 87 if label is None: 88 label = b"" 89 90 if context is None: 91 context = b"" 92 93 utils._check_bytes("label", label) 94 utils._check_bytes("context", context) 95 self._algorithm = algorithm 96 self._mode = mode 97 self._length = length 98 self._rlen = rlen 99 self._llen = llen 100 self._location = location 101 self._label = label 102 self._context = context 103 self._backend = backend 104 self._used = False 105 self._fixed_data = fixed 106 107 def _valid_byte_length(self, value): 108 if not isinstance(value, int): 109 raise TypeError("value must be of type int") 110 111 value_bin = utils.int_to_bytes(1, value) 112 if not 1 <= len(value_bin) <= 4: 113 return False 114 return True 115 116 def derive(self, key_material): 117 if self._used: 118 raise AlreadyFinalized 119 120 utils._check_byteslike("key_material", key_material) 121 self._used = True 122 123 # inverse floor division (equivalent to ceiling) 124 rounds = -(-self._length // self._algorithm.digest_size) 125 126 output = [b""] 127 128 # For counter mode, the number of iterations shall not be 129 # larger than 2^r-1, where r <= 32 is the binary length of the counter 130 # This ensures that the counter values used as an input to the 131 # PRF will not repeat during a particular call to the KDF function. 132 r_bin = utils.int_to_bytes(1, self._rlen) 133 if rounds > pow(2, len(r_bin) * 8) - 1: 134 raise ValueError("There are too many iterations.") 135 136 for i in range(1, rounds + 1): 137 h = hmac.HMAC(key_material, self._algorithm, backend=self._backend) 138 139 counter = utils.int_to_bytes(i, self._rlen) 140 if self._location == CounterLocation.BeforeFixed: 141 h.update(counter) 142 143 h.update(self._generate_fixed_input()) 144 145 if self._location == CounterLocation.AfterFixed: 146 h.update(counter) 147 148 output.append(h.finalize()) 149 150 return b"".join(output)[: self._length] 151 152 def _generate_fixed_input(self): 153 if self._fixed_data and isinstance(self._fixed_data, bytes): 154 return self._fixed_data 155 156 l_val = utils.int_to_bytes(self._length * 8, self._llen) 157 158 return b"".join([self._label, b"\x00", self._context, l_val]) 159 160 def verify(self, key_material, expected_key): 161 if not constant_time.bytes_eq(self.derive(key_material), expected_key): 162 raise InvalidKey 163