• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) 2014 6WIND
5
6r"""
7IPsec layer
8===========
9
10Example of use:
11
12>>> sa = SecurityAssociation(ESP, spi=0xdeadbeef, crypt_algo='AES-CBC',
13...                          crypt_key=b'sixteenbytes key')
14>>> p = IP(src='1.1.1.1', dst='2.2.2.2')
15>>> p /= TCP(sport=45012, dport=80)
16>>> p /= Raw('testdata')
17>>> p = IP(raw(p))
18>>> p
19<IP  version=4L ihl=5L tos=0x0 len=48 id=1 flags= frag=0L ttl=64 proto=tcp chksum=0x74c2 src=1.1.1.1 dst=2.2.2.2 options=[] |<TCP  sport=45012 dport=http seq=0 ack=0 dataofs=5L reserved=0L flags=S window=8192 chksum=0x1914 urgptr=0 options=[] |<Raw  load='testdata' |>>>  # noqa: E501
20>>>
21>>> e = sa.encrypt(p)
22>>> e
23<IP  version=4L ihl=5L tos=0x0 len=76 id=1 flags= frag=0L ttl=64 proto=esp chksum=0x747a src=1.1.1.1 dst=2.2.2.2 |<ESP  spi=0xdeadbeef seq=1 data=b'\xf8\xdb\x1e\x83[T\xab\\\xd2\x1b\xed\xd1\xe5\xc8Y\xc2\xa5d\x92\xc1\x05\x17\xa6\x92\x831\xe6\xc1]\x9a\xd6K}W\x8bFfd\xa5B*+\xde\xc8\x89\xbf{\xa9' |>>  # noqa: E501
24>>>
25>>> d = sa.decrypt(e)
26>>> d
27<IP  version=4L ihl=5L tos=0x0 len=48 id=1 flags= frag=0L ttl=64 proto=tcp chksum=0x74c2 src=1.1.1.1 dst=2.2.2.2 |<TCP  sport=45012 dport=http seq=0 ack=0 dataofs=5L reserved=0L flags=S window=8192 chksum=0x1914 urgptr=0 options=[] |<Raw  load='testdata' |>>>  # noqa: E501
28>>>
29>>> d == p
30True
31"""
32
33try:
34    from math import gcd
35except ImportError:
36    from fractions import gcd
37import os
38import socket
39import struct
40import warnings
41
42from scapy.config import conf, crypto_validator
43from scapy.compat import orb, raw
44from scapy.data import IP_PROTOS
45from scapy.error import log_loading
46from scapy.fields import (
47    ByteEnumField,
48    ByteField,
49    IntField,
50    PacketField,
51    ShortField,
52    StrField,
53    XByteField,
54    XIntField,
55    XStrField,
56    XStrLenField,
57)
58from scapy.packet import (
59    Packet,
60    Raw,
61    bind_bottom_up,
62    bind_layers,
63    bind_top_down,
64)
65from scapy.layers.inet import IP, UDP
66from scapy.layers.inet6 import IPv6, IPv6ExtHdrHopByHop, IPv6ExtHdrDestOpt, \
67    IPv6ExtHdrRouting
68
69
70###############################################################################
71class AH(Packet):
72    """
73    Authentication Header
74
75    See https://tools.ietf.org/rfc/rfc4302.txt
76    """
77
78    name = 'AH'
79
80    def __get_icv_len(self):
81        """
82        Compute the size of the ICV based on the payloadlen field.
83        Padding size is included as it can only be known from the authentication  # noqa: E501
84        algorithm provided by the Security Association.
85        """
86        # payloadlen = length of AH in 32-bit words (4-byte units), minus "2"
87        # payloadlen = 3 32-bit word fixed fields + ICV + padding - 2
88        # ICV = (payloadlen + 2 - 3 - padding) in 32-bit words
89        return (self.payloadlen - 1) * 4
90
91    fields_desc = [
92        ByteEnumField('nh', None, IP_PROTOS),
93        ByteField('payloadlen', None),
94        ShortField('reserved', None),
95        XIntField('spi', 0x00000001),
96        IntField('seq', 0),
97        XStrLenField('icv', None, length_from=__get_icv_len),
98        # Padding len can only be known with the SecurityAssociation.auth_algo
99        XStrLenField('padding', None, length_from=lambda x: 0),
100    ]
101
102    overload_fields = {
103        IP: {'proto': socket.IPPROTO_AH},
104        IPv6: {'nh': socket.IPPROTO_AH},
105        IPv6ExtHdrHopByHop: {'nh': socket.IPPROTO_AH},
106        IPv6ExtHdrDestOpt: {'nh': socket.IPPROTO_AH},
107        IPv6ExtHdrRouting: {'nh': socket.IPPROTO_AH},
108    }
109
110
111bind_layers(IP, AH, proto=socket.IPPROTO_AH)
112bind_layers(IPv6, AH, nh=socket.IPPROTO_AH)
113bind_layers(AH, IP, nh=socket.IPPROTO_IP)
114bind_layers(AH, IPv6, nh=socket.IPPROTO_IPV6)
115
116###############################################################################
117
118
119class ESP(Packet):
120    """
121    Encapsulated Security Payload
122
123    See https://tools.ietf.org/rfc/rfc4303.txt
124    """
125    name = 'ESP'
126
127    fields_desc = [
128        XIntField('spi', 0x00000001),
129        IntField('seq', 0),
130        XStrField('data', None),
131    ]
132
133    @classmethod
134    def dispatch_hook(cls, _pkt=None, *args, **kargs):
135        if _pkt:
136            if len(_pkt) >= 4 and struct.unpack("!I", _pkt[0:4])[0] == 0x00:
137                return NON_ESP
138            elif len(_pkt) == 1 and struct.unpack("!B", _pkt)[0] == 0xff:
139                return NAT_KEEPALIVE
140            else:
141                return ESP
142        return cls
143
144    overload_fields = {
145        IP: {'proto': socket.IPPROTO_ESP},
146        IPv6: {'nh': socket.IPPROTO_ESP},
147        IPv6ExtHdrHopByHop: {'nh': socket.IPPROTO_ESP},
148        IPv6ExtHdrDestOpt: {'nh': socket.IPPROTO_ESP},
149        IPv6ExtHdrRouting: {'nh': socket.IPPROTO_ESP},
150    }
151
152
153class NON_ESP(Packet):  # RFC 3948, section 2.2
154    fields_desc = [
155        XIntField("non_esp", 0x0)
156    ]
157
158
159class NAT_KEEPALIVE(Packet):  # RFC 3948, section 2.2
160    fields_desc = [
161        XByteField("nat_keepalive", 0xFF)
162    ]
163
164
165bind_layers(IP, ESP, proto=socket.IPPROTO_ESP)
166bind_layers(IPv6, ESP, nh=socket.IPPROTO_ESP)
167
168# NAT-Traversal encapsulation
169bind_bottom_up(UDP, ESP, dport=4500)
170bind_bottom_up(UDP, ESP, sport=4500)
171bind_top_down(UDP, ESP, dport=4500, sport=4500)
172bind_top_down(UDP, NON_ESP, dport=4500, sport=4500)
173bind_top_down(UDP, NAT_KEEPALIVE, dport=4500, sport=4500)
174
175###############################################################################
176
177
178class _ESPPlain(Packet):
179    """
180    Internal class to represent unencrypted ESP packets.
181    """
182    name = 'ESP'
183
184    fields_desc = [
185        XIntField('spi', 0x0),
186        IntField('seq', 0),
187
188        StrField('iv', ''),
189        PacketField('data', '', Raw),
190        StrField('padding', ''),
191
192        ByteField('padlen', 0),
193        ByteEnumField('nh', 0, IP_PROTOS),
194        StrField('icv', ''),
195    ]
196
197    def data_for_encryption(self):
198        return raw(self.data) + self.padding + struct.pack("BB", self.padlen, self.nh)  # noqa: E501
199
200
201###############################################################################
202if conf.crypto_valid:
203    from cryptography.exceptions import InvalidTag
204    from cryptography.hazmat.backends import default_backend
205    from cryptography.hazmat.primitives.ciphers import (
206        aead,
207        Cipher,
208        algorithms,
209        modes,
210    )
211    try:
212        # cryptography > 43.0
213        from cryptography.hazmat.decrepit.ciphers import (
214            algorithms as decrepit_algorithms
215        )
216    except ImportError:
217        decrepit_algorithms = algorithms
218else:
219    log_loading.info("Can't import python-cryptography v1.7+. "
220                     "Disabled IPsec encryption/authentication.")
221    default_backend = None
222    InvalidTag = Exception
223    Cipher = algorithms = modes = None
224
225###############################################################################
226
227
228def _lcm(a, b):
229    """
230    Least Common Multiple between 2 integers.
231    """
232    if a == 0 or b == 0:
233        return 0
234    else:
235        return abs(a * b) // gcd(a, b)
236
237
238class CryptAlgo(object):
239    """
240    IPsec encryption algorithm
241    """
242
243    def __init__(self, name, cipher, mode, block_size=None, iv_size=None,
244                 key_size=None, icv_size=None, salt_size=None, format_mode_iv=None):  # noqa: E501
245        """
246        :param name: the name of this encryption algorithm
247        :param cipher: a Cipher module
248        :param mode: the mode used with the cipher module
249        :param block_size: the length a block for this algo. Defaults to the
250                           `block_size` of the cipher.
251        :param iv_size: the length of the initialization vector of this algo.
252                        Defaults to the `block_size` of the cipher.
253        :param key_size: an integer or list/tuple of integers. If specified,
254                         force the secret keys length to one of the values.
255                         Defaults to the `key_size` of the cipher.
256        :param icv_size: the length of the Integrity Check Value of this algo.
257                         Used by Combined Mode Algorithms e.g. GCM
258        :param salt_size: the length of the salt to use as the IV prefix.
259                          Usually used by Counter modes e.g. CTR
260        :param format_mode_iv: function to format the Initialization Vector
261                               e.g. handle the salt value
262                               Default is the random buffer from `generate_iv`
263        """
264        self.name = name
265        self.cipher = cipher
266        self.mode = mode
267        self.icv_size = icv_size
268
269        self.is_aead = False
270        # If using cryptography.hazmat.primitives.cipher.aead
271        self.ciphers_aead_api = False
272
273        if modes:
274            if self.mode is not None:
275                self.is_aead = issubclass(self.mode,
276                                          modes.ModeWithAuthenticationTag)
277            elif self.cipher in (aead.AESGCM, aead.AESCCM,
278                                 aead.ChaCha20Poly1305):
279                self.is_aead = True
280                self.ciphers_aead_api = True
281
282        if block_size is not None:
283            self.block_size = block_size
284        elif cipher is not None:
285            self.block_size = cipher.block_size // 8
286        else:
287            self.block_size = 1
288
289        if iv_size is None:
290            self.iv_size = self.block_size
291        else:
292            self.iv_size = iv_size
293
294        if key_size is not None:
295            self.key_size = key_size
296        elif cipher is not None:
297            self.key_size = tuple(i // 8 for i in cipher.key_sizes)
298        else:
299            self.key_size = None
300
301        if salt_size is None:
302            self.salt_size = 0
303        else:
304            self.salt_size = salt_size
305
306        if format_mode_iv is None:
307            self._format_mode_iv = lambda iv, **kw: iv
308        else:
309            self._format_mode_iv = format_mode_iv
310
311    def check_key(self, key):
312        """
313        Check that the key length is valid.
314
315        :param key:    a byte string
316        """
317        if self.key_size and not (len(key) == self.key_size or len(key) in self.key_size):  # noqa: E501
318            raise TypeError('invalid key size %s, must be %s' %
319                            (len(key), self.key_size))
320
321    def generate_iv(self):
322        """
323        Generate a random initialization vector.
324        """
325        # XXX: Handle counter modes with real counters? RFCs allow the use of
326        # XXX: random bytes for counters, so it is not wrong to do it that way
327        return os.urandom(self.iv_size)
328
329    @crypto_validator
330    def new_cipher(self, key, mode_iv, digest=None):
331        """
332        :param key:     the secret key, a byte string
333        :param mode_iv: the initialization vector or nonce, a byte string.
334                        Formatted by `format_mode_iv`.
335        :param digest:  also known as tag or icv. A byte string containing the
336                        digest of the encrypted data. Only use this during
337                        decryption!
338
339        :returns:    an initialized cipher object for this algo
340        """
341        if self.is_aead and digest is not None:
342            # With AEAD, the mode needs the digest during decryption.
343            return Cipher(
344                self.cipher(key),
345                self.mode(mode_iv, digest, len(digest)),
346                default_backend(),
347            )
348        else:
349            return Cipher(
350                self.cipher(key),
351                self.mode(mode_iv),
352                default_backend(),
353            )
354
355    def pad(self, esp):
356        """
357        Add the correct amount of padding so that the data to encrypt is
358        exactly a multiple of the algorithm's block size.
359
360        Also, make sure that the total ESP packet length is a multiple of 4
361        bytes.
362
363        :param esp:    an unencrypted _ESPPlain packet
364
365        :returns:    an unencrypted _ESPPlain packet with valid padding
366        """
367        # 2 extra bytes for padlen and nh
368        data_len = len(esp.data) + 2
369
370        # according to the RFC4303, section 2.4. Padding (for Encryption)
371        # the size of the ESP payload must be a multiple of 32 bits
372        align = _lcm(self.block_size, 4)
373
374        # pad for block size
375        esp.padlen = -data_len % align
376
377        # Still according to the RFC, the default value for padding *MUST* be an  # noqa: E501
378        # array of bytes starting from 1 to padlen
379        # TODO: Handle padding function according to the encryption algo
380        esp.padding = struct.pack("B" * esp.padlen, *range(1, esp.padlen + 1))
381
382        # If the following test fails, it means that this algo does not comply
383        # with the RFC
384        payload_len = len(esp.iv) + len(esp.data) + len(esp.padding) + 2
385        if payload_len % 4 != 0:
386            raise ValueError('The size of the ESP data is not aligned to 32 bits after padding.')  # noqa: E501
387
388        return esp
389
390    def encrypt(self, sa, esp, key, icv_size=None, esn_en=False, esn=0):
391        """
392        Encrypt an ESP packet
393
394        :param sa:   the SecurityAssociation associated with the ESP packet.
395        :param esp:  an unencrypted _ESPPlain packet with valid padding
396        :param key:  the secret key used for encryption
397        :param icv_size: the length of the icv used for integrity check
398        :esn_en:     extended sequence number enable which allows to use 64-bit
399                     sequence number instead of 32-bit when using an AEAD
400                     algorithm
401        :esn:        extended sequence number (32 MSB)
402        :return:    a valid ESP packet encrypted with this algorithm
403        """
404        if icv_size is None:
405            icv_size = self.icv_size if self.is_aead else 0
406        data = esp.data_for_encryption()
407
408        if self.cipher:
409            mode_iv = self._format_mode_iv(algo=self, sa=sa, iv=esp.iv)
410            aad = None
411            if self.is_aead:
412                if esn_en:
413                    aad = struct.pack('!LLL', esp.spi, esn, esp.seq)
414                else:
415                    aad = struct.pack('!LL', esp.spi, esp.seq)
416            if self.ciphers_aead_api:
417                # New API
418                if self.cipher == aead.AESCCM:
419                    cipher = self.cipher(key, tag_length=icv_size)
420                else:
421                    cipher = self.cipher(key)
422                if self.name == 'AES-NULL-GMAC':
423                    # Special case for GMAC (rfc 4543 sect 3)
424                    data = data + cipher.encrypt(mode_iv, b"", aad + esp.iv + data)
425                else:
426                    data = cipher.encrypt(mode_iv, data, aad)
427            else:
428                cipher = self.new_cipher(key, mode_iv)
429                encryptor = cipher.encryptor()
430
431                if self.is_aead:
432                    encryptor.authenticate_additional_data(aad)
433                    data = encryptor.update(data) + encryptor.finalize()
434                    data += encryptor.tag[:icv_size]
435                else:
436                    data = encryptor.update(data) + encryptor.finalize()
437
438        return ESP(spi=esp.spi, seq=esp.seq, data=esp.iv + data)
439
440    def decrypt(self, sa, esp, key, icv_size=None, esn_en=False, esn=0):
441        """
442        Decrypt an ESP packet
443
444        :param sa: the SecurityAssociation associated with the ESP packet.
445        :param esp: an encrypted ESP packet
446        :param key: the secret key used for encryption
447        :param icv_size: the length of the icv used for integrity check
448        :param esn_en: extended sequence number enable which allows to use
449                       64-bit sequence number instead of 32-bit when using an
450                       AEAD algorithm
451        :param esn: extended sequence number (32 MSB)
452        :returns: a valid ESP packet encrypted with this algorithm
453        :raise scapy.layers.ipsec.IPSecIntegrityError: if the integrity check
454            fails with an AEAD algorithm
455        """
456        if icv_size is None:
457            icv_size = self.icv_size if self.is_aead else 0
458
459        iv = esp.data[:self.iv_size]
460        data = esp.data[self.iv_size:len(esp.data) - icv_size]
461        icv = esp.data[len(esp.data) - icv_size:]
462
463        if self.cipher:
464            mode_iv = self._format_mode_iv(sa=sa, iv=iv)
465            aad = None
466            if self.is_aead:
467                if esn_en:
468                    aad = struct.pack('!LLL', esp.spi, esn, esp.seq)
469                else:
470                    aad = struct.pack('!LL', esp.spi, esp.seq)
471            if self.ciphers_aead_api:
472                # New API
473                if self.cipher == aead.AESCCM:
474                    cipher = self.cipher(key, tag_length=icv_size)
475                else:
476                    cipher = self.cipher(key)
477                try:
478                    if self.name == 'AES-NULL-GMAC':
479                        # Special case for GMAC (rfc 4543 sect 3)
480                        data = data + cipher.decrypt(mode_iv, icv, aad + iv + data)
481                    else:
482                        data = cipher.decrypt(mode_iv, data + icv, aad)
483                except InvalidTag as err:
484                    raise IPSecIntegrityError(err)
485            else:
486                cipher = self.new_cipher(key, mode_iv, icv)
487                decryptor = cipher.decryptor()
488
489                if self.is_aead:
490                    # Tag value check is done during the finalize method
491                    decryptor.authenticate_additional_data(aad)
492                try:
493                    data = decryptor.update(data) + decryptor.finalize()
494                except InvalidTag as err:
495                    raise IPSecIntegrityError(err)
496
497        # extract padlen and nh
498        padlen = orb(data[-2])
499        nh = orb(data[-1])
500
501        # then use padlen to determine data and padding
502        padding = data[len(data) - padlen - 2: len(data) - 2]
503        data = data[:len(data) - padlen - 2]
504
505        return _ESPPlain(spi=esp.spi,
506                         seq=esp.seq,
507                         iv=iv,
508                         data=data,
509                         padding=padding,
510                         padlen=padlen,
511                         nh=nh,
512                         icv=icv)
513
514###############################################################################
515# The names of the encryption algorithms are the same than in scapy.contrib.ikev2  # noqa: E501
516# see http://www.iana.org/assignments/ikev2-parameters/ikev2-parameters.xhtml
517
518
519CRYPT_ALGOS = {
520    'NULL': CryptAlgo('NULL', cipher=None, mode=None, iv_size=0),
521}
522
523if algorithms:
524    CRYPT_ALGOS['AES-CBC'] = CryptAlgo('AES-CBC',
525                                       cipher=algorithms.AES,
526                                       mode=modes.CBC)
527    _aes_ctr_format_mode_iv = lambda sa, iv, **kw: sa.crypt_salt + iv + b'\x00\x00\x00\x01'  # noqa: E501
528    CRYPT_ALGOS['AES-CTR'] = CryptAlgo('AES-CTR',
529                                       cipher=algorithms.AES,
530                                       mode=modes.CTR,
531                                       block_size=1,
532                                       iv_size=8,
533                                       salt_size=4,
534                                       format_mode_iv=_aes_ctr_format_mode_iv)
535    _salt_format_mode_iv = lambda sa, iv, **kw: sa.crypt_salt + iv
536    CRYPT_ALGOS['AES-GCM'] = CryptAlgo('AES-GCM',
537                                       cipher=aead.AESGCM,
538                                       key_size=(16, 24, 32),
539                                       mode=None,
540                                       salt_size=4,
541                                       block_size=1,
542                                       iv_size=8,
543                                       icv_size=16,
544                                       format_mode_iv=_salt_format_mode_iv)
545    # GMAC: rfc 4543, "companion to the AES Galois/Counter Mode ESP"
546    # This is defined as a crypt_algo by rfc, but has the role of an auth_algo
547    CRYPT_ALGOS['AES-NULL-GMAC'] = CryptAlgo('AES-NULL-GMAC',
548                                             cipher=aead.AESGCM,
549                                             key_size=(16, 24, 32),
550                                             mode=None,
551                                             salt_size=4,
552                                             block_size=1,
553                                             iv_size=8,
554                                             icv_size=16,
555                                             format_mode_iv=_salt_format_mode_iv)
556    CRYPT_ALGOS['AES-CCM'] = CryptAlgo('AES-CCM',
557                                       cipher=aead.AESCCM,
558                                       mode=None,
559                                       key_size=(16, 24, 32),
560                                       block_size=1,
561                                       iv_size=8,
562                                       salt_size=3,
563                                       icv_size=16,
564                                       format_mode_iv=_salt_format_mode_iv)
565    CRYPT_ALGOS['CHACHA20-POLY1305'] = CryptAlgo('CHACHA20-POLY1305',
566                                                 cipher=aead.ChaCha20Poly1305,
567                                                 mode=None,
568                                                 key_size=32,
569                                                 block_size=1,
570                                                 iv_size=8,
571                                                 salt_size=4,
572                                                 icv_size=16,
573                                                 format_mode_iv=_salt_format_mode_iv)  # noqa: E501
574
575    # Using a TripleDES cipher algorithm for DES is done by using the same 64
576    # bits key 3 times (done by cryptography when given a 64 bits key)
577    CRYPT_ALGOS['DES'] = CryptAlgo('DES',
578                                   cipher=decrepit_algorithms.TripleDES,
579                                   mode=modes.CBC,
580                                   key_size=(8,))
581    CRYPT_ALGOS['3DES'] = CryptAlgo('3DES',
582                                    cipher=decrepit_algorithms.TripleDES,
583                                    mode=modes.CBC)
584    if decrepit_algorithms is algorithms:
585        # cryptography < 43 raises a DeprecationWarning
586        from cryptography.utils import CryptographyDeprecationWarning
587        with warnings.catch_warnings():
588            # Hide deprecation warnings
589            warnings.filterwarnings("ignore",
590                                    category=CryptographyDeprecationWarning)
591            CRYPT_ALGOS['CAST'] = CryptAlgo('CAST',
592                                            cipher=decrepit_algorithms.CAST5,
593                                            mode=modes.CBC)
594            CRYPT_ALGOS['Blowfish'] = CryptAlgo('Blowfish',
595                                                cipher=decrepit_algorithms.Blowfish,
596                                                mode=modes.CBC)
597    else:
598        CRYPT_ALGOS['CAST'] = CryptAlgo('CAST',
599                                        cipher=decrepit_algorithms.CAST5,
600                                        mode=modes.CBC)
601        CRYPT_ALGOS['Blowfish'] = CryptAlgo('Blowfish',
602                                            cipher=decrepit_algorithms.Blowfish,
603                                            mode=modes.CBC)
604
605
606###############################################################################
607if conf.crypto_valid:
608    from cryptography.hazmat.primitives.hmac import HMAC
609    from cryptography.hazmat.primitives.cmac import CMAC
610    from cryptography.hazmat.primitives import hashes
611else:
612    # no error if cryptography is not available but authentication won't be supported  # noqa: E501
613    HMAC = CMAC = hashes = None
614
615###############################################################################
616
617
618class IPSecIntegrityError(Exception):
619    """
620    Error risen when the integrity check fails.
621    """
622    pass
623
624
625class AuthAlgo(object):
626    """
627    IPsec integrity algorithm
628    """
629
630    def __init__(self, name, mac, digestmod, icv_size, key_size=None):
631        """
632        :param name: the name of this integrity algorithm
633        :param mac: a Message Authentication Code module
634        :param digestmod: a Hash or Cipher module
635        :param icv_size: the length of the integrity check value of this algo
636        :param key_size: an integer or list/tuple of integers. If specified,
637                         force the secret keys length to one of the values.
638                         Defaults to the `key_size` of the cipher.
639        """
640        self.name = name
641        self.mac = mac
642        self.digestmod = digestmod
643        self.icv_size = icv_size
644        self.key_size = key_size
645
646    def check_key(self, key):
647        """
648        Check that the key length is valid.
649
650        :param key:    a byte string
651        """
652        if self.key_size and len(key) not in self.key_size:
653            raise TypeError('invalid key size %s, must be one of %s' %
654                            (len(key), self.key_size))
655
656    @crypto_validator
657    def new_mac(self, key):
658        """
659        :param key:    a byte string
660        :returns:       an initialized mac object for this algo
661        """
662        if self.mac is CMAC:
663            return self.mac(self.digestmod(key), default_backend())
664        else:
665            return self.mac(key, self.digestmod(), default_backend())
666
667    def sign(self, pkt, key, esn_en=False, esn=0):
668        """
669        Sign an IPsec (ESP or AH) packet with this algo.
670
671        :param pkt:    a packet that contains a valid encrypted ESP or AH layer
672        :param key:    the authentication key, a byte string
673        :param esn_en: extended sequence number enable which allows to use
674                       64-bit sequence number instead of 32-bit
675        :param esn: extended sequence number (32 MSB)
676
677        :returns: the signed packet
678        """
679        if not self.mac:
680            return pkt
681
682        mac = self.new_mac(key)
683
684        if pkt.haslayer(ESP):
685            mac.update(bytes(pkt[ESP]))
686            if esn_en:
687                # RFC4303 sect 2.2.1
688                mac.update(struct.pack('!L', esn))
689            pkt[ESP].data += mac.finalize()[:self.icv_size]
690
691        elif pkt.haslayer(AH):
692            mac.update(bytes(zero_mutable_fields(pkt.copy(), sending=True)))
693            if esn_en:
694                # RFC4302 sect 2.5.1
695                mac.update(struct.pack('!L', esn))
696            pkt[AH].icv = mac.finalize()[:self.icv_size]
697
698        return pkt
699
700    def verify(self, pkt, key, esn_en=False, esn=0):
701        """
702        Check that the integrity check value (icv) of a packet is valid.
703
704        :param pkt:    a packet that contains a valid encrypted ESP or AH layer
705        :param key:    the authentication key, a byte string
706        :param esn_en: extended sequence number enable which allows to use
707                       64-bit sequence number instead of 32-bit
708        :param esn: extended sequence number (32 MSB)
709
710        :raise scapy.layers.ipsec.IPSecIntegrityError: if the integrity check
711            fails
712        """
713        if not self.mac or self.icv_size == 0:
714            return
715
716        mac = self.new_mac(key)
717
718        pkt_icv = 'not found'
719
720        if isinstance(pkt, ESP):
721            pkt_icv = pkt.data[len(pkt.data) - self.icv_size:]
722            clone = pkt.copy()
723            clone.data = clone.data[:len(clone.data) - self.icv_size]
724            mac.update(bytes(clone))
725            if esn_en:
726                # RFC4303 sect 2.2.1
727                mac.update(struct.pack('!L', esn))
728
729        elif pkt.haslayer(AH):
730            if len(pkt[AH].icv) != self.icv_size:
731                # Fill padding since we know the actual icv_size
732                pkt[AH].padding = pkt[AH].icv[self.icv_size:]
733                pkt[AH].icv = pkt[AH].icv[:self.icv_size]
734            pkt_icv = pkt[AH].icv
735            clone = zero_mutable_fields(pkt.copy(), sending=False)
736            mac.update(bytes(clone))
737            if esn_en:
738                # RFC4302 sect 2.5.1
739                mac.update(struct.pack('!L', esn))
740
741        computed_icv = mac.finalize()[:self.icv_size]
742
743        # XXX: Cannot use mac.verify because the ICV can be truncated
744        if pkt_icv != computed_icv:
745            raise IPSecIntegrityError('pkt_icv=%r, computed_icv=%r' %
746                                      (pkt_icv, computed_icv))
747
748###############################################################################
749# The names of the integrity algorithms are the same than in scapy.contrib.ikev2  # noqa: E501
750# see http://www.iana.org/assignments/ikev2-parameters/ikev2-parameters.xhtml
751
752
753AUTH_ALGOS = {
754    'NULL': AuthAlgo('NULL', mac=None, digestmod=None, icv_size=0),
755}
756
757if HMAC and hashes:
758    # XXX: NIST has deprecated SHA1 but is required by RFC7321
759    AUTH_ALGOS['HMAC-SHA1-96'] = AuthAlgo('HMAC-SHA1-96',
760                                          mac=HMAC,
761                                          digestmod=hashes.SHA1,
762                                          icv_size=12)
763    AUTH_ALGOS['SHA2-256-128'] = AuthAlgo('SHA2-256-128',
764                                          mac=HMAC,
765                                          digestmod=hashes.SHA256,
766                                          icv_size=16)
767    AUTH_ALGOS['SHA2-384-192'] = AuthAlgo('SHA2-384-192',
768                                          mac=HMAC,
769                                          digestmod=hashes.SHA384,
770                                          icv_size=24)
771    AUTH_ALGOS['SHA2-512-256'] = AuthAlgo('SHA2-512-256',
772                                          mac=HMAC,
773                                          digestmod=hashes.SHA512,
774                                          icv_size=32)
775    # XXX:Flagged as deprecated by 'cryptography'. Kept for backward compat
776    AUTH_ALGOS['HMAC-MD5-96'] = AuthAlgo('HMAC-MD5-96',
777                                         mac=HMAC,
778                                         digestmod=hashes.MD5,
779                                         icv_size=12)
780if CMAC and algorithms:
781    AUTH_ALGOS['AES-CMAC-96'] = AuthAlgo('AES-CMAC-96',
782                                         mac=CMAC,
783                                         digestmod=algorithms.AES,
784                                         icv_size=12,
785                                         key_size=(16,))
786
787###############################################################################
788
789
790def split_for_transport(orig_pkt, transport_proto):
791    """
792    Split an IP(v6) packet in the correct location to insert an ESP or AH
793    header.
794
795    :param orig_pkt: the packet to split. Must be an IP or IPv6 packet
796    :param transport_proto: the IPsec protocol number that will be inserted
797                            at the split position.
798    :returns: a tuple (header, nh, payload) where nh is the protocol number of
799             payload.
800    """
801    # force resolution of default fields to avoid padding errors
802    header = orig_pkt.__class__(raw(orig_pkt))
803    next_hdr = header.payload
804    nh = None
805
806    if header.version == 4:
807        nh = header.proto
808        header.proto = transport_proto
809        header.remove_payload()
810        del header.chksum
811        del header.len
812
813        return header, nh, next_hdr
814    else:
815        found_rt_hdr = False
816        prev = header
817
818        # Since the RFC 4302 is vague about where the ESP/AH headers should be
819        # inserted in IPv6, I chose to follow the linux implementation.
820        while isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrRouting, IPv6ExtHdrDestOpt)):  # noqa: E501
821            if isinstance(next_hdr, IPv6ExtHdrHopByHop):
822                pass
823            if isinstance(next_hdr, IPv6ExtHdrRouting):
824                found_rt_hdr = True
825            elif isinstance(next_hdr, IPv6ExtHdrDestOpt) and found_rt_hdr:
826                break
827
828            prev = next_hdr
829            next_hdr = next_hdr.payload
830
831        nh = prev.nh
832        prev.nh = transport_proto
833        prev.remove_payload()
834        del header.plen
835
836        return header, nh, next_hdr
837
838
839###############################################################################
840# see RFC 4302 - Appendix A. Mutability of IP Options/Extension Headers
841IMMUTABLE_IPV4_OPTIONS = (
842    0,  # End Of List
843    1,  # No OPeration
844    2,  # Security
845    5,  # Extended Security
846    6,  # Commercial Security
847    20,  # Router Alert
848    21,  # Sender Directed Multi-Destination Delivery
849)
850
851
852def zero_mutable_fields(pkt, sending=False):
853    """
854    When using AH, all "mutable" fields must be "zeroed" before calculating
855    the ICV. See RFC 4302, Section 3.3.3.1. Handling Mutable Fields.
856
857    :param pkt: an IP(v6) packet containing an AH layer.
858                NOTE: The packet will be modified
859    :param sending: if true, ipv6 routing headers will not be reordered
860    """
861
862    if pkt.haslayer(AH):
863        pkt[AH].icv = b"\x00" * len(pkt[AH].icv)
864    else:
865        raise TypeError('no AH layer found')
866
867    if pkt.version == 4:
868        # the tos field has been replaced by DSCP and ECN
869        # Routers may rewrite the DS field as needed to provide a
870        # desired local or end-to-end service
871        pkt.tos = 0
872        # an intermediate router might set the DF bit, even if the source
873        # did not select it.
874        pkt.flags = 0
875        # changed en route as a normal course of processing by routers
876        pkt.ttl = 0
877        # will change if any of these other fields change
878        pkt.chksum = 0
879
880        immutable_opts = []
881        for opt in pkt.options:
882            if opt.option in IMMUTABLE_IPV4_OPTIONS:
883                immutable_opts.append(opt)
884            else:
885                immutable_opts.append(Raw(b"\x00" * len(opt)))
886        pkt.options = immutable_opts
887
888    else:
889        # holds DSCP and ECN
890        pkt.tc = 0
891        # The flow label described in AHv1 was mutable, and in RFC 2460 [DH98]
892        # was potentially mutable. To retain compatibility with existing AH
893        # implementations, the flow label is not included in the ICV in AHv2.
894        pkt.fl = 0
895        # same as ttl
896        pkt.hlim = 0
897
898        next_hdr = pkt.payload
899
900        while isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrRouting, IPv6ExtHdrDestOpt)):  # noqa: E501
901            if isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrDestOpt)):
902                for opt in next_hdr.options:
903                    if opt.otype & 0x20:
904                        # option data can change en-route and must be zeroed
905                        opt.optdata = b"\x00" * opt.optlen
906            elif isinstance(next_hdr, IPv6ExtHdrRouting) and sending:
907                # The sender must order the field so that it appears as it
908                # will at the receiver, prior to performing the ICV computation.  # noqa: E501
909                next_hdr.segleft = 0
910                if next_hdr.addresses:
911                    final = next_hdr.addresses.pop()
912                    next_hdr.addresses.insert(0, pkt.dst)
913                    pkt.dst = final
914            else:
915                break
916
917            next_hdr = next_hdr.payload
918
919    return pkt
920
921###############################################################################
922
923
924class SecurityAssociation(object):
925    """
926    This class is responsible of "encryption" and "decryption" of IPsec packets.  # noqa: E501
927    """
928
929    SUPPORTED_PROTOS = (IP, IPv6)
930
931    def __init__(self, proto, spi, seq_num=1, crypt_algo=None, crypt_key=None,
932                 crypt_icv_size=None,
933                 auth_algo=None, auth_key=None,
934                 tunnel_header=None, nat_t_header=None, esn_en=False, esn=0):
935        """
936        :param proto: the IPsec proto to use (ESP or AH)
937        :param spi: the Security Parameters Index of this SA
938        :param seq_num: the initial value for the sequence number on encrypted
939                        packets
940        :param crypt_algo: the encryption algorithm name (only used with ESP)
941        :param crypt_key: the encryption key (only used with ESP)
942        :param crypt_icv_size: change the default size of the crypt_algo
943                               (only used with ESP)
944        :param auth_algo: the integrity algorithm name
945        :param auth_key: the integrity key
946        :param tunnel_header: an instance of a IP(v6) header that will be used
947                              to encapsulate the encrypted packets.
948        :param nat_t_header: an instance of a UDP header that will be used
949                             for NAT-Traversal.
950        :param esn_en: extended sequence number enable which allows to use
951                       64-bit sequence number instead of 32-bit when using an
952                       AEAD algorithm
953        :param esn: extended sequence number (32 MSB)
954        """
955
956        if proto not in {ESP, AH, ESP.name, AH.name}:
957            raise ValueError("proto must be either ESP or AH")
958        if isinstance(proto, str):
959            self.proto = {ESP.name: ESP, AH.name: AH}[proto]
960        else:
961            self.proto = proto
962
963        self.spi = spi
964        self.seq_num = seq_num
965        self.esn_en = esn_en
966        # Get Extended Sequence (32 MSB)
967        self.esn = esn
968        if crypt_algo:
969            if crypt_algo not in CRYPT_ALGOS:
970                raise TypeError('unsupported encryption algo %r, try %r' %
971                                (crypt_algo, list(CRYPT_ALGOS)))
972            self.crypt_algo = CRYPT_ALGOS[crypt_algo]
973
974            if crypt_key:
975                salt_size = self.crypt_algo.salt_size
976                self.crypt_key = crypt_key[:len(crypt_key) - salt_size]
977                self.crypt_salt = crypt_key[len(crypt_key) - salt_size:]
978            else:
979                self.crypt_key = None
980                self.crypt_salt = None
981
982        else:
983            self.crypt_algo = CRYPT_ALGOS['NULL']
984            self.crypt_key = None
985            self.crypt_salt = None
986        self.crypt_icv_size = crypt_icv_size
987
988        if auth_algo:
989            if auth_algo not in AUTH_ALGOS:
990                raise TypeError('unsupported integrity algo %r, try %r' %
991                                (auth_algo, list(AUTH_ALGOS)))
992            self.auth_algo = AUTH_ALGOS[auth_algo]
993            self.auth_key = auth_key
994        else:
995            self.auth_algo = AUTH_ALGOS['NULL']
996            self.auth_key = None
997
998        if tunnel_header and not isinstance(tunnel_header, (IP, IPv6)):
999            raise TypeError('tunnel_header must be %s or %s' % (IP.name, IPv6.name))  # noqa: E501
1000        self.tunnel_header = tunnel_header
1001
1002        if nat_t_header:
1003            if proto is not ESP:
1004                raise TypeError('nat_t_header is only allowed with ESP')
1005            if not isinstance(nat_t_header, UDP):
1006                raise TypeError('nat_t_header must be %s' % UDP.name)
1007        self.nat_t_header = nat_t_header
1008
1009    def check_spi(self, pkt):
1010        if pkt.spi != self.spi:
1011            raise TypeError('packet spi=0x%x does not match the SA spi=0x%x' %
1012                            (pkt.spi, self.spi))
1013
1014    def _encrypt_esp(self, pkt, seq_num=None, iv=None, esn_en=None, esn=None):
1015
1016        if iv is None:
1017            iv = self.crypt_algo.generate_iv()
1018        else:
1019            if len(iv) != self.crypt_algo.iv_size:
1020                raise TypeError('iv length must be %s' % self.crypt_algo.iv_size)  # noqa: E501
1021
1022        esp = _ESPPlain(spi=self.spi, seq=seq_num or self.seq_num, iv=iv)
1023
1024        if self.tunnel_header:
1025            tunnel = self.tunnel_header.copy()
1026
1027            if tunnel.version == 4:
1028                del tunnel.proto
1029                del tunnel.len
1030                del tunnel.chksum
1031            else:
1032                del tunnel.nh
1033                del tunnel.plen
1034
1035            pkt = tunnel.__class__(raw(tunnel / pkt))
1036
1037        ip_header, nh, payload = split_for_transport(pkt, socket.IPPROTO_ESP)
1038        esp.data = payload
1039        esp.nh = nh
1040
1041        esp = self.crypt_algo.pad(esp)
1042        esp = self.crypt_algo.encrypt(self, esp, self.crypt_key,
1043                                      self.crypt_icv_size,
1044                                      esn_en=esn_en or self.esn_en,
1045                                      esn=esn or self.esn)
1046
1047        self.auth_algo.sign(esp,
1048                            self.auth_key,
1049                            esn_en=esn_en or self.esn_en,
1050                            esn=esn or self.esn)
1051
1052        if self.nat_t_header:
1053            nat_t_header = self.nat_t_header.copy()
1054            nat_t_header.chksum = 0
1055            del nat_t_header.len
1056            if ip_header.version == 4:
1057                del ip_header.proto
1058            else:
1059                del ip_header.nh
1060            ip_header /= nat_t_header
1061
1062        if ip_header.version == 4:
1063            del ip_header.len
1064            del ip_header.chksum
1065        else:
1066            del ip_header.plen
1067
1068        # sequence number must always change, unless specified by the user
1069        if seq_num is None:
1070            self.seq_num += 1
1071
1072        return ip_header.__class__(raw(ip_header / esp))
1073
1074    def _encrypt_ah(self, pkt, seq_num=None, esn_en=False, esn=0):
1075
1076        ah = AH(spi=self.spi, seq=seq_num or self.seq_num,
1077                icv=b"\x00" * self.auth_algo.icv_size)
1078
1079        if self.tunnel_header:
1080            tunnel = self.tunnel_header.copy()
1081
1082            if tunnel.version == 4:
1083                del tunnel.proto
1084                del tunnel.len
1085                del tunnel.chksum
1086            else:
1087                del tunnel.nh
1088                del tunnel.plen
1089
1090            pkt = tunnel.__class__(raw(tunnel / pkt))
1091
1092        ip_header, nh, payload = split_for_transport(pkt, socket.IPPROTO_AH)
1093        ah.nh = nh
1094
1095        if ip_header.version == 6 and len(ah) % 8 != 0:
1096            # For IPv6, the total length of the header must be a multiple of
1097            # 8-octet units.
1098            ah.padding = b"\x00" * (-len(ah) % 8)
1099        elif len(ah) % 4 != 0:
1100            # For IPv4, the total length of the header must be a multiple of
1101            # 4-octet units.
1102            ah.padding = b"\x00" * (-len(ah) % 4)
1103
1104        # RFC 4302 - Section 2.2. Payload Length
1105        # This 8-bit field specifies the length of AH in 32-bit words (4-byte
1106        # units), minus "2".
1107        ah.payloadlen = len(ah) // 4 - 2
1108
1109        if ip_header.version == 4:
1110            ip_header.len = len(ip_header) + len(ah) + len(payload)
1111            del ip_header.chksum
1112            ip_header = ip_header.__class__(raw(ip_header))
1113        else:
1114            ip_header.plen = len(ip_header.payload) + len(ah) + len(payload)
1115
1116        signed_pkt = self.auth_algo.sign(ip_header / ah / payload,
1117                                         self.auth_key,
1118                                         esn_en=esn_en or self.esn_en,
1119                                         esn=esn or self.esn)
1120
1121        # sequence number must always change, unless specified by the user
1122        if seq_num is None:
1123            self.seq_num += 1
1124
1125        return signed_pkt
1126
1127    def encrypt(self, pkt, seq_num=None, iv=None, esn_en=None, esn=None):
1128        """
1129        Encrypt (and encapsulate) an IP(v6) packet with ESP or AH according
1130        to this SecurityAssociation.
1131
1132        :param pkt:     the packet to encrypt
1133        :param seq_num: if specified, use this sequence number instead of the
1134                        generated one
1135        :param esn_en:  extended sequence number enable which allows to
1136                        use 64-bit sequence number instead of 32-bit when
1137                        using an AEAD algorithm
1138        :param esn:     extended sequence number (32 MSB)
1139        :param iv:      if specified, use this initialization vector for
1140                        encryption instead of a random one.
1141
1142        :returns: the encrypted/encapsulated packet
1143        """
1144        if not isinstance(pkt, self.SUPPORTED_PROTOS):
1145            raise TypeError('cannot encrypt %s, supported protos are %s'
1146                            % (pkt.__class__, self.SUPPORTED_PROTOS))
1147        if self.proto is ESP:
1148            return self._encrypt_esp(pkt, seq_num=seq_num,
1149                                     iv=iv, esn_en=esn_en,
1150                                     esn=esn)
1151        else:
1152            return self._encrypt_ah(pkt, seq_num=seq_num,
1153                                    esn_en=esn_en, esn=esn)
1154
1155    def _decrypt_esp(self, pkt, verify=True, esn_en=None, esn=None):
1156
1157        encrypted = pkt[ESP]
1158
1159        if verify:
1160            self.check_spi(pkt)
1161            self.auth_algo.verify(encrypted, self.auth_key,
1162                                  esn_en=esn_en or self.esn_en,
1163                                  esn=esn or self.esn)
1164
1165        esp = self.crypt_algo.decrypt(self, encrypted, self.crypt_key,
1166                                      self.crypt_icv_size or
1167                                      self.crypt_algo.icv_size or
1168                                      self.auth_algo.icv_size,
1169                                      esn_en=esn_en or self.esn_en,
1170                                      esn=esn or self.esn)
1171
1172        if self.tunnel_header:
1173            # drop the tunnel header and return the payload untouched
1174
1175            pkt.remove_payload()
1176            if pkt.version == 4:
1177                pkt.proto = esp.nh
1178            else:
1179                pkt.nh = esp.nh
1180            cls = pkt.guess_payload_class(esp.data)
1181
1182            return cls(esp.data)
1183        else:
1184            ip_header = pkt
1185
1186            if ip_header.version == 4:
1187                ip_header.proto = esp.nh
1188                del ip_header.chksum
1189                ip_header.remove_payload()
1190                ip_header.len = len(ip_header) + len(esp.data)
1191                # recompute checksum
1192                ip_header = ip_header.__class__(raw(ip_header))
1193            else:
1194                if self.nat_t_header:
1195                    # drop the UDP header and return the payload untouched
1196                    ip_header.nh = esp.nh
1197                    ip_header.remove_payload()
1198                else:
1199                    encrypted.underlayer.nh = esp.nh
1200                    encrypted.underlayer.remove_payload()
1201                ip_header.plen = len(ip_header.payload) + len(esp.data)
1202
1203            cls = ip_header.guess_payload_class(esp.data)
1204
1205            # reassemble the ip_header with the ESP payload
1206            return ip_header / cls(esp.data)
1207
1208    def _decrypt_ah(self, pkt, verify=True, esn_en=None, esn=None):
1209
1210        if verify:
1211            self.check_spi(pkt)
1212            self.auth_algo.verify(pkt, self.auth_key,
1213                                  esn_en=esn_en or self.esn_en,
1214                                  esn=esn or self.esn)
1215
1216        ah = pkt[AH]
1217        payload = ah.payload
1218        payload.remove_underlayer(None)  # useless argument...
1219
1220        if self.tunnel_header:
1221            return payload
1222        else:
1223            ip_header = pkt
1224
1225            if ip_header.version == 4:
1226                ip_header.proto = ah.nh
1227                del ip_header.chksum
1228                ip_header.remove_payload()
1229                ip_header.len = len(ip_header) + len(payload)
1230                # recompute checksum
1231                ip_header = ip_header.__class__(raw(ip_header))
1232            else:
1233                ah.underlayer.nh = ah.nh
1234                ah.underlayer.remove_payload()
1235                ip_header.plen = len(ip_header.payload) + len(payload)
1236
1237            # reassemble the ip_header with the AH payload
1238            return ip_header / payload
1239
1240    def decrypt(self, pkt, verify=True, esn_en=None, esn=None):
1241        """
1242        Decrypt (and decapsulate) an IP(v6) packet containing ESP or AH.
1243
1244        :param pkt:     the packet to decrypt
1245        :param verify:  if False, do not perform the integrity check
1246        :param esn_en:  extended sequence number enable which allows to use
1247                        64-bit sequence number instead of 32-bit when using an
1248                        AEAD algorithm
1249        :param esn:        extended sequence number (32 MSB)
1250        :returns: the decrypted/decapsulated packet
1251        :raise scapy.layers.ipsec.IPSecIntegrityError: if the integrity check
1252            fails
1253        """
1254        if not isinstance(pkt, self.SUPPORTED_PROTOS):
1255            raise TypeError('cannot decrypt %s, supported protos are %s'
1256                            % (pkt.__class__, self.SUPPORTED_PROTOS))
1257
1258        if self.proto is ESP and pkt.haslayer(ESP):
1259            return self._decrypt_esp(pkt, verify=verify,
1260                                     esn_en=esn_en, esn=esn)
1261        elif self.proto is AH and pkt.haslayer(AH):
1262            return self._decrypt_ah(pkt, verify=verify, esn_en=esn_en, esn=esn)
1263        else:
1264            raise TypeError('%s has no %s layer' % (pkt, self.proto.name))
1265