• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: BSD-2-Clause
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (c) 2013, Marc Horowitz
5# Copyright (C) 2013, Massachusetts Institute of Technology
6# Copyright (C) 2022-2024, Gabriel Potter and the secdev/scapy community
7
8"""
9Implementation of cryptographic functions for Kerberos 5
10
11- RFC 3961: Encryption and Checksum Specifications for Kerberos 5
12- RFC 3962: Advanced Encryption Standard (AES) Encryption for Kerberos 5
13- RFC 4757: The RC4-HMAC Kerberos Encryption Types Used by Microsoft Windows
14- RFC 6113: A Generalized Framework for Kerberos Pre-Authentication
15- RFC 8009: AES Encryption with HMAC-SHA2 for Kerberos 5
16"""
17
18# TODO: support cipher states...
19
20__all__ = [
21    "EncryptionType",
22    "ChecksumType",
23    "Key",
24    "InvalidChecksum",
25    "_rfc1964pad",
26]
27
28# The following is a heavily modified version of
29# https://github.com/SecureAuthCorp/impacket/blob/3ec59074ec35c06bbd4312d1042f0e23f4a1b41f/impacket/krb5/crypto.py
30# itself heavily inspired from
31# https://github.com/mhorowitz/pykrb5/blob/master/krb5/crypto.py
32# Note that the following work is based only on THIS COMMIT from impacket,
33# which is therefore under mhorowitz's BSD 2-clause "simplified" license.
34
35import abc
36import enum
37import math
38import os
39import struct
40from scapy.compat import (
41    orb,
42    chb,
43    int_bytes,
44    bytes_int,
45    plain_str,
46)
47
48# Typing
49from typing import (
50    Any,
51    Callable,
52    List,
53    Optional,
54    Type,
55    Union,
56)
57
58# We end up using our own crypto module for hashes / hmac because
59# we need MD4 which was dropped everywhere. It's just a wrapper above
60# the builtin python ones (except for MD4).
61
62from scapy.layers.tls.crypto.hash import (
63    _GenericHash,
64    Hash_MD4,
65    Hash_MD5,
66    Hash_SHA,
67    Hash_SHA256,
68    Hash_SHA384,
69)
70from scapy.layers.tls.crypto.h_mac import (
71    Hmac,
72    Hmac_MD5,
73    Hmac_SHA,
74)
75
76# For everything else, use cryptography.
77
78try:
79    from cryptography.hazmat.primitives import hashes
80    from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
81    from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
82
83    try:
84        # cryptography > 43.0
85        from cryptography.hazmat.decrepit.ciphers import (
86            algorithms as decrepit_algorithms,
87        )
88    except ImportError:
89        decrepit_algorithms = algorithms
90except ImportError:
91    raise ImportError("To use kerberos cryptography, you need to install cryptography.")
92
93
94# cryptography's TripleDES allow the usage of a 56bit key, which thus behaves like DES
95DES = decrepit_algorithms.TripleDES
96
97
98# https://go.microsoft.com/fwlink/?LinkId=186039
99# https://csrc.nist.gov/CSRC/media/Publications/sp/800-108/archive/2008-11-06/documents/sp800-108-Nov2008.pdf
100# [SP800-108] section 5.1 (used in [MS-SMB2] sect 3.1.4.2)
101
102
103def SP800108_KDFCTR(
104    K_I: bytes,
105    Label: bytes,
106    Context: bytes,
107    L: int,
108    hashmod: _GenericHash = Hash_SHA256,
109) -> bytes:
110    """
111    KDF in Counter Mode as section 5.1 of [SP800-108]
112
113    This assumes r=32, and defaults to SHA256 ([MS-SMB2] default).
114    """
115    PRF = Hmac(K_I, hashmod).digest
116    h = hashmod.hash_len
117    n = math.ceil(L / h)
118    if n >= 0xFFFFFFFF:
119        # 2^r-1 = 0xffffffff with r=32 per [MS-SMB2]
120        raise ValueError("Invalid n value in SP800108_KDFCTR")
121    result = b"".join(
122        PRF(struct.pack(">I", i) + Label + b"\x00" + Context + struct.pack(">I", L))
123        for i in range(1, n + 1)
124    )
125    return result[: L // 8]
126
127
128# https://www.iana.org/assignments/kerberos-parameters/kerberos-parameters.xhtml#kerberos-parameters-1
129
130
131class EncryptionType(enum.IntEnum):
132    DES_CBC_CRC = 1
133    DES_CBC_MD4 = 2
134    DES_CBC_MD5 = 3
135    # DES3_CBC_SHA1 = 7
136    DES3_CBC_SHA1_KD = 16
137    AES128_CTS_HMAC_SHA1_96 = 17
138    AES256_CTS_HMAC_SHA1_96 = 18
139    AES128_CTS_HMAC_SHA256_128 = 19
140    AES256_CTS_HMAC_SHA384_192 = 20
141    RC4_HMAC = 23
142    RC4_HMAC_EXP = 24
143    # CAMELLIA128-CTS-CMAC = 25
144    # CAMELLIA256-CTS-CMAC = 26
145
146
147# https://www.iana.org/assignments/kerberos-parameters/kerberos-parameters.xhtml#kerberos-parameters-2
148
149
150class ChecksumType(enum.IntEnum):
151    CRC32 = 1
152    # RSA_MD4 = 2
153    RSA_MD4_DES = 3
154    # RSA_MD5 = 7
155    RSA_MD5_DES = 8
156    # RSA_MD5_DES3 = 9
157    # SHA1 = 10
158    HMAC_SHA1_DES3_KD = 12
159    # HMAC_SHA1_DES3 = 13
160    # SHA1 = 14
161    HMAC_SHA1_96_AES128 = 15
162    HMAC_SHA1_96_AES256 = 16
163    # CMAC-CAMELLIA128 = 17
164    # CMAC-CAMELLIA256 = 18
165    HMAC_SHA256_128_AES128 = 19
166    HMAC_SHA384_192_AES256 = 20
167    HMAC_MD5 = -138
168
169
170class InvalidChecksum(ValueError):
171    pass
172
173
174#########
175# Utils #
176#########
177
178
179# https://www.gnu.org/software/shishi/ides.pdf - APPENDIX B
180
181
182def _n_fold(s, n):
183    # type: (bytes, int) -> bytes
184    """
185    n-fold is an algorithm that takes m input bits and "stretches" them
186    to form n output bits with equal contribution from each input bit to
187    the output (quote from RFC 3961 sect 3.1).
188    """
189
190    def rot13(y, nb):
191        # type: (bytes, int) -> bytes
192        x = bytes_int(y)
193        mod = (1 << (nb * 8)) - 1
194        if nb == 0:
195            return y
196        elif nb == 1:
197            return int_bytes(((x >> 5) | (x << (nb * 8 - 5))) & mod, nb)
198        else:
199            return int_bytes(((x >> 13) | (x << (nb * 8 - 13))) & mod, nb)
200
201    def ocadd(x, y, nb):
202        # type: (bytearray, bytearray, int) -> bytearray
203        v = [a + b for a, b in zip(x, y)]
204        while any(x & ~0xFF for x in v):
205            v = [(v[i - nb + 1] >> 8) + (v[i] & 0xFF) for i in range(nb)]
206        return bytearray(x for x in v)
207
208    m = len(s)
209    lcm = n // math.gcd(n, m) * m  # lcm = math.lcm(n, m) on Python>=3.9
210    buf = bytearray()
211    for _ in range(lcm // m):
212        buf += s
213        s = rot13(s, m)
214    out = bytearray(b"\x00" * n)
215    for i in range(0, lcm, n):
216        out = ocadd(out, buf[i : i + n], n)
217    return bytes(out)
218
219
220def _zeropad(s, padsize):
221    # type: (bytes, int) -> bytes
222    """
223    Return s padded with 0 bytes to a multiple of padsize.
224    """
225    return s + b"\x00" * (-len(s) % padsize)
226
227
228def _rfc1964pad(s):
229    # type: (bytes) -> bytes
230    """
231    Return s padded as RFC1964 mandates
232    """
233    pad = (-len(s)) % 8
234    return s + pad * struct.pack("!B", pad)
235
236
237def _xorbytes(b1, b2):
238    # type: (bytearray, bytearray) -> bytearray
239    """
240    xor two strings together and return the resulting string
241    """
242    assert len(b1) == len(b2)
243    return bytearray((x ^ y) for x, y in zip(b1, b2))
244
245
246def _mac_equal(mac1, mac2):
247    # type: (bytes, bytes) -> bool
248    # Constant-time comparison function.  (We can't use HMAC.verify
249    # since we use truncated macs.)
250    return all(x == y for x, y in zip(mac1, mac2))
251
252
253# https://doi.org/10.6028/NBS.FIPS.74 sect 3.6
254
255WEAK_DES_KEYS = set(
256    [
257        # 1
258        b"\xe0\x01\xe0\x01\xf1\x01\xf1\x01",
259        b"\x01\xe0\x01\xe0\x01\xf1\x01\xf1",
260        # 2
261        b"\xfe\x1f\xfe\x1f\xfe\x0e\xfe\x0e",
262        b"\x1f\xfe\x1f\xfe\x0e\xfe\x0e\xfe",
263        # 3
264        b"\xe0\x1f\xe0\x1f\xf1\x0e\xf1\x0e",
265        b"\x1f\xe0\x1f\xe0\x0e\xf1\x0e\xf1",
266        # 4
267        b"\x01\xfe\x01\xfe\x01\xfe\x01\xfe",
268        b"\xfe\x01\xfe\x01\xfe\x01\xfe\x01",
269        # 5
270        b"\x01\x1f\x01\x1f\x01\x0e\x01\x0e",
271        b"\x1f\x01\x1f\x01\x0e\x01\x0e\x01",
272        # 6
273        b"\xe0\xfe\xe0\xfe\xf1\xfe\xf1\xfe",
274        b"\xfe\xe0\xfe\xe0\xfe\xf1\xfe\xf1",
275        # 7
276        b"\x01" * 8,
277        # 8
278        b"\xfe" * 8,
279        # 9
280        b"\xe0" * 4 + b"\xf1" * 4,
281        # 10
282        b"\x1f" * 4 + b"\x0e" * 4,
283    ]
284)
285
286# fmt: off
287CRC32_TABLE = [
288    0x00000000, 0x77073096, 0xee0e612c, 0x990951ba,
289    0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3,
290    0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988,
291    0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91,
292    0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de,
293    0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7,
294    0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec,
295    0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5,
296    0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172,
297    0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b,
298    0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940,
299    0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59,
300    0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116,
301    0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f,
302    0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924,
303    0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d,
304    0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a,
305    0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433,
306    0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818,
307    0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01,
308    0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e,
309    0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457,
310    0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c,
311    0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65,
312    0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2,
313    0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb,
314    0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0,
315    0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9,
316    0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086,
317    0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f,
318    0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4,
319    0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad,
320    0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a,
321    0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683,
322    0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8,
323    0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1,
324    0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe,
325    0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7,
326    0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc,
327    0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5,
328    0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252,
329    0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b,
330    0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60,
331    0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79,
332    0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236,
333    0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f,
334    0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04,
335    0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d,
336    0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a,
337    0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713,
338    0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38,
339    0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21,
340    0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e,
341    0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777,
342    0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c,
343    0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45,
344    0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2,
345    0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db,
346    0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0,
347    0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9,
348    0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6,
349    0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf,
350    0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94,
351    0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d
352]
353# fmt: on
354
355############
356# RFC 3961 #
357############
358
359
360# RFC3961 sect 3
361
362
363class _EncryptionAlgorithmProfile(abc.ABCMeta):
364    """
365    Base class for etype profiles.
366
367    Usable etype classes must define:
368    :attr etype: etype number
369    :attr keysize: protocol size of key in bytes
370    :attr seedsize: random_to_key input size in bytes
371    :attr reqcksum: 'required checksum mechanism' per RFC3961.
372                    this is the default checksum used for this algorithm.
373    :attr random_to_key: (if the keyspace is not dense)
374    :attr string_to_key:
375    :attr encrypt:
376    :attr decrypt:
377    :attr prf:
378    """
379
380    etype = None  # type: EncryptionType
381    keysize = None  # type: int
382    seedsize = None  # type: int
383    reqcksum = None  # type: ChecksumType
384
385    @classmethod
386    @abc.abstractmethod
387    def derive(cls, key, constant):
388        # type: (Key, bytes) -> bytes
389        pass
390
391    @classmethod
392    @abc.abstractmethod
393    def encrypt(cls, key, keyusage, plaintext, confounder):
394        # type: (Key, int, bytes, Optional[bytes]) -> bytes
395        pass
396
397    @classmethod
398    @abc.abstractmethod
399    def decrypt(cls, key, keyusage, ciphertext):
400        # type: (Key, int, bytes) -> bytes
401        pass
402
403    @classmethod
404    @abc.abstractmethod
405    def prf(cls, key, string):
406        # type: (Key, bytes) -> bytes
407        pass
408
409    @classmethod
410    @abc.abstractmethod
411    def string_to_key(cls, string, salt, params):
412        # type: (bytes, bytes, Optional[bytes]) -> Key
413        pass
414
415    @classmethod
416    def random_to_key(cls, seed):
417        # type: (bytes) -> Key
418        if len(seed) != cls.seedsize:
419            raise ValueError("Wrong seed length")
420        return Key(cls.etype, key=seed)
421
422
423# RFC3961 sect 4
424
425
426class _ChecksumProfile(object):
427    """
428    Base class for checksum profiles.
429
430    Usable checksum classes must define:
431    :func checksum:
432    :attr macsize: Size of checksum in bytes
433    :func verify: (if verification is not just checksum-and-compare)
434    """
435
436    macsize = None  # type: int
437
438    @classmethod
439    @abc.abstractmethod
440    def checksum(cls, key, keyusage, text):
441        # type: (Key, int, bytes) -> bytes
442        pass
443
444    @classmethod
445    def verify(cls, key, keyusage, text, cksum):
446        # type: (Key, int, bytes, bytes) -> None
447        expected = cls.checksum(key, keyusage, text)
448        if not _mac_equal(cksum, expected):
449            raise InvalidChecksum("checksum verification failure")
450
451
452# RFC3961 sect 5.3
453
454
455class _SimplifiedEncryptionProfile(_EncryptionAlgorithmProfile):
456    """
457    Base class for etypes using the RFC 3961 simplified profile.
458    Defines the encrypt, decrypt, and prf methods.
459
460    Subclasses must define:
461
462    :param blocksize: Underlying cipher block size in bytes
463    :param padsize: Underlying cipher padding multiple (1 or blocksize)
464    :param macsize: Size of integrity MAC in bytes
465    :param hashmod: underlying hash function
466    :param basic_encrypt, basic_decrypt: Underlying CBC/CTS cipher
467    """
468
469    blocksize = None  # type: int
470    padsize = None  # type: int
471    macsize = None  # type: int
472    hashmod = None  # type: Any
473
474    # Used in RFC 8009. This is not a simplified profile per se but
475    # is still pretty close.
476    rfc8009 = False
477
478    @classmethod
479    @abc.abstractmethod
480    def basic_encrypt(cls, key, plaintext):
481        # type: (bytes, bytes) -> bytes
482        pass
483
484    @classmethod
485    @abc.abstractmethod
486    def basic_decrypt(cls, key, ciphertext):
487        # type: (bytes, bytes) -> bytes
488        pass
489
490    @classmethod
491    def derive(cls, key, constant):
492        # type: (Key, bytes) -> bytes
493        """
494        Also known as "DK" in RFC3961.
495        """
496        # RFC 3961 only says to n-fold the constant only if it is
497        # shorter than the cipher block size.  But all Unix
498        # implementations n-fold constants if their length is larger
499        # than the block size as well, and n-folding when the length
500        # is equal to the block size is a no-op.
501        plaintext = _n_fold(constant, cls.blocksize)
502        rndseed = b""
503        while len(rndseed) < cls.seedsize:
504            ciphertext = cls.basic_encrypt(key.key, plaintext)
505            rndseed += ciphertext
506            plaintext = ciphertext
507        # DK(Key, Constant) = random-to-key(DR(Key, Constant))
508        return cls.random_to_key(rndseed[0 : cls.seedsize]).key
509
510    @classmethod
511    def encrypt(cls, key, keyusage, plaintext, confounder, signtext=None):
512        # type: (Key, int, bytes, Optional[bytes], Optional[bytes]) -> bytes
513        """
514        Encryption function.
515
516        :param key: the key
517        :param keyusage: the keyusage
518        :param plaintext: the text to encrypt
519        :param confounder: (optional) the confounder. If none, will be random
520        :param signtext: (optional) make the checksum include different data than what
521                         is encrypted. Useful for kerberos GSS_WrapEx. If none, same as
522                         plaintext.
523        """
524        if not cls.rfc8009:
525            ki = cls.derive(key, struct.pack(">IB", keyusage, 0x55))
526            ke = cls.derive(key, struct.pack(">IB", keyusage, 0xAA))
527        else:
528            ki = cls.derive(key, struct.pack(">IB", keyusage, 0x55), cls.macsize * 8)  # type: ignore  # noqa: E501
529            ke = cls.derive(key, struct.pack(">IB", keyusage, 0xAA), cls.keysize * 8)  # type: ignore  # noqa: E501
530        if confounder is None:
531            confounder = os.urandom(cls.blocksize)
532        basic_plaintext = confounder + _zeropad(plaintext, cls.padsize)
533        if signtext is None:
534            signtext = basic_plaintext
535        if not cls.rfc8009:
536            # Simplified profile
537            hmac = Hmac(ki, cls.hashmod).digest(signtext)
538            return cls.basic_encrypt(ke, basic_plaintext) + hmac[: cls.macsize]
539        else:
540            # RFC 8009
541            C = cls.basic_encrypt(ke, basic_plaintext)
542            hmac = Hmac(ki, cls.hashmod).digest(b"\0" * 16 + C)  # XXX IV
543            return C + hmac[: cls.macsize]
544
545    @classmethod
546    def decrypt(cls, key, keyusage, ciphertext, presignfunc=None):
547        # type: (Key, int, bytes, Optional[Callable[[bytes, bytes], bytes]]) -> bytes
548        """
549        decryption function
550        """
551        if not cls.rfc8009:
552            ki = cls.derive(key, struct.pack(">IB", keyusage, 0x55))
553            ke = cls.derive(key, struct.pack(">IB", keyusage, 0xAA))
554        else:
555            ki = cls.derive(key, struct.pack(">IB", keyusage, 0x55), cls.macsize * 8)  # type: ignore  # noqa: E501
556            ke = cls.derive(key, struct.pack(">IB", keyusage, 0xAA), cls.keysize * 8)  # type: ignore  # noqa: E501
557        if len(ciphertext) < cls.blocksize + cls.macsize:
558            raise ValueError("Ciphertext too short")
559        basic_ctext, mac = ciphertext[: -cls.macsize], ciphertext[-cls.macsize :]
560        if len(basic_ctext) % cls.padsize != 0:
561            raise ValueError("ciphertext does not meet padding requirement")
562        if not cls.rfc8009:
563            # Simplified profile
564            basic_plaintext = cls.basic_decrypt(ke, basic_ctext)
565            signtext = basic_plaintext
566            if presignfunc:
567                # Allow to have additional processing of the data that is to be signed.
568                # This is useful for GSS_WrapEx
569                signtext = presignfunc(
570                    basic_plaintext[: cls.blocksize],
571                    basic_plaintext[cls.blocksize :],
572                )
573            hmac = Hmac(ki, cls.hashmod).digest(signtext)
574            expmac = hmac[: cls.macsize]
575            if not _mac_equal(mac, expmac):
576                raise ValueError("ciphertext integrity failure")
577        else:
578            # RFC 8009
579            signtext = b"\0" * 16 + basic_ctext  # XXX IV
580            if presignfunc:
581                # Allow to have additional processing of the data that is to be signed.
582                # This is useful for GSS_WrapEx
583                signtext = presignfunc(
584                    basic_ctext[16 : 16 + cls.blocksize],
585                    basic_ctext[16 + cls.blocksize :],
586                )
587            hmac = Hmac(ki, cls.hashmod).digest(signtext)
588            expmac = hmac[: cls.macsize]
589            if not _mac_equal(mac, expmac):
590                raise ValueError("ciphertext integrity failure")
591            basic_plaintext = cls.basic_decrypt(ke, basic_ctext)
592        # Discard the confounder.
593        return bytes(basic_plaintext[cls.blocksize :])
594
595    @classmethod
596    def prf(cls, key, string):
597        # type: (Key, bytes) -> bytes
598        """
599        pseudo-random function
600        """
601        # Hash the input.  RFC 3961 says to truncate to the padding
602        # size, but implementations truncate to the block size.
603        hashval = cls.hashmod().digest(string)
604        if len(hashval) % cls.blocksize:
605            hashval = hashval[: -(len(hashval) % cls.blocksize)]
606        # Encrypt the hash with a derived key.
607        kp = cls.derive(key, b"prf")
608        return cls.basic_encrypt(kp, hashval)
609
610
611# RFC3961 sect 5.4
612
613
614class _SimplifiedChecksum(_ChecksumProfile):
615    """
616    Base class for checksums using the RFC 3961 simplified profile.
617    Defines the checksum and verify methods.
618
619    Subclasses must define:
620    :attr enc: Profile of associated etype
621    """
622
623    enc = None  # type: Type[_SimplifiedEncryptionProfile]
624
625    # Used in RFC 8009. This is not a simplified profile per se but
626    # is still pretty close.
627    rfc8009 = False
628
629    @classmethod
630    def checksum(cls, key, keyusage, text):
631        # type: (Key, int, bytes) -> bytes
632        if not cls.rfc8009:
633            # Simplified profile
634            kc = cls.enc.derive(key, struct.pack(">IB", keyusage, 0x99))
635        else:
636            # RFC 8009
637            kc = cls.enc.derive(  # type: ignore
638                key, struct.pack(">IB", keyusage, 0x99), cls.macsize * 8
639            )
640        hmac = Hmac(kc, cls.enc.hashmod).digest(text)
641        return hmac[: cls.macsize]
642
643    @classmethod
644    def verify(cls, key, keyusage, text, cksum):
645        # type: (Key, int, bytes, bytes) -> None
646        if key.etype != cls.enc.etype:
647            raise ValueError("Wrong key type for checksum")
648        super(_SimplifiedChecksum, cls).verify(key, keyusage, text, cksum)
649
650
651# RFC3961 sect 6.1
652
653
654class _CRC32(_ChecksumProfile):
655    macsize = 4
656
657    # This isn't your usual CRC32, it's a "modified version" according to the RFC3961.
658    # Another RFC states it's just a buggy version of the actual CRC32.
659
660    @classmethod
661    def checksum(cls, key, keyusage, text):
662        # type: (Optional[Key], int, bytes) -> bytes
663        c = 0
664        for i in range(len(text)):
665            idx = text[i] ^ c
666            idx &= 0xFF
667            c >>= 8
668            c ^= CRC32_TABLE[idx]
669        return c.to_bytes(4, "little")
670
671
672# RFC3961 sect 6.2
673
674
675class _DESCBC(_SimplifiedEncryptionProfile):
676    keysize = 8
677    seedsize = 8
678    blocksize = 8
679    padsize = 8
680    macsize = 16
681    hashmod = Hash_MD5
682
683    @classmethod
684    def encrypt(cls, key, keyusage, plaintext, confounder, signtext=None):
685        # type: (Key, int, bytes, Optional[bytes], Any) -> bytes
686        if confounder is None:
687            confounder = os.urandom(cls.blocksize)
688        basic_plaintext = (
689            confounder + b"\x00" * cls.macsize + _zeropad(plaintext, cls.padsize)
690        )
691        checksum = cls.hashmod().digest(basic_plaintext)
692        basic_plaintext = (
693            basic_plaintext[: len(confounder)]
694            + checksum
695            + basic_plaintext[len(confounder) + len(checksum) :]
696        )
697        return cls.basic_encrypt(key.key, basic_plaintext)
698
699    @classmethod
700    def decrypt(cls, key, keyusage, ciphertext, presignfunc=None):
701        # type: (Key, int, bytes, Any) -> bytes
702        if len(ciphertext) < cls.blocksize + cls.macsize:
703            raise ValueError("ciphertext too short")
704
705        complex_plaintext = cls.basic_decrypt(key.key, ciphertext)
706        cofounder = complex_plaintext[: cls.padsize]
707        mac = complex_plaintext[cls.padsize : cls.padsize + cls.macsize]
708        message = complex_plaintext[cls.padsize + cls.macsize :]
709
710        expmac = cls.hashmod().digest(cofounder + b"\x00" * cls.macsize + message)
711        if not _mac_equal(mac, expmac):
712            raise InvalidChecksum("ciphertext integrity failure")
713        return bytes(message)
714
715    @classmethod
716    def mit_des_string_to_key(cls, string, salt):
717        # type: (bytes, bytes) -> Key
718        def fixparity(deskey):
719            # type: (List[int]) -> bytes
720            temp = b""
721            for i in range(len(deskey)):
722                t = (bin(orb(deskey[i]))[2:]).rjust(8, "0")
723                if t[:7].count("1") % 2 == 0:
724                    temp += chb(int(t[:7] + "1", 2))
725                else:
726                    temp += chb(int(t[:7] + "0", 2))
727            return temp
728
729        def addparity(l1):
730            # type: (List[int]) -> List[int]
731            temp = list()
732            for byte in l1:
733                if (bin(byte).count("1") % 2) == 0:
734                    byte = (byte << 1) | 0b00000001
735                else:
736                    byte = (byte << 1) & 0b11111110
737                temp.append(byte)
738            return temp
739
740        def XOR(l1, l2):
741            # type: (List[int], List[int]) -> List[int]
742            temp = list()
743            for b1, b2 in zip(l1, l2):
744                temp.append((b1 ^ b2) & 0b01111111)
745
746            return temp
747
748        odd = True
749        tempstring = [0, 0, 0, 0, 0, 0, 0, 0]
750        s = _zeropad(string + salt, cls.padsize)
751
752        for block in [s[i : i + 8] for i in range(0, len(s), 8)]:
753            temp56 = list()
754            # removeMSBits
755            for byte in block:
756                temp56.append(orb(byte) & 0b01111111)
757
758            # reverse
759            if odd is False:
760                bintemp = b""
761                for byte in temp56:
762                    bintemp += bin(byte)[2:].rjust(7, "0").encode()
763                bintemp = bintemp[::-1]
764
765                temp56 = list()
766                for bits7 in [bintemp[i : i + 7] for i in range(0, len(bintemp), 7)]:
767                    temp56.append(int(bits7, 2))
768
769            odd = not odd
770            tempstring = XOR(tempstring, temp56)
771
772        tempkey = bytearray(b"".join(chb(byte) for byte in addparity(tempstring)))
773        if bytes(tempkey) in WEAK_DES_KEYS:
774            tempkey[7] = tempkey[7] ^ 0xF0
775
776        tempkeyb = bytes(tempkey)
777        des = Cipher(DES(tempkeyb), modes.CBC(tempkeyb)).encryptor()
778        chekcsumkey = des.update(s)[-8:]
779        chekcsumkey = bytearray(fixparity(chekcsumkey))
780        if bytes(chekcsumkey) in WEAK_DES_KEYS:
781            chekcsumkey[7] = chekcsumkey[7] ^ 0xF0
782
783        return Key(cls.etype, key=bytes(chekcsumkey))
784
785    @classmethod
786    def basic_encrypt(cls, key, plaintext):
787        # type: (bytes, bytes) -> bytes
788        assert len(plaintext) % 8 == 0
789        des = Cipher(DES(key), modes.CBC(b"\0" * 8)).encryptor()
790        return des.update(bytes(plaintext))
791
792    @classmethod
793    def basic_decrypt(cls, key, ciphertext):
794        # type: (bytes, bytes) -> bytes
795        assert len(ciphertext) % 8 == 0
796        des = Cipher(DES(key), modes.CBC(b"\0" * 8)).decryptor()
797        return des.update(bytes(ciphertext))
798
799    @classmethod
800    def string_to_key(cls, string, salt, params):
801        # type: (bytes, bytes, Optional[bytes]) -> Key
802        if params is not None and params != b"":
803            raise ValueError("Invalid DES string-to-key parameters")
804        key = cls.mit_des_string_to_key(string, salt)
805        return key
806
807
808# RFC3961 sect 6.2.1
809
810
811class _DESMD5(_DESCBC):
812    etype = EncryptionType.DES_CBC_MD5
813    hashmod = Hash_MD5
814    reqcksum = ChecksumType.RSA_MD5_DES
815
816
817# RFC3961 sect 6.2.2
818
819
820class _DESMD4(_DESCBC):
821    etype = EncryptionType.DES_CBC_MD4
822    hashmod = Hash_MD4
823    reqcksum = ChecksumType.RSA_MD4_DES
824
825
826# RFC3961 sect 6.3
827
828
829class _DES3CBC(_SimplifiedEncryptionProfile):
830    etype = EncryptionType.DES3_CBC_SHA1_KD
831    keysize = 24
832    seedsize = 21
833    blocksize = 8
834    padsize = 8
835    macsize = 20
836    hashmod = Hash_SHA
837    reqcksum = ChecksumType.HMAC_SHA1_DES3_KD
838
839    @classmethod
840    def random_to_key(cls, seed):
841        # type: (bytes) -> Key
842        # XXX Maybe reframe as _DESEncryptionType.random_to_key and use that
843        # way from DES3 random-to-key when DES is implemented, since
844        # MIT does this instead of the RFC 3961 random-to-key.
845        def expand(seed):
846            # type: (bytes) -> bytes
847            def parity(b):
848                # type: (int) -> int
849                # Return b with the low-order bit set to yield odd parity.
850                b &= ~1
851                return b if bin(b & ~1).count("1") % 2 else b | 1
852
853            assert len(seed) == 7
854            firstbytes = [parity(b & ~1) for b in seed]
855            lastbyte = parity(sum((seed[i] & 1) << i + 1 for i in range(7)))
856            keybytes = bytearray(firstbytes + [lastbyte])
857            if bytes(keybytes) in WEAK_DES_KEYS:
858                keybytes[7] = keybytes[7] ^ 0xF0
859            return bytes(keybytes)
860
861        if len(seed) != 21:
862            raise ValueError("Wrong seed length")
863        k1, k2, k3 = expand(seed[:7]), expand(seed[7:14]), expand(seed[14:])
864        return Key(cls.etype, key=k1 + k2 + k3)
865
866    @classmethod
867    def string_to_key(cls, string, salt, params):
868        # type: (bytes, bytes, Optional[bytes]) -> Key
869        if params is not None and params != b"":
870            raise ValueError("Invalid DES3 string-to-key parameters")
871        k = cls.random_to_key(_n_fold(string + salt, 21))
872        return Key(
873            cls.etype,
874            key=cls.derive(k, b"kerberos"),
875        )
876
877    @classmethod
878    def basic_encrypt(cls, key, plaintext):
879        # type: (bytes, bytes) -> bytes
880        assert len(plaintext) % 8 == 0
881        des3 = Cipher(
882            decrepit_algorithms.TripleDES(key), modes.CBC(b"\0" * 8)
883        ).encryptor()
884        return des3.update(bytes(plaintext))
885
886    @classmethod
887    def basic_decrypt(cls, key, ciphertext):
888        # type: (bytes, bytes) -> bytes
889        assert len(ciphertext) % 8 == 0
890        des3 = Cipher(
891            decrepit_algorithms.TripleDES(key), modes.CBC(b"\0" * 8)
892        ).decryptor()
893        return des3.update(bytes(ciphertext))
894
895
896class _SHA1DES3(_SimplifiedChecksum):
897    macsize = 20
898    enc = _DES3CBC
899
900
901############
902# RFC 3962 #
903############
904
905
906# RFC3962 sect 6
907
908
909class _AESEncryptionType_SHA1_96(_SimplifiedEncryptionProfile, abc.ABCMeta):
910    blocksize = 16
911    padsize = 1
912    macsize = 12
913    hashmod = Hash_SHA
914
915    @classmethod
916    def string_to_key(cls, string, salt, params):
917        # type: (bytes, bytes, Optional[bytes]) -> Key
918        iterations = struct.unpack(">L", params or b"\x00\x00\x10\x00")[0]
919        kdf = PBKDF2HMAC(
920            algorithm=hashes.SHA1(),
921            length=cls.seedsize,
922            salt=salt,
923            iterations=iterations,
924        )
925        tkey = cls.random_to_key(kdf.derive(string))
926        return Key(
927            cls.etype,
928            key=cls.derive(tkey, b"kerberos"),
929        )
930
931    # basic_encrypt and basic_decrypt implement AES in CBC-CS3 mode
932
933    @classmethod
934    def basic_encrypt(cls, key, plaintext):
935        # type: (bytes, bytes) -> bytes
936        assert len(plaintext) >= 16
937        aes = Cipher(algorithms.AES(key), modes.CBC(b"\0" * 16)).encryptor()
938        ctext = aes.update(_zeropad(bytes(plaintext), 16))
939        if len(plaintext) > 16:
940            # Swap the last two ciphertext blocks and truncate the
941            # final block to match the plaintext length.
942            lastlen = len(plaintext) % 16 or 16
943            ctext = ctext[:-32] + ctext[-16:] + ctext[-32:-16][:lastlen]
944        return ctext
945
946    @classmethod
947    def basic_decrypt(cls, key, ciphertext):
948        # type: (bytes, bytes) -> bytes
949        assert len(ciphertext) >= 16
950        aes = Cipher(algorithms.AES(key), modes.ECB()).decryptor()
951        if len(ciphertext) == 16:
952            return aes.update(ciphertext)
953        # Split the ciphertext into blocks.  The last block may be partial.
954        cblocks = [
955            bytearray(ciphertext[p : p + 16]) for p in range(0, len(ciphertext), 16)
956        ]
957        lastlen = len(cblocks[-1])
958        # CBC-decrypt all but the last two blocks.
959        prev_cblock = bytearray(16)
960        plaintext = b""
961        for bb in cblocks[:-2]:
962            plaintext += _xorbytes(bytearray(aes.update(bytes(bb))), prev_cblock)
963            prev_cblock = bb
964        # Decrypt the second-to-last cipher block.  The left side of
965        # the decrypted block will be the final block of plaintext
966        # xor'd with the final partial cipher block; the right side
967        # will be the omitted bytes of ciphertext from the final
968        # block.
969        bb = bytearray(aes.update(bytes(cblocks[-2])))
970        lastplaintext = _xorbytes(bb[:lastlen], cblocks[-1])
971        omitted = bb[lastlen:]
972        # Decrypt the final cipher block plus the omitted bytes to get
973        # the second-to-last plaintext block.
974        plaintext += _xorbytes(
975            bytearray(aes.update(bytes(cblocks[-1]) + bytes(omitted))), prev_cblock
976        )
977        return plaintext + lastplaintext
978
979
980# RFC3962 sect 7
981
982
983class _AES128CTS_SHA1_96(_AESEncryptionType_SHA1_96):
984    etype = EncryptionType.AES128_CTS_HMAC_SHA1_96
985    keysize = 16
986    seedsize = 16
987    reqcksum = ChecksumType.HMAC_SHA1_96_AES128
988
989
990class _AES256CTS_SHA1_96(_AESEncryptionType_SHA1_96):
991    etype = EncryptionType.AES256_CTS_HMAC_SHA1_96
992    keysize = 32
993    seedsize = 32
994    reqcksum = ChecksumType.HMAC_SHA1_96_AES256
995
996
997class _SHA1_96_AES128(_SimplifiedChecksum):
998    macsize = 12
999    enc = _AES128CTS_SHA1_96
1000
1001
1002class _SHA1_96_AES256(_SimplifiedChecksum):
1003    macsize = 12
1004    enc = _AES256CTS_SHA1_96
1005
1006
1007############
1008# RFC 4757 #
1009############
1010
1011# RFC4757 sect 4
1012
1013
1014class _HMACMD5(_ChecksumProfile):
1015    macsize = 16
1016
1017    @classmethod
1018    def checksum(cls, key, keyusage, text):
1019        # type: (Key, int, bytes) -> bytes
1020        ksign = Hmac_MD5(key.key).digest(b"signaturekey\0")
1021        md5hash = Hash_MD5().digest(_RC4.usage_str(keyusage) + text)
1022        return Hmac_MD5(ksign).digest(md5hash)
1023
1024    @classmethod
1025    def verify(cls, key, keyusage, text, cksum):
1026        # type: (Key, int, bytes, bytes) -> None
1027        if key.etype not in [EncryptionType.RC4_HMAC, EncryptionType.RC4_HMAC_EXP]:
1028            raise ValueError("Wrong key type for checksum")
1029        super(_HMACMD5, cls).verify(key, keyusage, text, cksum)
1030
1031
1032# RFC4757 sect 5
1033
1034
1035class _RC4(_EncryptionAlgorithmProfile):
1036    etype = EncryptionType.RC4_HMAC
1037    keysize = 16
1038    seedsize = 16
1039    reqcksum = ChecksumType.HMAC_MD5
1040    export = False
1041
1042    @staticmethod
1043    def usage_str(keyusage):
1044        # type: (int) -> bytes
1045        # Return a four-byte string for an RFC 3961 keyusage, using
1046        # the RFC 4757 rules sect 3. Per the errata, do not map 9 to 8.
1047        table = {3: 8, 23: 13}
1048        msusage = table[keyusage] if keyusage in table else keyusage
1049        return struct.pack("<I", msusage)
1050
1051    @classmethod
1052    def string_to_key(cls, string, salt, params):
1053        # type: (bytes, bytes, Optional[bytes]) -> Key
1054        if params is not None and params != b"":
1055            raise ValueError("Invalid RC4 string-to-key parameters")
1056        utf16string = plain_str(string).encode("UTF-16LE")
1057        return Key(cls.etype, key=Hash_MD4().digest(utf16string))
1058
1059    @classmethod
1060    def encrypt(cls, key, keyusage, plaintext, confounder):
1061        # type: (Key, int, bytes, Optional[bytes]) -> bytes
1062        if confounder is None:
1063            confounder = os.urandom(8)
1064        if cls.export:
1065            ki = Hmac_MD5(key.key).digest(b"fortybits\x00" + cls.usage_str(keyusage))
1066        else:
1067            ki = Hmac_MD5(key.key).digest(cls.usage_str(keyusage))
1068        cksum = Hmac_MD5(ki).digest(confounder + plaintext)
1069        if cls.export:
1070            ki = ki[:7] + b"\xab" * 9
1071        ke = Hmac_MD5(ki).digest(cksum)
1072        rc4 = Cipher(algorithms.ARC4(ke), mode=None).encryptor()
1073        return cksum + rc4.update(bytes(confounder + plaintext))
1074
1075    @classmethod
1076    def decrypt(cls, key, keyusage, ciphertext):
1077        # type: (Key, int, bytes) -> bytes
1078        if len(ciphertext) < 24:
1079            raise ValueError("ciphertext too short")
1080        cksum, basic_ctext = ciphertext[:16], ciphertext[16:]
1081        if cls.export:
1082            ki = Hmac_MD5(key.key).digest(b"fortybits\x00" + cls.usage_str(keyusage))
1083        else:
1084            ki = Hmac_MD5(key.key).digest(cls.usage_str(keyusage))
1085        if cls.export:
1086            kie = ki[:7] + b"\xab" * 9
1087        else:
1088            kie = ki
1089        ke = Hmac_MD5(kie).digest(cksum)
1090        rc4 = Cipher(decrepit_algorithms.ARC4(ke), mode=None).decryptor()
1091        basic_plaintext = rc4.update(bytes(basic_ctext))
1092        exp_cksum = Hmac_MD5(ki).digest(basic_plaintext)
1093        ok = _mac_equal(cksum, exp_cksum)
1094        if not ok and keyusage == 9:
1095            # Try again with usage 8, due to RFC 4757 errata.
1096            ki = Hmac_MD5(key.key).digest(struct.pack("<I", 8))
1097            exp_cksum = Hmac_MD5(ki).digest(basic_plaintext)
1098            ok = _mac_equal(cksum, exp_cksum)
1099        if not ok:
1100            raise InvalidChecksum("ciphertext integrity failure")
1101        # Discard the confounder.
1102        return bytes(basic_plaintext[8:])
1103
1104    @classmethod
1105    def prf(cls, key, string):
1106        # type: (Key, bytes) -> bytes
1107        return Hmac_SHA(key.key).digest(string)
1108
1109
1110class _RC4_EXPORT(_RC4):
1111    etype = EncryptionType.RC4_HMAC_EXP
1112    export = True
1113
1114
1115############
1116# RFC 8009 #
1117############
1118
1119
1120class _AESEncryptionType_SHA256_SHA384(_AESEncryptionType_SHA1_96, abc.ABCMeta):
1121    enctypename = None  # type: bytes
1122    hashmod: _GenericHash = None  # Scapy
1123    _hashmod: hashes.HashAlgorithm = None  # Cryptography
1124
1125    # Turn on RFC 8009 mode
1126    rfc8009 = True
1127
1128    @classmethod
1129    def derive(cls, key, label, k, context=b""):  # type: ignore
1130        # type: (Key, bytes, int, bytes) -> bytes
1131        """
1132        Also known as "KDF-HMAC-SHA2" in RFC8009.
1133        """
1134        # RFC 8009 sect 3
1135        return SP800108_KDFCTR(
1136            K_I=key.key,
1137            Label=label,
1138            Context=context,
1139            L=k,
1140            hashmod=cls.hashmod,
1141        )
1142
1143    @classmethod
1144    def string_to_key(cls, string, salt, params):
1145        # type: (bytes, bytes, Optional[bytes]) -> Key
1146        # RFC 8009 sect 4
1147        iterations = struct.unpack(">L", params or b"\x00\x00\x80\x00")[0]
1148        saltp = cls.enctypename + b"\x00" + salt
1149        kdf = PBKDF2HMAC(
1150            algorithm=cls._hashmod(),
1151            length=cls.seedsize,
1152            salt=saltp,
1153            iterations=iterations,
1154        )
1155        tkey = cls.random_to_key(kdf.derive(string))
1156        return Key(
1157            cls.etype,
1158            key=cls.derive(tkey, b"kerberos", cls.keysize * 8),
1159        )
1160
1161    @classmethod
1162    def prf(cls, key, string):
1163        # type: (Key, bytes) -> bytes
1164        return cls.derive(key, b"prf", cls.hashmod.hash_len * 8, string)
1165
1166
1167class _AES128CTS_SHA256_128(_AESEncryptionType_SHA256_SHA384):
1168    etype = EncryptionType.AES128_CTS_HMAC_SHA256_128
1169    keysize = 16
1170    seedsize = 16
1171    macsize = 16
1172    reqcksum = ChecksumType.HMAC_SHA256_128_AES128
1173    # _AESEncryptionType_SHA256_SHA384 parameters
1174    enctypename = b"aes128-cts-hmac-sha256-128"
1175    hashmod = Hash_SHA256
1176    _hashmod = hashes.SHA256
1177
1178
1179class _AES256CTS_SHA384_192(_AESEncryptionType_SHA256_SHA384):
1180    etype = EncryptionType.AES256_CTS_HMAC_SHA384_192
1181    keysize = 32
1182    seedsize = 32
1183    macsize = 24
1184    reqcksum = ChecksumType.HMAC_SHA384_192_AES256
1185    # _AESEncryptionType_SHA256_SHA384 parameters
1186    enctypename = b"aes256-cts-hmac-sha384-192"
1187    hashmod = Hash_SHA384
1188    _hashmod = hashes.SHA384
1189
1190
1191class _SHA256_128_AES128(_SimplifiedChecksum):
1192    macsize = 16
1193    enc = _AES128CTS_SHA256_128
1194    rfc8009 = True
1195
1196
1197class _SHA384_182_AES256(_SimplifiedChecksum):
1198    macsize = 24
1199    enc = _AES256CTS_SHA384_192
1200    rfc8009 = True
1201
1202
1203##############
1204# Key object #
1205##############
1206
1207_enctypes = {
1208    # DES_CBC_CRC - UNIMPLEMENTED
1209    EncryptionType.DES_CBC_MD5: _DESMD5,
1210    EncryptionType.DES_CBC_MD4: _DESMD4,
1211    # DES3_CBC_SHA1 - UNIMPLEMENTED
1212    EncryptionType.DES3_CBC_SHA1_KD: _DES3CBC,
1213    EncryptionType.AES128_CTS_HMAC_SHA1_96: _AES128CTS_SHA1_96,
1214    EncryptionType.AES256_CTS_HMAC_SHA1_96: _AES256CTS_SHA1_96,
1215    EncryptionType.AES128_CTS_HMAC_SHA256_128: _AES128CTS_SHA256_128,
1216    EncryptionType.AES256_CTS_HMAC_SHA384_192: _AES256CTS_SHA384_192,
1217    # CAMELLIA128-CTS-CMAC - UNIMPLEMENTED
1218    # CAMELLIA256-CTS-CMAC - UNIMPLEMENTED
1219    EncryptionType.RC4_HMAC: _RC4,
1220    EncryptionType.RC4_HMAC_EXP: _RC4_EXPORT,
1221}
1222
1223
1224_checksums = {
1225    ChecksumType.CRC32: _CRC32,
1226    # RSA_MD4 - UNIMPLEMENTED
1227    # RSA_MD4_DES - UNIMPLEMENTED
1228    # RSA_MD5 - UNIMPLEMENTED
1229    # RSA_MD5_DES - UNIMPLEMENTED
1230    # SHA1 - UNIMPLEMENTED
1231    ChecksumType.HMAC_SHA1_DES3_KD: _SHA1DES3,
1232    # HMAC_SHA1_DES3 - UNIMPLEMENTED
1233    ChecksumType.HMAC_SHA1_96_AES128: _SHA1_96_AES128,
1234    ChecksumType.HMAC_SHA1_96_AES256: _SHA1_96_AES256,
1235    # CMAC-CAMELLIA128 - UNIMPLEMENTED
1236    # CMAC-CAMELLIA256 - UNIMPLEMENTED
1237    ChecksumType.HMAC_SHA256_128_AES128: _SHA256_128_AES128,
1238    ChecksumType.HMAC_SHA384_192_AES256: _SHA384_182_AES256,
1239    ChecksumType.HMAC_MD5: _HMACMD5,
1240    0xFFFFFF76: _HMACMD5,
1241}
1242
1243
1244class Key(object):
1245    def __init__(
1246        self,
1247        etype: Union[EncryptionType, int, None] = None,
1248        key: bytes = b"",
1249        cksumtype: Union[ChecksumType, int, None] = None,
1250    ) -> None:
1251        """
1252        Kerberos Key object.
1253
1254        :param etype: the EncryptionType
1255        :param cksumtype: the ChecksumType
1256        :param key: the bytes containing the key bytes for this Key.
1257        """
1258        assert etype or cksumtype, "Provide an etype or a cksumtype !"
1259        assert key, "Provide a key !"
1260        if isinstance(etype, int):
1261            etype = EncryptionType(etype)
1262        if isinstance(cksumtype, int):
1263            cksumtype = ChecksumType(cksumtype)
1264        self.etype = etype
1265        if etype is not None:
1266            try:
1267                self.ep = _enctypes[etype]
1268            except ValueError:
1269                raise ValueError("UNKNOWN/UNIMPLEMENTED etype '%s'" % etype)
1270            if len(key) != self.ep.keysize:
1271                raise ValueError(
1272                    "Wrong key length. Got %s. Expected %s"
1273                    % (len(key), self.ep.keysize)
1274                )
1275            if cksumtype is None and self.ep.reqcksum in _checksums:
1276                cksumtype = self.ep.reqcksum
1277        self.cksumtype = cksumtype
1278        if cksumtype is not None:
1279            try:
1280                self.cp = _checksums[cksumtype]
1281            except ValueError:
1282                raise ValueError("UNKNOWN/UNIMPLEMENTED cksumtype '%s'" % cksumtype)
1283            if self.etype is None and issubclass(self.cp, _SimplifiedChecksum):
1284                self.etype = self.cp.enc.etype  # type: ignore
1285        self.key = key
1286
1287    def __repr__(self):
1288        # type: () -> str
1289        if self.etype:
1290            name = self.etype.name
1291        elif self.cksumtype:
1292            name = self.cksumtype.name
1293        else:
1294            return "<Key UNKNOWN>"
1295        return "<Key %s%s>" % (
1296            name,
1297            " (%s octets)" % len(self.key),
1298        )
1299
1300    def encrypt(self, keyusage, plaintext, confounder=None, **kwargs):
1301        # type: (int, bytes, Optional[bytes], **Any) -> bytes
1302        """
1303        Encrypt data using the current Key.
1304
1305        :param keyusage: the key usage
1306        :param plaintext: the plain text to encrypt
1307        :param confounder: (optional) choose the confounder. Otherwise random.
1308        """
1309        return self.ep.encrypt(self, keyusage, bytes(plaintext), confounder, **kwargs)
1310
1311    def decrypt(self, keyusage, ciphertext, **kwargs):
1312        # type: (int, bytes, **Any) -> bytes
1313        """
1314        Decrypt data using the current Key.
1315
1316        :param keyusage: the key usage
1317        :param ciphertext: the encrypted text to decrypt
1318        """
1319        # Throw InvalidChecksum on checksum failure.  Throw ValueError on
1320        # invalid key enctype or malformed ciphertext.
1321        return self.ep.decrypt(self, keyusage, ciphertext, **kwargs)
1322
1323    def prf(self, string):
1324        # type: (bytes) -> bytes
1325        return self.ep.prf(self, string)
1326
1327    def make_checksum(self, keyusage, text, cksumtype=None, **kwargs):
1328        # type: (int, bytes, Optional[int], **Any) -> bytes
1329        """
1330        Create a checksum using the current Key.
1331
1332        :param keyusage: the key usage
1333        :param text: the text to create a checksum from
1334        :param cksumtype: (optional) override the checksum type
1335        """
1336        if cksumtype is not None and cksumtype != self.cksumtype:
1337            # Clone key and use a different cksumtype
1338            return Key(
1339                cksumtype=cksumtype,
1340                key=self.key,
1341            ).make_checksum(keyusage=keyusage, text=text, **kwargs)
1342        if self.cksumtype is None:
1343            raise ValueError("cksumtype not specified !")
1344        return self.cp.checksum(self, keyusage, text, **kwargs)
1345
1346    def verify_checksum(self, keyusage, text, cksum, cksumtype=None):
1347        # type: (int, bytes, bytes, Optional[int]) -> None
1348        """
1349        Verify a checksum using the current Key.
1350
1351        :param keyusage: the key usage
1352        :param text: the text to verify
1353        :param cksum: the expected checksum
1354        :param cksumtype: (optional) override the checksum type
1355        """
1356        if cksumtype is not None and cksumtype != self.cksumtype:
1357            # Clone key and use a different cksumtype
1358            return Key(
1359                cksumtype=cksumtype,
1360                key=self.key,
1361            ).verify_checksum(keyusage=keyusage, text=text, cksum=cksum)
1362        # Throw InvalidChecksum exception on checksum failure.  Throw
1363        # ValueError on invalid cksumtype, invalid key enctype, or
1364        # malformed checksum.
1365        if self.cksumtype is None:
1366            raise ValueError("cksumtype not specified !")
1367        self.cp.verify(self, keyusage, text, cksum)
1368
1369    @classmethod
1370    def random_to_key(cls, etype, seed):
1371        # type: (EncryptionType, bytes) -> Key
1372        """
1373        random-to-key per RFC3961
1374
1375        This is used to create a random Key from a seed.
1376        """
1377        try:
1378            ep = _enctypes[etype]
1379        except ValueError:
1380            raise ValueError("Unknown etype '%s'" % etype)
1381        if len(seed) != ep.seedsize:
1382            raise ValueError("Wrong crypto seed length")
1383        return ep.random_to_key(seed)
1384
1385    @classmethod
1386    def string_to_key(cls, etype, string, salt, params=None):
1387        # type: (EncryptionType, bytes, bytes, Optional[bytes]) -> Key
1388        """
1389        string-to-key per RFC3961
1390
1391        This is typically used to create a Key object from a password + salt
1392        """
1393        try:
1394            ep = _enctypes[etype]
1395        except ValueError:
1396            raise ValueError("Unknown etype '%s'" % etype)
1397        return ep.string_to_key(string, salt, params)
1398
1399
1400############
1401# RFC 6113 #
1402############
1403
1404
1405def KRB_FX_CF2(key1, key2, pepper1, pepper2):
1406    # type: (Key, Key, bytes, bytes) -> Key
1407    """
1408    KRB-FX-CF2 RFC6113
1409    """
1410
1411    def prfplus(key, pepper):
1412        # type: (Key, bytes) -> bytes
1413        # Produce l bytes of output using the RFC 6113 PRF+ function.
1414        out = b""
1415        count = 1
1416        while len(out) < key.ep.seedsize:
1417            out += key.prf(chb(count) + pepper)
1418            count += 1
1419        return out[: key.ep.seedsize]
1420
1421    return Key(
1422        key1.etype,
1423        key=bytes(
1424            _xorbytes(
1425                bytearray(prfplus(key1, pepper1)), bytearray(prfplus(key2, pepper2))
1426            )
1427        ),
1428    )
1429