1# This file is dual licensed under the terms of the Apache License, Version 2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 3# for complete details. 4 5from __future__ import absolute_import, division, print_function 6 7import six 8 9from cryptography import utils 10from cryptography.exceptions import ( 11 AlreadyFinalized, InvalidKey, UnsupportedAlgorithm, _Reasons 12) 13from cryptography.hazmat.backends.interfaces import HMACBackend 14from cryptography.hazmat.primitives import constant_time, hmac 15from cryptography.hazmat.primitives.kdf import KeyDerivationFunction 16 17 18@utils.register_interface(KeyDerivationFunction) 19class HKDF(object): 20 def __init__(self, algorithm, length, salt, info, backend): 21 if not isinstance(backend, HMACBackend): 22 raise UnsupportedAlgorithm( 23 "Backend object does not implement HMACBackend.", 24 _Reasons.BACKEND_MISSING_INTERFACE 25 ) 26 27 self._algorithm = algorithm 28 29 if salt is None: 30 salt = b"\x00" * self._algorithm.digest_size 31 else: 32 utils._check_bytes("salt", salt) 33 34 self._salt = salt 35 36 self._backend = backend 37 38 self._hkdf_expand = HKDFExpand(self._algorithm, length, info, backend) 39 40 def _extract(self, key_material): 41 h = hmac.HMAC(self._salt, self._algorithm, backend=self._backend) 42 h.update(key_material) 43 return h.finalize() 44 45 def derive(self, key_material): 46 utils._check_byteslike("key_material", key_material) 47 return self._hkdf_expand.derive(self._extract(key_material)) 48 49 def verify(self, key_material, expected_key): 50 if not constant_time.bytes_eq(self.derive(key_material), expected_key): 51 raise InvalidKey 52 53 54@utils.register_interface(KeyDerivationFunction) 55class HKDFExpand(object): 56 def __init__(self, algorithm, length, info, backend): 57 if not isinstance(backend, HMACBackend): 58 raise UnsupportedAlgorithm( 59 "Backend object does not implement HMACBackend.", 60 _Reasons.BACKEND_MISSING_INTERFACE 61 ) 62 63 self._algorithm = algorithm 64 65 self._backend = backend 66 67 max_length = 255 * algorithm.digest_size 68 69 if length > max_length: 70 raise ValueError( 71 "Can not derive keys larger than {0} octets.".format( 72 max_length 73 )) 74 75 self._length = length 76 77 if info is None: 78 info = b"" 79 else: 80 utils._check_bytes("info", info) 81 82 self._info = info 83 84 self._used = False 85 86 def _expand(self, key_material): 87 output = [b""] 88 counter = 1 89 90 while self._algorithm.digest_size * (len(output) - 1) < self._length: 91 h = hmac.HMAC(key_material, self._algorithm, backend=self._backend) 92 h.update(output[-1]) 93 h.update(self._info) 94 h.update(six.int2byte(counter)) 95 output.append(h.finalize()) 96 counter += 1 97 98 return b"".join(output)[:self._length] 99 100 def derive(self, key_material): 101 utils._check_byteslike("key_material", key_material) 102 if self._used: 103 raise AlreadyFinalized 104 105 self._used = True 106 return self._expand(key_material) 107 108 def verify(self, key_material, expected_key): 109 if not constant_time.bytes_eq(self.derive(key_material), expected_key): 110 raise InvalidKey 111