1# This file is dual licensed under the terms of the Apache License, Version 2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 3# for complete details. 4 5from __future__ import absolute_import, division, print_function 6 7import struct 8 9from cryptography import utils 10from cryptography.exceptions import ( 11 AlreadyFinalized, 12 InvalidKey, 13 UnsupportedAlgorithm, 14 _Reasons, 15) 16from cryptography.hazmat.backends import _get_backend 17from cryptography.hazmat.backends.interfaces import HashBackend 18from cryptography.hazmat.primitives import constant_time, hashes 19from cryptography.hazmat.primitives.kdf import KeyDerivationFunction 20 21 22def _int_to_u32be(n): 23 return struct.pack(">I", n) 24 25 26@utils.register_interface(KeyDerivationFunction) 27class X963KDF(object): 28 def __init__(self, algorithm, length, sharedinfo, backend=None): 29 backend = _get_backend(backend) 30 31 max_len = algorithm.digest_size * (2 ** 32 - 1) 32 if length > max_len: 33 raise ValueError( 34 "Can not derive keys larger than {} bits.".format(max_len) 35 ) 36 if sharedinfo is not None: 37 utils._check_bytes("sharedinfo", sharedinfo) 38 39 self._algorithm = algorithm 40 self._length = length 41 self._sharedinfo = sharedinfo 42 43 if not isinstance(backend, HashBackend): 44 raise UnsupportedAlgorithm( 45 "Backend object does not implement HashBackend.", 46 _Reasons.BACKEND_MISSING_INTERFACE, 47 ) 48 self._backend = backend 49 self._used = False 50 51 def derive(self, key_material): 52 if self._used: 53 raise AlreadyFinalized 54 self._used = True 55 utils._check_byteslike("key_material", key_material) 56 output = [b""] 57 outlen = 0 58 counter = 1 59 60 while self._length > outlen: 61 h = hashes.Hash(self._algorithm, self._backend) 62 h.update(key_material) 63 h.update(_int_to_u32be(counter)) 64 if self._sharedinfo is not None: 65 h.update(self._sharedinfo) 66 output.append(h.finalize()) 67 outlen += len(output[-1]) 68 counter += 1 69 70 return b"".join(output)[: self._length] 71 72 def verify(self, key_material, expected_key): 73 if not constant_time.bytes_eq(self.derive(key_material), expected_key): 74 raise InvalidKey 75