• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#############################################################################
2## ipsec.py --- IPsec support for Scapy                                    ##
3##                                                                         ##
4## Copyright (C) 2014  6WIND                                               ##
5##                                                                         ##
6## This program is free software; you can redistribute it and/or modify it ##
7## under the terms of the GNU General Public License version 2 as          ##
8## published by the Free Software Foundation.                              ##
9##                                                                         ##
10## This program is distributed in the hope that it will be useful, but     ##
11## WITHOUT ANY WARRANTY; without even the implied warranty of              ##
12## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU       ##
13## General Public License for more details.                                ##
14#############################################################################
15"""
16IPsec layer
17===========
18
19Example of use:
20
21>>> sa = SecurityAssociation(ESP, spi=0xdeadbeef, crypt_algo='AES-CBC',
22...                          crypt_key='sixteenbytes key')
23>>> p = IP(src='1.1.1.1', dst='2.2.2.2')
24>>> p /= TCP(sport=45012, dport=80)
25>>> p /= Raw('testdata')
26>>> p = IP(raw(p))
27>>> p
28<IP  version=4L ihl=5L tos=0x0 len=48 id=1 flags= frag=0L ttl=64 proto=tcp chksum=0x74c2 src=1.1.1.1 dst=2.2.2.2 options=[] |<TCP  sport=45012 dport=http seq=0 ack=0 dataofs=5L reserved=0L flags=S window=8192 chksum=0x1914 urgptr=0 options=[] |<Raw  load='testdata' |>>>
29>>>
30>>> e = sa.encrypt(p)
31>>> e
32<IP  version=4L ihl=5L tos=0x0 len=76 id=1 flags= frag=0L ttl=64 proto=esp chksum=0x747a src=1.1.1.1 dst=2.2.2.2 |<ESP  spi=0xdeadbeef seq=1 data=b'\xf8\xdb\x1e\x83[T\xab\\\xd2\x1b\xed\xd1\xe5\xc8Y\xc2\xa5d\x92\xc1\x05\x17\xa6\x92\x831\xe6\xc1]\x9a\xd6K}W\x8bFfd\xa5B*+\xde\xc8\x89\xbf{\xa9' |>>
33>>>
34>>> d = sa.decrypt(e)
35>>> d
36<IP  version=4L ihl=5L tos=0x0 len=48 id=1 flags= frag=0L ttl=64 proto=tcp chksum=0x74c2 src=1.1.1.1 dst=2.2.2.2 |<TCP  sport=45012 dport=http seq=0 ack=0 dataofs=5L reserved=0L flags=S window=8192 chksum=0x1914 urgptr=0 options=[] |<Raw  load='testdata' |>>>
37>>>
38>>> d == p
39True
40"""
41
42from __future__ import absolute_import
43from fractions import gcd
44import os
45import socket
46import struct
47
48from scapy.config import conf, crypto_validator
49from scapy.compat import orb, raw
50from scapy.data import IP_PROTOS
51from scapy.compat import *
52from scapy.error import log_loading
53from scapy.fields import ByteEnumField, ByteField, IntField, PacketField, \
54    ShortField, StrField, XIntField, XStrField, XStrLenField
55from scapy.packet import Packet, bind_layers, Raw
56from scapy.layers.inet import IP, UDP
57import scapy.modules.six as six
58from scapy.modules.six.moves import range
59from scapy.layers.inet6 import IPv6, IPv6ExtHdrHopByHop, IPv6ExtHdrDestOpt, \
60    IPv6ExtHdrRouting
61
62
63#------------------------------------------------------------------------------
64class AH(Packet):
65    """
66    Authentication Header
67
68    See https://tools.ietf.org/rfc/rfc4302.txt
69    """
70
71    name = 'AH'
72
73    def __get_icv_len(self):
74        """
75        Compute the size of the ICV based on the payloadlen field.
76        Padding size is included as it can only be known from the authentication
77        algorithm provided by the Security Association.
78        """
79        # payloadlen = length of AH in 32-bit words (4-byte units), minus "2"
80        # payloadlen = 3 32-bit word fixed fields + ICV + padding - 2
81        # ICV = (payloadlen + 2 - 3 - padding) in 32-bit words
82        return (self.payloadlen - 1) * 4
83
84    fields_desc = [
85        ByteEnumField('nh', None, IP_PROTOS),
86        ByteField('payloadlen', None),
87        ShortField('reserved', None),
88        XIntField('spi', 0x0),
89        IntField('seq', 0),
90        XStrLenField('icv', None, length_from=__get_icv_len),
91        # Padding len can only be known with the SecurityAssociation.auth_algo
92        XStrLenField('padding', None, length_from=lambda x: 0),
93    ]
94
95    overload_fields = {
96        IP: {'proto': socket.IPPROTO_AH},
97        IPv6: {'nh': socket.IPPROTO_AH},
98        IPv6ExtHdrHopByHop: {'nh': socket.IPPROTO_AH},
99        IPv6ExtHdrDestOpt: {'nh': socket.IPPROTO_AH},
100        IPv6ExtHdrRouting: {'nh': socket.IPPROTO_AH},
101    }
102
103bind_layers(IP, AH, proto=socket.IPPROTO_AH)
104bind_layers(IPv6, AH, nh=socket.IPPROTO_AH)
105bind_layers(AH, IP, nh=socket.IPPROTO_IP)
106bind_layers(AH, IPv6, nh=socket.IPPROTO_IPV6)
107
108#------------------------------------------------------------------------------
109class ESP(Packet):
110    """
111    Encapsulated Security Payload
112
113    See https://tools.ietf.org/rfc/rfc4303.txt
114    """
115    name = 'ESP'
116
117    fields_desc = [
118        XIntField('spi', 0x0),
119        IntField('seq', 0),
120        XStrField('data', None),
121    ]
122
123    overload_fields = {
124        IP: {'proto': socket.IPPROTO_ESP},
125        IPv6: {'nh': socket.IPPROTO_ESP},
126        IPv6ExtHdrHopByHop: {'nh': socket.IPPROTO_ESP},
127        IPv6ExtHdrDestOpt: {'nh': socket.IPPROTO_ESP},
128        IPv6ExtHdrRouting: {'nh': socket.IPPROTO_ESP},
129    }
130
131bind_layers(IP, ESP, proto=socket.IPPROTO_ESP)
132bind_layers(IPv6, ESP, nh=socket.IPPROTO_ESP)
133bind_layers(UDP, ESP, dport=4500)  # NAT-Traversal encapsulation
134bind_layers(UDP, ESP, sport=4500)  # NAT-Traversal encapsulation
135
136#------------------------------------------------------------------------------
137class _ESPPlain(Packet):
138    """
139    Internal class to represent unencrypted ESP packets.
140    """
141    name = 'ESP'
142
143    fields_desc = [
144        XIntField('spi', 0x0),
145        IntField('seq', 0),
146
147        StrField('iv', ''),
148        PacketField('data', '', Raw),
149        StrField('padding', ''),
150
151        ByteField('padlen', 0),
152        ByteEnumField('nh', 0, IP_PROTOS),
153        StrField('icv', ''),
154    ]
155
156    def data_for_encryption(self):
157        return raw(self.data) + self.padding + struct.pack("BB", self.padlen, self.nh)
158
159#------------------------------------------------------------------------------
160if conf.crypto_valid:
161    from cryptography.exceptions import InvalidTag
162    from cryptography.hazmat.backends import default_backend
163    from cryptography.hazmat.primitives.ciphers import (
164        Cipher,
165        algorithms,
166        modes,
167    )
168else:
169    log_loading.info("Can't import python-cryptography v1.7+. "
170                     "Disabled IPsec encryption/authentication.")
171    InvalidTag = default_backend = None
172    Cipher = algorithms = modes = None
173
174#------------------------------------------------------------------------------
175def _lcm(a, b):
176    """
177    Least Common Multiple between 2 integers.
178    """
179    if a == 0 or b == 0:
180        return 0
181    else:
182        return abs(a * b) // gcd(a, b)
183
184class CryptAlgo(object):
185    """
186    IPsec encryption algorithm
187    """
188
189    def __init__(self, name, cipher, mode, block_size=None, iv_size=None,
190                 key_size=None, icv_size=None, salt_size=None, format_mode_iv=None):
191        """
192        @param name: the name of this encryption algorithm
193        @param cipher: a Cipher module
194        @param mode: the mode used with the cipher module
195        @param block_size: the length a block for this algo. Defaults to the
196                           `block_size` of the cipher.
197        @param iv_size: the length of the initialization vector of this algo.
198                        Defaults to the `block_size` of the cipher.
199        @param key_size: an integer or list/tuple of integers. If specified,
200                         force the secret keys length to one of the values.
201                         Defaults to the `key_size` of the cipher.
202        @param icv_size: the length of the Integrity Check Value of this algo.
203                         Used by Combined Mode Algorithms e.g. GCM
204        @param salt_size: the length of the salt to use as the IV prefix.
205                          Usually used by Counter modes e.g. CTR
206        @param format_mode_iv: function to format the Initialization Vector
207                               e.g. handle the salt value
208                               Default is the random buffer from `generate_iv`
209        """
210        self.name = name
211        self.cipher = cipher
212        self.mode = mode
213        self.icv_size = icv_size
214
215        if modes and self.mode is not None:
216            self.is_aead = issubclass(self.mode,
217                                      modes.ModeWithAuthenticationTag)
218        else:
219            self.is_aead = False
220
221        if block_size is not None:
222            self.block_size = block_size
223        elif cipher is not None:
224            self.block_size = cipher.block_size // 8
225        else:
226            self.block_size = 1
227
228        if iv_size is None:
229            self.iv_size = self.block_size
230        else:
231            self.iv_size = iv_size
232
233        if key_size is not None:
234            self.key_size = key_size
235        elif cipher is not None:
236            self.key_size = tuple(i // 8 for i in cipher.key_sizes)
237        else:
238            self.key_size = None
239
240        if salt_size is None:
241            self.salt_size = 0
242        else:
243            self.salt_size = salt_size
244
245        if format_mode_iv is None:
246            self._format_mode_iv = lambda iv, **kw: iv
247        else:
248            self._format_mode_iv = format_mode_iv
249
250    def check_key(self, key):
251        """
252        Check that the key length is valid.
253
254        @param key:    a byte string
255        """
256        if self.key_size and not (len(key) == self.key_size or len(key) in self.key_size):
257            raise TypeError('invalid key size %s, must be %s' %
258                            (len(key), self.key_size))
259
260    def generate_iv(self):
261        """
262        Generate a random initialization vector.
263        """
264        # XXX: Handle counter modes with real counters? RFCs allow the use of
265        # XXX: random bytes for counters, so it is not wrong to do it that way
266        return os.urandom(self.iv_size)
267
268    @crypto_validator
269    def new_cipher(self, key, mode_iv, digest=None):
270        """
271        @param key:     the secret key, a byte string
272        @param mode_iv: the initialization vector or nonce, a byte string.
273                        Formatted by `format_mode_iv`.
274        @param digest:  also known as tag or icv. A byte string containing the
275                        digest of the encrypted data. Only use this during
276                        decryption!
277
278        @return:    an initialized cipher object for this algo
279        """
280        if self.is_aead and digest is not None:
281            # With AEAD, the mode needs the digest during decryption.
282            return Cipher(
283                self.cipher(key),
284                self.mode(mode_iv, digest, len(digest)),
285                default_backend(),
286            )
287        else:
288            return Cipher(
289                self.cipher(key),
290                self.mode(mode_iv),
291                default_backend(),
292            )
293
294    def pad(self, esp):
295        """
296        Add the correct amount of padding so that the data to encrypt is
297        exactly a multiple of the algorithm's block size.
298
299        Also, make sure that the total ESP packet length is a multiple of 4
300        bytes.
301
302        @param esp:    an unencrypted _ESPPlain packet
303
304        @return:    an unencrypted _ESPPlain packet with valid padding
305        """
306        # 2 extra bytes for padlen and nh
307        data_len = len(esp.data) + 2
308
309        # according to the RFC4303, section 2.4. Padding (for Encryption)
310        # the size of the ESP payload must be a multiple of 32 bits
311        align = _lcm(self.block_size, 4)
312
313        # pad for block size
314        esp.padlen = -data_len % align
315
316        # Still according to the RFC, the default value for padding *MUST* be an
317        # array of bytes starting from 1 to padlen
318        # TODO: Handle padding function according to the encryption algo
319        esp.padding = struct.pack("B" * esp.padlen, *range(1, esp.padlen + 1))
320
321        # If the following test fails, it means that this algo does not comply
322        # with the RFC
323        payload_len = len(esp.iv) + len(esp.data) + len(esp.padding) + 2
324        if payload_len % 4 != 0:
325            raise ValueError('The size of the ESP data is not aligned to 32 bits after padding.')
326
327        return esp
328
329    def encrypt(self, sa, esp, key):
330        """
331        Encrypt an ESP packet
332
333        @param sa:   the SecurityAssociation associated with the ESP packet.
334        @param esp:  an unencrypted _ESPPlain packet with valid padding
335        @param key:  the secret key used for encryption
336
337        @return:    a valid ESP packet encrypted with this algorithm
338        """
339        data = esp.data_for_encryption()
340
341        if self.cipher:
342            mode_iv = self._format_mode_iv(algo=self, sa=sa, iv=esp.iv)
343            cipher = self.new_cipher(key, mode_iv)
344            encryptor = cipher.encryptor()
345
346            if self.is_aead:
347                aad = struct.pack('!LL', esp.spi, esp.seq)
348                encryptor.authenticate_additional_data(aad)
349                data = encryptor.update(data) + encryptor.finalize()
350                data += encryptor.tag[:self.icv_size]
351            else:
352                data = encryptor.update(data) + encryptor.finalize()
353
354        return ESP(spi=esp.spi, seq=esp.seq, data=esp.iv + data)
355
356    def decrypt(self, sa, esp, key, icv_size=None):
357        """
358        Decrypt an ESP packet
359
360        @param sa:         the SecurityAssociation associated with the ESP packet.
361        @param esp:        an encrypted ESP packet
362        @param key:        the secret key used for encryption
363        @param icv_size:   the length of the icv used for integrity check
364
365        @return:    a valid ESP packet encrypted with this algorithm
366        @raise IPSecIntegrityError: if the integrity check fails with an AEAD
367                                    algorithm
368        """
369        if icv_size is None:
370            icv_size = self.icv_size if self.is_aead else 0
371
372        iv = esp.data[:self.iv_size]
373        data = esp.data[self.iv_size:len(esp.data) - icv_size]
374        icv = esp.data[len(esp.data) - icv_size:]
375
376        if self.cipher:
377            mode_iv = self._format_mode_iv(sa=sa, iv=iv)
378            cipher = self.new_cipher(key, mode_iv, icv)
379            decryptor = cipher.decryptor()
380
381            if self.is_aead:
382                # Tag value check is done during the finalize method
383                decryptor.authenticate_additional_data(
384                    struct.pack('!LL', esp.spi, esp.seq)
385                )
386
387            try:
388                data = decryptor.update(data) + decryptor.finalize()
389            except InvalidTag as err:
390                raise IPSecIntegrityError(err)
391
392        # extract padlen and nh
393        padlen = orb(data[-2])
394        nh = orb(data[-1])
395
396        # then use padlen to determine data and padding
397        data = data[:len(data) - padlen - 2]
398        padding = data[len(data) - padlen - 2: len(data) - 2]
399
400        return _ESPPlain(spi=esp.spi,
401                        seq=esp.seq,
402                        iv=iv,
403                        data=data,
404                        padding=padding,
405                        padlen=padlen,
406                        nh=nh,
407                        icv=icv)
408
409#------------------------------------------------------------------------------
410# The names of the encryption algorithms are the same than in scapy.contrib.ikev2
411# see http://www.iana.org/assignments/ikev2-parameters/ikev2-parameters.xhtml
412
413CRYPT_ALGOS = {
414    'NULL': CryptAlgo('NULL', cipher=None, mode=None, iv_size=0),
415}
416
417if algorithms:
418    CRYPT_ALGOS['AES-CBC'] = CryptAlgo('AES-CBC',
419                                       cipher=algorithms.AES,
420                                       mode=modes.CBC)
421    _aes_ctr_format_mode_iv = lambda sa, iv, **kw: sa.crypt_salt + iv + b'\x00\x00\x00\x01'
422    CRYPT_ALGOS['AES-CTR'] = CryptAlgo('AES-CTR',
423                                       cipher=algorithms.AES,
424                                       mode=modes.CTR,
425                                       iv_size=8,
426                                       salt_size=4,
427                                       format_mode_iv=_aes_ctr_format_mode_iv)
428    _salt_format_mode_iv = lambda sa, iv, **kw: sa.crypt_salt + iv
429    CRYPT_ALGOS['AES-GCM'] = CryptAlgo('AES-GCM',
430                                       cipher=algorithms.AES,
431                                       mode=modes.GCM,
432                                       salt_size=4,
433                                       iv_size=8,
434                                       icv_size=16,
435                                       format_mode_iv=_salt_format_mode_iv)
436    if hasattr(modes, 'CCM'):
437        CRYPT_ALGOS['AES-CCM'] = CryptAlgo('AES-CCM',
438                                           cipher=algorithms.AES,
439                                           mode=modes.CCM,
440                                           iv_size=8,
441                                           salt_size=3,
442                                           icv_size=16,
443                                           format_mode_iv=_salt_format_mode_iv)
444    # XXX: Flagged as weak by 'cryptography'. Kept for backward compatibility
445    CRYPT_ALGOS['Blowfish'] = CryptAlgo('Blowfish',
446                                        cipher=algorithms.Blowfish,
447                                        mode=modes.CBC)
448    # XXX: RFC7321 states that DES *MUST NOT* be implemented.
449    # XXX: Keep for backward compatibility?
450    # Using a TripleDES cipher algorithm for DES is done by using the same 64
451    # bits key 3 times (done by cryptography when given a 64 bits key)
452    CRYPT_ALGOS['DES'] = CryptAlgo('DES',
453                                   cipher=algorithms.TripleDES,
454                                   mode=modes.CBC,
455                                   key_size=(8,))
456    CRYPT_ALGOS['3DES'] = CryptAlgo('3DES',
457                                    cipher=algorithms.TripleDES,
458                                    mode=modes.CBC)
459    CRYPT_ALGOS['CAST'] = CryptAlgo('CAST',
460                                    cipher=algorithms.CAST5,
461                                    mode=modes.CBC)
462
463#------------------------------------------------------------------------------
464if conf.crypto_valid:
465    from cryptography.hazmat.primitives.hmac import HMAC
466    from cryptography.hazmat.primitives.cmac import CMAC
467    from cryptography.hazmat.primitives import hashes
468else:
469    # no error if cryptography is not available but authentication won't be supported
470    HMAC = CMAC = hashes = None
471
472#------------------------------------------------------------------------------
473class IPSecIntegrityError(Exception):
474    """
475    Error risen when the integrity check fails.
476    """
477    pass
478
479class AuthAlgo(object):
480    """
481    IPsec integrity algorithm
482    """
483
484    def __init__(self, name, mac, digestmod, icv_size, key_size=None):
485        """
486        @param name: the name of this integrity algorithm
487        @param mac: a Message Authentication Code module
488        @param digestmod: a Hash or Cipher module
489        @param icv_size: the length of the integrity check value of this algo
490        @param key_size: an integer or list/tuple of integers. If specified,
491                         force the secret keys length to one of the values.
492                         Defaults to the `key_size` of the cipher.
493        """
494        self.name = name
495        self.mac = mac
496        self.digestmod = digestmod
497        self.icv_size = icv_size
498        self.key_size = key_size
499
500    def check_key(self, key):
501        """
502        Check that the key length is valid.
503
504        @param key:    a byte string
505        """
506        if self.key_size and len(key) not in self.key_size:
507            raise TypeError('invalid key size %s, must be one of %s' %
508                            (len(key), self.key_size))
509
510    @crypto_validator
511    def new_mac(self, key):
512        """
513        @param key:    a byte string
514        @return:       an initialized mac object for this algo
515        """
516        if self.mac is CMAC:
517            return self.mac(self.digestmod(key), default_backend())
518        else:
519            return self.mac(key, self.digestmod(), default_backend())
520
521    def sign(self, pkt, key):
522        """
523        Sign an IPsec (ESP or AH) packet with this algo.
524
525        @param pkt:    a packet that contains a valid encrypted ESP or AH layer
526        @param key:    the authentication key, a byte string
527
528        @return: the signed packet
529        """
530        if not self.mac:
531            return pkt
532
533        mac = self.new_mac(key)
534
535        if pkt.haslayer(ESP):
536            mac.update(raw(pkt[ESP]))
537            pkt[ESP].data += mac.finalize()[:self.icv_size]
538
539        elif pkt.haslayer(AH):
540            clone = zero_mutable_fields(pkt.copy(), sending=True)
541            mac.update(raw(clone))
542            pkt[AH].icv = mac.finalize()[:self.icv_size]
543
544        return pkt
545
546    def verify(self, pkt, key):
547        """
548        Check that the integrity check value (icv) of a packet is valid.
549
550        @param pkt:    a packet that contains a valid encrypted ESP or AH layer
551        @param key:    the authentication key, a byte string
552
553        @raise IPSecIntegrityError: if the integrity check fails
554        """
555        if not self.mac or self.icv_size == 0:
556            return
557
558        mac = self.new_mac(key)
559
560        pkt_icv = 'not found'
561        computed_icv = 'not computed'
562
563        if isinstance(pkt, ESP):
564            pkt_icv = pkt.data[len(pkt.data) - self.icv_size:]
565            clone = pkt.copy()
566            clone.data = clone.data[:len(clone.data) - self.icv_size]
567
568        elif pkt.haslayer(AH):
569            if len(pkt[AH].icv) != self.icv_size:
570                # Fill padding since we know the actual icv_size
571                pkt[AH].padding = pkt[AH].icv[self.icv_size:]
572                pkt[AH].icv = pkt[AH].icv[:self.icv_size]
573            pkt_icv = pkt[AH].icv
574            clone = zero_mutable_fields(pkt.copy(), sending=False)
575
576        mac.update(raw(clone))
577        computed_icv = mac.finalize()[:self.icv_size]
578
579        # XXX: Cannot use mac.verify because the ICV can be truncated
580        if pkt_icv != computed_icv:
581            raise IPSecIntegrityError('pkt_icv=%r, computed_icv=%r' %
582                                      (pkt_icv, computed_icv))
583
584#------------------------------------------------------------------------------
585# The names of the integrity algorithms are the same than in scapy.contrib.ikev2
586# see http://www.iana.org/assignments/ikev2-parameters/ikev2-parameters.xhtml
587
588AUTH_ALGOS = {
589    'NULL': AuthAlgo('NULL', mac=None, digestmod=None, icv_size=0),
590}
591
592if HMAC and hashes:
593    # XXX: NIST has deprecated SHA1 but is required by RFC7321
594    AUTH_ALGOS['HMAC-SHA1-96'] = AuthAlgo('HMAC-SHA1-96',
595                                          mac=HMAC,
596                                          digestmod=hashes.SHA1,
597                                          icv_size=12)
598    AUTH_ALGOS['SHA2-256-128'] = AuthAlgo('SHA2-256-128',
599                                          mac=HMAC,
600                                          digestmod=hashes.SHA256,
601                                          icv_size=16)
602    AUTH_ALGOS['SHA2-384-192'] = AuthAlgo('SHA2-384-192',
603                                          mac=HMAC,
604                                          digestmod=hashes.SHA384,
605                                          icv_size=24)
606    AUTH_ALGOS['SHA2-512-256'] = AuthAlgo('SHA2-512-256',
607                                          mac=HMAC,
608                                          digestmod=hashes.SHA512,
609                                          icv_size=32)
610    # XXX:Flagged as deprecated by 'cryptography'. Kept for backward compat
611    AUTH_ALGOS['HMAC-MD5-96'] = AuthAlgo('HMAC-MD5-96',
612                                         mac=HMAC,
613                                         digestmod=hashes.MD5,
614                                         icv_size=12)
615if CMAC and algorithms:
616    AUTH_ALGOS['AES-CMAC-96'] = AuthAlgo('AES-CMAC-96',
617                                      mac=CMAC,
618                                      digestmod=algorithms.AES,
619                                      icv_size=12,
620                                      key_size=(16,))
621
622#------------------------------------------------------------------------------
623def split_for_transport(orig_pkt, transport_proto):
624    """
625    Split an IP(v6) packet in the correct location to insert an ESP or AH
626    header.
627
628    @param orig_pkt: the packet to split. Must be an IP or IPv6 packet
629    @param transport_proto: the IPsec protocol number that will be inserted
630                            at the split position.
631    @return: a tuple (header, nh, payload) where nh is the protocol number of
632             payload.
633    """
634    # force resolution of default fields to avoid padding errors
635    header = orig_pkt.__class__(raw(orig_pkt))
636    next_hdr = header.payload
637    nh = None
638
639    if header.version == 4:
640        nh = header.proto
641        header.proto = transport_proto
642        header.remove_payload()
643        del header.chksum
644        del header.len
645
646        return header, nh, next_hdr
647    else:
648        found_rt_hdr = False
649        prev = header
650
651        # Since the RFC 4302 is vague about where the ESP/AH headers should be
652        # inserted in IPv6, I chose to follow the linux implementation.
653        while isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrRouting, IPv6ExtHdrDestOpt)):
654            if isinstance(next_hdr, IPv6ExtHdrHopByHop):
655                pass
656            if isinstance(next_hdr, IPv6ExtHdrRouting):
657                found_rt_hdr = True
658            elif isinstance(next_hdr, IPv6ExtHdrDestOpt) and found_rt_hdr:
659                break
660
661            prev = next_hdr
662            next_hdr = next_hdr.payload
663
664        nh = prev.nh
665        prev.nh = transport_proto
666        prev.remove_payload()
667        del header.plen
668
669        return header, nh, next_hdr
670
671#------------------------------------------------------------------------------
672# see RFC 4302 - Appendix A. Mutability of IP Options/Extension Headers
673IMMUTABLE_IPV4_OPTIONS = (
674    0, # End Of List
675    1, # No OPeration
676    2, # Security
677    5, # Extended Security
678    6, # Commercial Security
679    20, # Router Alert
680    21, # Sender Directed Multi-Destination Delivery
681)
682def zero_mutable_fields(pkt, sending=False):
683    """
684    When using AH, all "mutable" fields must be "zeroed" before calculating
685    the ICV. See RFC 4302, Section 3.3.3.1. Handling Mutable Fields.
686
687    @param pkt: an IP(v6) packet containing an AH layer.
688                NOTE: The packet will be modified
689    @param sending: if true, ipv6 routing headers will not be reordered
690    """
691
692    if pkt.haslayer(AH):
693        pkt[AH].icv = b"\x00" * len(pkt[AH].icv)
694    else:
695        raise TypeError('no AH layer found')
696
697    if pkt.version == 4:
698        # the tos field has been replaced by DSCP and ECN
699        # Routers may rewrite the DS field as needed to provide a
700        # desired local or end-to-end service
701        pkt.tos = 0
702        # an intermediate router might set the DF bit, even if the source
703        # did not select it.
704        pkt.flags = 0
705        # changed en route as a normal course of processing by routers
706        pkt.ttl = 0
707        # will change if any of these other fields change
708        pkt.chksum = 0
709
710        immutable_opts = []
711        for opt in pkt.options:
712            if opt.option in IMMUTABLE_IPV4_OPTIONS:
713                immutable_opts.append(opt)
714            else:
715                immutable_opts.append(Raw(b"\x00" * len(opt)))
716        pkt.options = immutable_opts
717
718    else:
719        # holds DSCP and ECN
720        pkt.tc = 0
721        # The flow label described in AHv1 was mutable, and in RFC 2460 [DH98]
722        # was potentially mutable. To retain compatibility with existing AH
723        # implementations, the flow label is not included in the ICV in AHv2.
724        pkt.fl = 0
725        # same as ttl
726        pkt.hlim = 0
727
728        next_hdr = pkt.payload
729
730        while isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrRouting, IPv6ExtHdrDestOpt)):
731            if isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrDestOpt)):
732                for opt in next_hdr.options:
733                    if opt.otype & 0x20:
734                        # option data can change en-route and must be zeroed
735                        opt.optdata = b"\x00" * opt.optlen
736            elif isinstance(next_hdr, IPv6ExtHdrRouting) and sending:
737                # The sender must order the field so that it appears as it
738                # will at the receiver, prior to performing the ICV computation.
739                next_hdr.segleft = 0
740                if next_hdr.addresses:
741                    final = next_hdr.addresses.pop()
742                    next_hdr.addresses.insert(0, pkt.dst)
743                    pkt.dst = final
744            else:
745                break
746
747            next_hdr = next_hdr.payload
748
749    return pkt
750
751#------------------------------------------------------------------------------
752class SecurityAssociation(object):
753    """
754    This class is responsible of "encryption" and "decryption" of IPsec packets.
755    """
756
757    SUPPORTED_PROTOS = (IP, IPv6)
758
759    def __init__(self, proto, spi, seq_num=1, crypt_algo=None, crypt_key=None,
760                 auth_algo=None, auth_key=None, tunnel_header=None, nat_t_header=None):
761        """
762        @param proto: the IPsec proto to use (ESP or AH)
763        @param spi: the Security Parameters Index of this SA
764        @param seq_num: the initial value for the sequence number on encrypted
765                        packets
766        @param crypt_algo: the encryption algorithm name (only used with ESP)
767        @param crypt_key: the encryption key (only used with ESP)
768        @param auth_algo: the integrity algorithm name
769        @param auth_key: the integrity key
770        @param tunnel_header: an instance of a IP(v6) header that will be used
771                              to encapsulate the encrypted packets.
772        @param nat_t_header: an instance of a UDP header that will be used
773                             for NAT-Traversal.
774        """
775
776        if proto not in (ESP, AH, ESP.name, AH.name):
777            raise ValueError("proto must be either ESP or AH")
778        if isinstance(proto, six.string_types):
779            self.proto = eval(proto)
780        else:
781            self.proto = proto
782
783        self.spi = spi
784        self.seq_num = seq_num
785
786        if crypt_algo:
787            if crypt_algo not in CRYPT_ALGOS:
788                raise TypeError('unsupported encryption algo %r, try %r' %
789                                (crypt_algo, list(CRYPT_ALGOS)))
790            self.crypt_algo = CRYPT_ALGOS[crypt_algo]
791
792            if crypt_key:
793                salt_size = self.crypt_algo.salt_size
794                self.crypt_key = crypt_key[:len(crypt_key) - salt_size]
795                self.crypt_salt = crypt_key[len(crypt_key) - salt_size:]
796            else:
797                self.crypt_key = None
798                self.crypt_salt = None
799
800        else:
801            self.crypt_algo = CRYPT_ALGOS['NULL']
802            self.crypt_key = None
803
804        if auth_algo:
805            if auth_algo not in AUTH_ALGOS:
806                raise TypeError('unsupported integrity algo %r, try %r' %
807                                (auth_algo, list(AUTH_ALGOS)))
808            self.auth_algo = AUTH_ALGOS[auth_algo]
809            self.auth_key = auth_key
810        else:
811            self.auth_algo = AUTH_ALGOS['NULL']
812            self.auth_key = None
813
814        if tunnel_header and not isinstance(tunnel_header, (IP, IPv6)):
815            raise TypeError('tunnel_header must be %s or %s' % (IP.name, IPv6.name))
816        self.tunnel_header = tunnel_header
817
818        if nat_t_header:
819            if proto is not ESP:
820                raise TypeError('nat_t_header is only allowed with ESP')
821            if not isinstance(nat_t_header, UDP):
822                raise TypeError('nat_t_header must be %s' % UDP.name)
823        self.nat_t_header = nat_t_header
824
825    def check_spi(self, pkt):
826        if pkt.spi != self.spi:
827            raise TypeError('packet spi=0x%x does not match the SA spi=0x%x' %
828                            (pkt.spi, self.spi))
829
830    def _encrypt_esp(self, pkt, seq_num=None, iv=None):
831
832        if iv is None:
833            iv = self.crypt_algo.generate_iv()
834        else:
835            if len(iv) != self.crypt_algo.iv_size:
836                raise TypeError('iv length must be %s' % self.crypt_algo.iv_size)
837
838        esp = _ESPPlain(spi=self.spi, seq=seq_num or self.seq_num, iv=iv)
839
840        if self.tunnel_header:
841            tunnel = self.tunnel_header.copy()
842
843            if tunnel.version == 4:
844                del tunnel.proto
845                del tunnel.len
846                del tunnel.chksum
847            else:
848                del tunnel.nh
849                del tunnel.plen
850
851            pkt = tunnel.__class__(raw(tunnel / pkt))
852
853        ip_header, nh, payload = split_for_transport(pkt, socket.IPPROTO_ESP)
854        esp.data = payload
855        esp.nh = nh
856
857        esp = self.crypt_algo.pad(esp)
858        esp = self.crypt_algo.encrypt(self, esp, self.crypt_key)
859
860        self.auth_algo.sign(esp, self.auth_key)
861
862        if self.nat_t_header:
863            nat_t_header = self.nat_t_header.copy()
864            nat_t_header.chksum = 0
865            del nat_t_header.len
866            if ip_header.version == 4:
867                del ip_header.proto
868            else:
869                del ip_header.nh
870            ip_header /= nat_t_header
871
872        if ip_header.version == 4:
873            ip_header.len = len(ip_header) + len(esp)
874            del ip_header.chksum
875            ip_header = ip_header.__class__(raw(ip_header))
876        else:
877            ip_header.plen = len(ip_header.payload) + len(esp)
878
879        # sequence number must always change, unless specified by the user
880        if seq_num is None:
881            self.seq_num += 1
882
883        return ip_header / esp
884
885    def _encrypt_ah(self, pkt, seq_num=None):
886
887        ah = AH(spi=self.spi, seq=seq_num or self.seq_num,
888                icv = b"\x00" * self.auth_algo.icv_size)
889
890        if self.tunnel_header:
891            tunnel = self.tunnel_header.copy()
892
893            if tunnel.version == 4:
894                del tunnel.proto
895                del tunnel.len
896                del tunnel.chksum
897            else:
898                del tunnel.nh
899                del tunnel.plen
900
901            pkt = tunnel.__class__(raw(tunnel / pkt))
902
903        ip_header, nh, payload = split_for_transport(pkt, socket.IPPROTO_AH)
904        ah.nh = nh
905
906        if ip_header.version == 6 and len(ah) % 8 != 0:
907            # For IPv6, the total length of the header must be a multiple of
908            # 8-octet units.
909            ah.padding = b"\x00" * (-len(ah) % 8)
910        elif len(ah) % 4 != 0:
911            # For IPv4, the total length of the header must be a multiple of
912            # 4-octet units.
913            ah.padding = b"\x00" * (-len(ah) % 4)
914
915        # RFC 4302 - Section 2.2. Payload Length
916        # This 8-bit field specifies the length of AH in 32-bit words (4-byte
917        # units), minus "2".
918        ah.payloadlen = len(ah) // 4 - 2
919
920        if ip_header.version == 4:
921            ip_header.len = len(ip_header) + len(ah) + len(payload)
922            del ip_header.chksum
923            ip_header = ip_header.__class__(raw(ip_header))
924        else:
925            ip_header.plen = len(ip_header.payload) + len(ah) + len(payload)
926
927        signed_pkt = self.auth_algo.sign(ip_header / ah / payload, self.auth_key)
928
929        # sequence number must always change, unless specified by the user
930        if seq_num is None:
931            self.seq_num += 1
932
933        return signed_pkt
934
935    def encrypt(self, pkt, seq_num=None, iv=None):
936        """
937        Encrypt (and encapsulate) an IP(v6) packet with ESP or AH according
938        to this SecurityAssociation.
939
940        @param pkt:     the packet to encrypt
941        @param seq_num: if specified, use this sequence number instead of the
942                        generated one
943        @param iv:      if specified, use this initialization vector for
944                        encryption instead of a random one.
945
946        @return: the encrypted/encapsulated packet
947        """
948        if not isinstance(pkt, self.SUPPORTED_PROTOS):
949            raise TypeError('cannot encrypt %s, supported protos are %s'
950                            % (pkt.__class__, self.SUPPORTED_PROTOS))
951        if self.proto is ESP:
952            return self._encrypt_esp(pkt, seq_num=seq_num, iv=iv)
953        else:
954            return self._encrypt_ah(pkt, seq_num=seq_num)
955
956    def _decrypt_esp(self, pkt, verify=True):
957
958        encrypted = pkt[ESP]
959
960        if verify:
961            self.check_spi(pkt)
962            self.auth_algo.verify(encrypted, self.auth_key)
963
964        esp = self.crypt_algo.decrypt(self, encrypted, self.crypt_key,
965                                      self.crypt_algo.icv_size or
966                                      self.auth_algo.icv_size)
967
968        if self.tunnel_header:
969            # drop the tunnel header and return the payload untouched
970
971            pkt.remove_payload()
972            if pkt.version == 4:
973                pkt.proto = esp.nh
974            else:
975                pkt.nh = esp.nh
976            cls = pkt.guess_payload_class(esp.data)
977
978            return cls(esp.data)
979        else:
980            ip_header = pkt
981
982            if ip_header.version == 4:
983                ip_header.proto = esp.nh
984                del ip_header.chksum
985                ip_header.remove_payload()
986                ip_header.len = len(ip_header) + len(esp.data)
987                # recompute checksum
988                ip_header = ip_header.__class__(raw(ip_header))
989            else:
990                encrypted.underlayer.nh = esp.nh
991                encrypted.underlayer.remove_payload()
992                ip_header.plen = len(ip_header.payload) + len(esp.data)
993
994            cls = ip_header.guess_payload_class(esp.data)
995
996            # reassemble the ip_header with the ESP payload
997            return ip_header / cls(esp.data)
998
999    def _decrypt_ah(self, pkt, verify=True):
1000
1001        if verify:
1002            self.check_spi(pkt)
1003            self.auth_algo.verify(pkt, self.auth_key)
1004
1005        ah = pkt[AH]
1006        payload = ah.payload
1007        payload.remove_underlayer(None)  # useless argument...
1008
1009        if self.tunnel_header:
1010            return payload
1011        else:
1012            ip_header = pkt
1013
1014            if ip_header.version == 4:
1015                ip_header.proto = ah.nh
1016                del ip_header.chksum
1017                ip_header.remove_payload()
1018                ip_header.len = len(ip_header) + len(payload)
1019                # recompute checksum
1020                ip_header = ip_header.__class__(raw(ip_header))
1021            else:
1022                ah.underlayer.nh = ah.nh
1023                ah.underlayer.remove_payload()
1024                ip_header.plen = len(ip_header.payload) + len(payload)
1025
1026            # reassemble the ip_header with the AH payload
1027            return ip_header / payload
1028
1029    def decrypt(self, pkt, verify=True):
1030        """
1031        Decrypt (and decapsulate) an IP(v6) packet containing ESP or AH.
1032
1033        @param pkt:     the packet to decrypt
1034        @param verify:  if False, do not perform the integrity check
1035
1036        @return: the decrypted/decapsulated packet
1037        @raise IPSecIntegrityError: if the integrity check fails
1038        """
1039        if not isinstance(pkt, self.SUPPORTED_PROTOS):
1040            raise TypeError('cannot decrypt %s, supported protos are %s'
1041                            % (pkt.__class__, self.SUPPORTED_PROTOS))
1042
1043        if self.proto is ESP and pkt.haslayer(ESP):
1044            return self._decrypt_esp(pkt, verify=verify)
1045        elif self.proto is AH and pkt.haslayer(AH):
1046            return self._decrypt_ah(pkt, verify=verify)
1047        else:
1048            raise TypeError('%s has no %s layer' % (pkt, self.proto.name))
1049