1# This file is dual licensed under the terms of the Apache License, Version 2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 3# for complete details. 4 5from __future__ import absolute_import, division, print_function 6 7import struct 8 9from cryptography.hazmat.primitives.ciphers import Cipher 10from cryptography.hazmat.primitives.ciphers.algorithms import AES 11from cryptography.hazmat.primitives.ciphers.modes import ECB 12from cryptography.hazmat.primitives.constant_time import bytes_eq 13 14 15def _wrap_core(wrapping_key, a, r, backend): 16 # RFC 3394 Key Wrap - 2.2.1 (index method) 17 encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor() 18 n = len(r) 19 for j in range(6): 20 for i in range(n): 21 # every encryption operation is a discrete 16 byte chunk (because 22 # AES has a 128-bit block size) and since we're using ECB it is 23 # safe to reuse the encryptor for the entire operation 24 b = encryptor.update(a + r[i]) 25 # pack/unpack are safe as these are always 64-bit chunks 26 a = struct.pack( 27 ">Q", struct.unpack(">Q", b[:8])[0] ^ ((n * j) + i + 1) 28 ) 29 r[i] = b[-8:] 30 31 assert encryptor.finalize() == b"" 32 33 return a + b"".join(r) 34 35 36def aes_key_wrap(wrapping_key, key_to_wrap, backend): 37 if len(wrapping_key) not in [16, 24, 32]: 38 raise ValueError("The wrapping key must be a valid AES key length") 39 40 if len(key_to_wrap) < 16: 41 raise ValueError("The key to wrap must be at least 16 bytes") 42 43 if len(key_to_wrap) % 8 != 0: 44 raise ValueError("The key to wrap must be a multiple of 8 bytes") 45 46 a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" 47 r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)] 48 return _wrap_core(wrapping_key, a, r, backend) 49 50 51def _unwrap_core(wrapping_key, a, r, backend): 52 # Implement RFC 3394 Key Unwrap - 2.2.2 (index method) 53 decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor() 54 n = len(r) 55 for j in reversed(range(6)): 56 for i in reversed(range(n)): 57 # pack/unpack are safe as these are always 64-bit chunks 58 atr = struct.pack( 59 ">Q", struct.unpack(">Q", a)[0] ^ ((n * j) + i + 1) 60 ) + r[i] 61 # every decryption operation is a discrete 16 byte chunk so 62 # it is safe to reuse the decryptor for the entire operation 63 b = decryptor.update(atr) 64 a = b[:8] 65 r[i] = b[-8:] 66 67 assert decryptor.finalize() == b"" 68 return a, r 69 70 71def aes_key_wrap_with_padding(wrapping_key, key_to_wrap, backend): 72 if len(wrapping_key) not in [16, 24, 32]: 73 raise ValueError("The wrapping key must be a valid AES key length") 74 75 aiv = b"\xA6\x59\x59\xA6" + struct.pack(">i", len(key_to_wrap)) 76 # pad the key to wrap if necessary 77 pad = (8 - (len(key_to_wrap) % 8)) % 8 78 key_to_wrap = key_to_wrap + b"\x00" * pad 79 if len(key_to_wrap) == 8: 80 # RFC 5649 - 4.1 - exactly 8 octets after padding 81 encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor() 82 b = encryptor.update(aiv + key_to_wrap) 83 assert encryptor.finalize() == b"" 84 return b 85 else: 86 r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)] 87 return _wrap_core(wrapping_key, aiv, r, backend) 88 89 90def aes_key_unwrap_with_padding(wrapping_key, wrapped_key, backend): 91 if len(wrapped_key) < 16: 92 raise InvalidUnwrap("Must be at least 16 bytes") 93 94 if len(wrapping_key) not in [16, 24, 32]: 95 raise ValueError("The wrapping key must be a valid AES key length") 96 97 if len(wrapped_key) == 16: 98 # RFC 5649 - 4.2 - exactly two 64-bit blocks 99 decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor() 100 b = decryptor.update(wrapped_key) 101 assert decryptor.finalize() == b"" 102 a = b[:8] 103 data = b[8:] 104 n = 1 105 else: 106 r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)] 107 encrypted_aiv = r.pop(0) 108 n = len(r) 109 a, r = _unwrap_core(wrapping_key, encrypted_aiv, r, backend) 110 data = b"".join(r) 111 112 # 1) Check that MSB(32,A) = A65959A6. 113 # 2) Check that 8*(n-1) < LSB(32,A) <= 8*n. If so, let 114 # MLI = LSB(32,A). 115 # 3) Let b = (8*n)-MLI, and then check that the rightmost b octets of 116 # the output data are zero. 117 (mli,) = struct.unpack(">I", a[4:]) 118 b = (8 * n) - mli 119 if ( 120 not bytes_eq(a[:4], b"\xa6\x59\x59\xa6") or not 121 8 * (n - 1) < mli <= 8 * n or ( 122 b != 0 and not bytes_eq(data[-b:], b"\x00" * b) 123 ) 124 ): 125 raise InvalidUnwrap() 126 127 if b == 0: 128 return data 129 else: 130 return data[:-b] 131 132 133def aes_key_unwrap(wrapping_key, wrapped_key, backend): 134 if len(wrapped_key) < 24: 135 raise InvalidUnwrap("Must be at least 24 bytes") 136 137 if len(wrapped_key) % 8 != 0: 138 raise InvalidUnwrap("The wrapped key must be a multiple of 8 bytes") 139 140 if len(wrapping_key) not in [16, 24, 32]: 141 raise ValueError("The wrapping key must be a valid AES key length") 142 143 aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" 144 r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)] 145 a = r.pop(0) 146 a, r = _unwrap_core(wrapping_key, a, r, backend) 147 if not bytes_eq(a, aiv): 148 raise InvalidUnwrap() 149 150 return b"".join(r) 151 152 153class InvalidUnwrap(Exception): 154 pass 155