• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19from scapy import all as scapy
20from socket import *  # pylint: disable=wildcard-import
21import struct
22import subprocess
23import threading
24import unittest
25
26import csocket
27import cstruct
28import multinetwork_base
29import net_test
30import packets
31import xfrm
32import xfrm_base
33
34ENCRYPTED_PAYLOAD = ("b1c74998efd6326faebe2061f00f2c750e90e76001664a80c287b150"
35                     "59e74bf949769cc6af71e51b539e7de3a2a14cb05a231b969e035174"
36                     "d98c5aa0cef1937db98889ec0d08fa408fecf616")
37
38TEST_ADDR1 = "2001:4860:4860::8888"
39TEST_ADDR2 = "2001:4860:4860::8844"
40
41XFRM_STATS_PROCFILE = "/proc/net/xfrm_stat"
42XFRM_STATS_OUT_NO_STATES = "XfrmOutNoStates"
43
44# IP addresses to use for tunnel endpoints. For generality, these should be
45# different from the addresses we send packets to.
46TUNNEL_ENDPOINTS = {4: "8.8.4.4", 6: TEST_ADDR2}
47
48TEST_SPI = 0x1234
49TEST_SPI2 = 0x1235
50
51
52
53class XfrmFunctionalTest(xfrm_base.XfrmLazyTest):
54
55  def assertIsUdpEncapEsp(self, packet, spi, seq, length):
56    self.assertEqual(IPPROTO_UDP, packet.proto)
57    udp_hdr = packet[scapy.UDP]
58    self.assertEqual(4500, udp_hdr.dport)
59    self.assertEqual(length, len(udp_hdr))
60    esp_hdr, _ = cstruct.Read(str(udp_hdr.payload), xfrm.EspHdr)
61    # FIXME: this file currently swaps SPI byte order manually, so SPI needs to
62    # be double-swapped here.
63    self.assertEqual(xfrm.EspHdr((spi, seq)), esp_hdr)
64
65  def CreateNewSa(self, localAddr, remoteAddr, spi, reqId, encap_tmpl,
66                  null_auth=False):
67    auth_algo = (
68        xfrm_base._ALGO_AUTH_NULL if null_auth else xfrm_base._ALGO_HMAC_SHA1)
69    self.xfrm.AddSaInfo(localAddr, remoteAddr, spi, xfrm.XFRM_MODE_TRANSPORT,
70                    reqId, xfrm_base._ALGO_CBC_AES_256, auth_algo, None,
71                    encap_tmpl, None, None)
72
73  def testAddSa(self):
74    self.CreateNewSa("::", TEST_ADDR1, TEST_SPI, 3320, None)
75    expected = (
76        "src :: dst 2001:4860:4860::8888\n"
77        "\tproto esp spi 0x00001234 reqid 3320 mode transport\n"
78        "\treplay-window 4 \n"
79        "\tauth-trunc hmac(sha1) 0x%s 96\n"
80        "\tenc cbc(aes) 0x%s\n"
81        "\tsel src ::/0 dst ::/0 \n" % (
82            xfrm_base._AUTHENTICATION_KEY_128.encode("hex"),
83            xfrm_base._ENCRYPTION_KEY_256.encode("hex")))
84
85    actual = subprocess.check_output("ip xfrm state".split())
86    # Newer versions of IP also show anti-replay context. Don't choke if it's
87    # missing.
88    actual = actual.replace(
89        "\tanti-replay context: seq 0x0, oseq 0x0, bitmap 0x00000000\n", "")
90    try:
91      self.assertMultiLineEqual(expected, actual)
92    finally:
93      self.xfrm.DeleteSaInfo(TEST_ADDR1, TEST_SPI, IPPROTO_ESP)
94
95  def testFlush(self):
96    self.assertEqual(0, len(self.xfrm.DumpSaInfo()))
97    self.CreateNewSa("::", "2000::", TEST_SPI, 1234, None)
98    self.CreateNewSa("0.0.0.0", "192.0.2.1", TEST_SPI, 4321, None)
99    self.assertEqual(2, len(self.xfrm.DumpSaInfo()))
100    self.xfrm.FlushSaInfo()
101    self.assertEqual(0, len(self.xfrm.DumpSaInfo()))
102
103  def _TestSocketPolicy(self, version):
104    # Open a UDP socket and connect it.
105    family = net_test.GetAddressFamily(version)
106    s = socket(family, SOCK_DGRAM, 0)
107    netid = self.RandomNetid()
108    self.SelectInterface(s, netid, "mark")
109
110    remotesockaddr = self.GetRemoteSocketAddress(version)
111    s.connect((remotesockaddr, 53))
112    saddr, sport = s.getsockname()[:2]
113    daddr, dport = s.getpeername()[:2]
114    if version == 5:
115      saddr = saddr.replace("::ffff:", "")
116      daddr = daddr.replace("::ffff:", "")
117
118    reqid = 0
119
120    desc, pkt = packets.UDP(version, saddr, daddr, sport=sport)
121    s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
122    self.ExpectPacketOn(netid, "Send after socket, expected %s" % desc, pkt)
123
124    # Using IPv4 XFRM on a dual-stack socket requires setting an AF_INET policy
125    # that's written in terms of IPv4 addresses.
126    xfrm_version = 4 if version == 5 else version
127    xfrm_family = net_test.GetAddressFamily(xfrm_version)
128    xfrm_base.ApplySocketPolicy(s, xfrm_family, xfrm.XFRM_POLICY_OUT,
129                                TEST_SPI, reqid, None)
130
131    # Because the policy has level set to "require" (the default), attempting
132    # to send a packet results in an error, because there is no SA that
133    # matches the socket policy we set.
134    self.assertRaisesErrno(
135        EAGAIN,
136        s.sendto, net_test.UDP_PAYLOAD, (remotesockaddr, 53))
137
138    # If there is a user space key manager, calling sendto() after applying the socket policy
139    # creates an SA whose state is XFRM_STATE_ACQ. So this just deletes it.
140    # If there is no user space key manager, deleting SA returns ESRCH as the error code.
141    try:
142        self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI, IPPROTO_ESP)
143    except IOError as e:
144        self.assertEqual(ESRCH, e.errno, "Unexpected error when deleting ACQ SA")
145
146    # Adding a matching SA causes the packet to go out encrypted. The SA's
147    # SPI must match the one in our template, and the destination address must
148    # match the packet's destination address (in tunnel mode, it has to match
149    # the tunnel destination).
150    self.CreateNewSa(
151        net_test.GetWildcardAddress(xfrm_version),
152        self.GetRemoteAddress(xfrm_version), TEST_SPI, reqid, None)
153
154    s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
155    expected_length = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TRANSPORT,
156                                                version, False,
157                                                net_test.UDP_PAYLOAD,
158                                                xfrm_base._ALGO_HMAC_SHA1,
159                                                xfrm_base._ALGO_CBC_AES_256)
160    self._ExpectEspPacketOn(netid, TEST_SPI, 1, expected_length, None, None)
161
162    # Sending to another destination doesn't work: again, no matching SA.
163    remoteaddr2 = self.GetOtherRemoteSocketAddress(version)
164    self.assertRaisesErrno(
165        EAGAIN,
166        s.sendto, net_test.UDP_PAYLOAD, (remoteaddr2, 53))
167
168    # Sending on another socket without the policy applied results in an
169    # unencrypted packet going out.
170    s2 = socket(family, SOCK_DGRAM, 0)
171    self.SelectInterface(s2, netid, "mark")
172    s2.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
173    pkts = self.ReadAllPacketsOn(netid)
174    self.assertEqual(1, len(pkts))
175    packet = pkts[0]
176
177    protocol = packet.nh if version == 6 else packet.proto
178    self.assertEqual(IPPROTO_UDP, protocol)
179
180    # Deleting the SA causes the first socket to return errors again.
181    self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI,
182                           IPPROTO_ESP)
183    self.assertRaisesErrno(
184        EAGAIN,
185        s.sendto, net_test.UDP_PAYLOAD, (remotesockaddr, 53))
186
187    # Clear the socket policy and expect a cleartext packet.
188    xfrm_base.SetPolicySockopt(s, family, None)
189    s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
190    self.ExpectPacketOn(netid, "Send after clear, expected %s" % desc, pkt)
191
192    # Clearing the policy twice is safe.
193    xfrm_base.SetPolicySockopt(s, family, None)
194    s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
195    self.ExpectPacketOn(netid, "Send after clear 2, expected %s" % desc, pkt)
196
197    # Clearing if a policy was never set is safe.
198    s = socket(AF_INET6, SOCK_DGRAM, 0)
199    xfrm_base.SetPolicySockopt(s, family, None)
200
201  def testSocketPolicyIPv4(self):
202    self._TestSocketPolicy(4)
203
204  def testSocketPolicyIPv6(self):
205    self._TestSocketPolicy(6)
206
207  def testSocketPolicyMapped(self):
208    self._TestSocketPolicy(5)
209
210  # Sets up sockets and marks to correct netid
211  def _SetupUdpEncapSockets(self):
212    netid = self.RandomNetid()
213    myaddr = self.MyAddress(4, netid)
214    remoteaddr = self.GetRemoteAddress(4)
215
216    # Reserve a port on which to receive UDP encapsulated packets. Sending
217    # packets works without this (and potentially can send packets with a source
218    # port belonging to another application), but receiving requires the port to
219    # be bound and the encapsulation socket option enabled.
220    encap_sock = net_test.Socket(AF_INET, SOCK_DGRAM, 0)
221    encap_sock.bind((myaddr, 0))
222    encap_port = encap_sock.getsockname()[1]
223    encap_sock.setsockopt(IPPROTO_UDP, xfrm.UDP_ENCAP, xfrm.UDP_ENCAP_ESPINUDP)
224
225    # Open a socket to send traffic.
226    s = socket(AF_INET, SOCK_DGRAM, 0)
227    self.SelectInterface(s, netid, "mark")
228    s.connect((remoteaddr, 53))
229
230    return netid, myaddr, remoteaddr, encap_sock, encap_port, s
231
232  # Sets up SAs and applies socket policy to given socket
233  def _SetupUdpEncapSaPair(self, myaddr, remoteaddr, in_spi, out_spi,
234                           encap_port, s, use_null_auth):
235    in_reqid = 123
236    out_reqid = 456
237
238    # Create inbound and outbound SAs that specify UDP encapsulation.
239    encaptmpl = xfrm.XfrmEncapTmpl((xfrm.UDP_ENCAP_ESPINUDP, htons(encap_port),
240                                    htons(4500), 16 * "\x00"))
241    self.CreateNewSa(myaddr, remoteaddr, out_spi, out_reqid, encaptmpl,
242                     use_null_auth)
243
244    # Add an encap template that's the mirror of the outbound one.
245    encaptmpl.sport, encaptmpl.dport = encaptmpl.dport, encaptmpl.sport
246    self.CreateNewSa(remoteaddr, myaddr, in_spi, in_reqid, encaptmpl,
247                     use_null_auth)
248
249    # Apply socket policies to s.
250    xfrm_base.ApplySocketPolicy(s, AF_INET, xfrm.XFRM_POLICY_OUT, out_spi,
251                                out_reqid, None)
252
253    # TODO: why does this work without a per-socket policy applied?
254    # The received  packet obviously matches an SA, but don't inbound packets
255    # need to match a policy as well? (b/71541609)
256    xfrm_base.ApplySocketPolicy(s, AF_INET, xfrm.XFRM_POLICY_IN, in_spi,
257                                in_reqid, None)
258
259    # Uncomment for debugging.
260    # subprocess.call("ip xfrm state".split())
261
262  # Check that packets can be sent and received.
263  def _VerifyUdpEncapSocket(self, netid, remoteaddr, myaddr, encap_port, sock,
264                           in_spi, out_spi, null_auth, seq_num):
265    # Now send a packet.
266    sock.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
267    srcport = sock.getsockname()[1]
268
269    # Expect to see an UDP encapsulated packet.
270    pkts = self.ReadAllPacketsOn(netid)
271    self.assertEqual(1, len(pkts))
272    packet = pkts[0]
273
274    auth_algo = (
275        xfrm_base._ALGO_AUTH_NULL if null_auth else xfrm_base._ALGO_HMAC_SHA1)
276    expected_len = xfrm_base.GetEspPacketLength(
277        xfrm.XFRM_MODE_TRANSPORT, 4, True, net_test.UDP_PAYLOAD, auth_algo,
278        xfrm_base._ALGO_CBC_AES_256)
279    self.assertIsUdpEncapEsp(packet, out_spi, seq_num, expected_len)
280
281    # Now test the receive path. Because we don't know how to decrypt packets,
282    # we just play back the encrypted packet that kernel sent earlier. We swap
283    # the addresses in the IP header to make the packet look like it's bound for
284    # us, but we can't do that for the port numbers because the UDP header is
285    # part of the integrity protected payload, which we can only replay as is.
286    # So the source and destination ports are swapped and the packet appears to
287    # be sent from srcport to port 53. Open another socket on that port, and
288    # apply the inbound policy to it.
289    twisted_socket = socket(AF_INET, SOCK_DGRAM, 0)
290    csocket.SetSocketTimeout(twisted_socket, 100)
291    twisted_socket.bind(("0.0.0.0", 53))
292
293    # Save the payload of the packet so we can replay it back to ourselves, and
294    # replace the SPI with our inbound SPI.
295    payload = str(packet.payload)[8:]
296    spi_seq = xfrm.EspHdr((in_spi, seq_num)).Pack()
297    payload = spi_seq + payload[len(spi_seq):]
298
299    sainfo = self.xfrm.FindSaInfo(in_spi)
300    start_integrity_failures = sainfo.stats.integrity_failed
301
302    # Now play back the valid packet and check that we receive it.
303    incoming = (scapy.IP(src=remoteaddr, dst=myaddr) /
304                scapy.UDP(sport=4500, dport=encap_port) / payload)
305    incoming = scapy.IP(str(incoming))
306    self.ReceivePacketOn(netid, incoming)
307
308    sainfo = self.xfrm.FindSaInfo(in_spi)
309
310    # TODO: break this out into a separate test
311    # If our SPIs are different, and we aren't using null authentication,
312    # we expect the packet to be dropped. We also expect that the integrity
313    # failure counter to increase, as SPIs are part of the authenticated or
314    # integrity-verified portion of the packet.
315    if not null_auth and in_spi != out_spi:
316      self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096)
317      self.assertEqual(start_integrity_failures + 1,
318                        sainfo.stats.integrity_failed)
319    else:
320      data, src = twisted_socket.recvfrom(4096)
321      self.assertEqual(net_test.UDP_PAYLOAD, data)
322      self.assertEqual((remoteaddr, srcport), src)
323      self.assertEqual(start_integrity_failures, sainfo.stats.integrity_failed)
324
325    # Check that unencrypted packets on twisted_socket are not received.
326    unencrypted = (
327        scapy.IP(src=remoteaddr, dst=myaddr) / scapy.UDP(
328            sport=srcport, dport=53) / net_test.UDP_PAYLOAD)
329    self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096)
330
331  def _RunEncapSocketPolicyTest(self, in_spi, out_spi, use_null_auth):
332    netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
333        self._SetupUdpEncapSockets()
334
335    self._SetupUdpEncapSaPair(myaddr, remoteaddr, in_spi, out_spi, encap_port,
336                              s, use_null_auth)
337
338    # Check that UDP encap sockets work with socket policy and given SAs
339    self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s, in_spi,
340                               out_spi, use_null_auth, 1)
341
342  # TODO: Add tests for ESP (non-encap) sockets.
343  def testUdpEncapSameSpisNullAuth(self):
344    # Use the same SPI both inbound and outbound because this lets us receive
345    # encrypted packets by simply replaying the packets the kernel sends
346    # without having to disable authentication
347    self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI, True)
348
349  def testUdpEncapSameSpis(self):
350    self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI, False)
351
352  def testUdpEncapDifferentSpisNullAuth(self):
353    self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI2, True)
354
355  def testUdpEncapDifferentSpis(self):
356    self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI2, False)
357
358  def testUdpEncapRekey(self):
359    # Select the two SPIs that will be used
360    start_spi = TEST_SPI
361    rekey_spi = TEST_SPI2
362
363    # Setup sockets
364    netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
365        self._SetupUdpEncapSockets()
366
367    # The SAs must use null authentication, since we change SPIs on the fly
368    # Without null authentication, this would result in an ESP authentication
369    # error since the SPI is part of the authenticated section. The packet
370    # would then be dropped
371    self._SetupUdpEncapSaPair(myaddr, remoteaddr, start_spi, start_spi,
372                              encap_port, s, True)
373
374    # Check that UDP encap sockets work with socket policy and given SAs
375    self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
376                               start_spi, start_spi, True, 1)
377
378    # Rekey this socket using the make-before-break paradigm. First we create
379    # new SAs, update the per-socket policies, and only then remove the old SAs
380    #
381    # This allows us to switch to the new SA without breaking the outbound path.
382    self._SetupUdpEncapSaPair(myaddr, remoteaddr, rekey_spi, rekey_spi,
383                              encap_port, s, True)
384
385    # Check that UDP encap socket works with updated socket policy, sending
386    # using new SA, but receiving on both old and new SAs
387    self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
388                               rekey_spi, rekey_spi, True, 1)
389    self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
390                               start_spi, rekey_spi, True, 2)
391
392    # Delete old SAs
393    self.xfrm.DeleteSaInfo(remoteaddr, start_spi, IPPROTO_ESP)
394    self.xfrm.DeleteSaInfo(myaddr, start_spi, IPPROTO_ESP)
395
396    # Check that UDP encap socket works with updated socket policy and new SAs
397    self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
398                               rekey_spi, rekey_spi, True, 3)
399
400  def testAllocSpecificSpi(self):
401    spi = 0xABCD
402    new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
403    self.assertEqual(spi, new_sa.id.spi)
404
405  def testAllocSpecificSpiUnavailable(self):
406    """Attempt to allocate the same SPI twice."""
407    spi = 0xABCD
408    new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
409    self.assertEqual(spi, new_sa.id.spi)
410    with self.assertRaisesErrno(ENOENT):
411      new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
412
413  def testAllocRangeSpi(self):
414    start, end = 0xABCD0, 0xABCDF
415    new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end)
416    spi = new_sa.id.spi
417    self.assertGreaterEqual(spi, start)
418    self.assertLessEqual(spi, end)
419
420  def testAllocRangeSpiUnavailable(self):
421    """Attempt to allocate N+1 SPIs from a range of size N."""
422    start, end = 0xABCD0, 0xABCDF
423    range_size = end - start + 1
424    spis = set()
425    # Assert that allocating SPI fails when none are available.
426    with self.assertRaisesErrno(ENOENT):
427      # Allocating range_size + 1 SPIs is guaranteed to fail.  Due to the way
428      # kernel picks random SPIs, this has a high probability of failing before
429      # reaching that limit.
430      for i in range(range_size + 1):
431        new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end)
432        spi = new_sa.id.spi
433        self.assertNotIn(spi, spis)
434        spis.add(spi)
435
436  def testSocketPolicyDstCacheV6(self):
437    self._TestSocketPolicyDstCache(6)
438
439  def testSocketPolicyDstCacheV4(self):
440    self._TestSocketPolicyDstCache(4)
441
442  def _TestSocketPolicyDstCache(self, version):
443    """Test that destination cache is cleared with socket policy.
444
445    This relies on the fact that connect() on a UDP socket populates the
446    destination cache.
447    """
448
449    # Create UDP socket.
450    family = net_test.GetAddressFamily(version)
451    netid = self.RandomNetid()
452    s = socket(family, SOCK_DGRAM, 0)
453    self.SelectInterface(s, netid, "mark")
454
455    # Populate the socket's destination cache.
456    remote = self.GetRemoteAddress(version)
457    s.connect((remote, 53))
458
459    # Apply a policy to the socket. Should clear dst cache.
460    reqid = 123
461    xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT,
462                                TEST_SPI, reqid, None)
463
464    # Policy with no matching SA should result in EAGAIN. If destination cache
465    # failed to clear, then the UDP packet will be sent normally.
466    with self.assertRaisesErrno(EAGAIN):
467      s.send(net_test.UDP_PAYLOAD)
468    self.ExpectNoPacketsOn(netid, "Packet not blocked by policy")
469
470  def _CheckNullEncryptionTunnelMode(self, version):
471    family = net_test.GetAddressFamily(version)
472    netid = self.RandomNetid()
473    local_addr = self.MyAddress(version, netid)
474    remote_addr = self.GetRemoteAddress(version)
475
476    # Borrow the address of another netId as the source address of the tunnel
477    tun_local = self.MyAddress(version, self.RandomNetid(netid))
478    # For generality, pick a tunnel endpoint that's not the address we
479    # connect the socket to.
480    tun_remote = TUNNEL_ENDPOINTS[version]
481
482    # Output
483    self.xfrm.AddSaInfo(
484        tun_local, tun_remote, 0xABCD, xfrm.XFRM_MODE_TUNNEL, 123,
485        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
486        None, None, None, netid)
487    # Input
488    self.xfrm.AddSaInfo(
489        tun_remote, tun_local, 0x9876, xfrm.XFRM_MODE_TUNNEL, 456,
490        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
491        None, None, None, None)
492
493    sock = net_test.UDPSocket(family)
494    self.SelectInterface(sock, netid, "mark")
495    sock.bind((local_addr, 0))
496    local_port = sock.getsockname()[1]
497    remote_port = 5555
498
499    xfrm_base.ApplySocketPolicy(
500        sock, family, xfrm.XFRM_POLICY_OUT, 0xABCD, 123,
501        (tun_local, tun_remote))
502    xfrm_base.ApplySocketPolicy(
503        sock, family, xfrm.XFRM_POLICY_IN, 0x9876, 456,
504        (tun_remote, tun_local))
505
506    # Create and receive an ESP packet.
507    IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
508    input_pkt = (IpType(src=remote_addr, dst=local_addr) /
509                 scapy.UDP(sport=remote_port, dport=local_port) /
510                 "input hello")
511    input_pkt = IpType(str(input_pkt)) # Compute length, checksum.
512    input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, 0x9876,
513                                                1, (tun_remote, tun_local))
514
515    self.ReceivePacketOn(netid, input_pkt)
516    msg, addr = sock.recvfrom(1024)
517    self.assertEqual("input hello", msg)
518    self.assertEqual((remote_addr, remote_port), addr[:2])
519
520    # Send and capture a packet.
521    sock.sendto("output hello", (remote_addr, remote_port))
522    packets = self.ReadAllPacketsOn(netid)
523    self.assertEqual(1, len(packets))
524    output_pkt = packets[0]
525    output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt)
526    self.assertEqual(output_pkt[scapy.UDP].len, len("output_hello") + 8)
527    self.assertEqual(remote_addr, output_pkt.dst)
528    self.assertEqual(remote_port, output_pkt[scapy.UDP].dport)
529    # length of the payload plus the UDP header
530    self.assertEqual("output hello", str(output_pkt[scapy.UDP].payload))
531    self.assertEqual(0xABCD, esp_hdr.spi)
532
533  def testNullEncryptionTunnelMode(self):
534    """Verify null encryption in tunnel mode.
535
536    This test verifies both manual assembly and disassembly of UDP packets
537    with ESP in IPsec tunnel mode.
538    """
539    for version in [4, 6]:
540      self._CheckNullEncryptionTunnelMode(version)
541
542  def _CheckNullEncryptionTransportMode(self, version):
543    family = net_test.GetAddressFamily(version)
544    netid = self.RandomNetid()
545    local_addr = self.MyAddress(version, netid)
546    remote_addr = self.GetRemoteAddress(version)
547
548    # Output
549    self.xfrm.AddSaInfo(
550        local_addr, remote_addr, 0xABCD, xfrm.XFRM_MODE_TRANSPORT, 123,
551        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
552        None, None, None, None)
553    # Input
554    self.xfrm.AddSaInfo(
555        remote_addr, local_addr, 0x9876, xfrm.XFRM_MODE_TRANSPORT, 456,
556        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
557        None, None, None, None)
558
559    sock = net_test.UDPSocket(family)
560    self.SelectInterface(sock, netid, "mark")
561    sock.bind((local_addr, 0))
562    local_port = sock.getsockname()[1]
563    remote_port = 5555
564
565    xfrm_base.ApplySocketPolicy(
566        sock, family, xfrm.XFRM_POLICY_OUT, 0xABCD, 123, None)
567    xfrm_base.ApplySocketPolicy(
568        sock, family, xfrm.XFRM_POLICY_IN, 0x9876, 456, None)
569
570    # Create and receive an ESP packet.
571    IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
572    input_pkt = (IpType(src=remote_addr, dst=local_addr) /
573                 scapy.UDP(sport=remote_port, dport=local_port) /
574                 "input hello")
575    input_pkt = IpType(str(input_pkt)) # Compute length, checksum.
576    input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, 0x9876, 1, None)
577
578    self.ReceivePacketOn(netid, input_pkt)
579    msg, addr = sock.recvfrom(1024)
580    self.assertEqual("input hello", msg)
581    self.assertEqual((remote_addr, remote_port), addr[:2])
582
583    # Send and capture a packet.
584    sock.sendto("output hello", (remote_addr, remote_port))
585    packets = self.ReadAllPacketsOn(netid)
586    self.assertEqual(1, len(packets))
587    output_pkt = packets[0]
588    output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt)
589    # length of the payload plus the UDP header
590    self.assertEqual(output_pkt[scapy.UDP].len, len("output_hello") + 8)
591    self.assertEqual(remote_addr, output_pkt.dst)
592    self.assertEqual(remote_port, output_pkt[scapy.UDP].dport)
593    self.assertEqual("output hello", str(output_pkt[scapy.UDP].payload))
594    self.assertEqual(0xABCD, esp_hdr.spi)
595
596  def testNullEncryptionTransportMode(self):
597    """Verify null encryption in transport mode.
598
599    This test verifies both manual assembly and disassembly of UDP packets
600    with ESP in IPsec transport mode.
601    """
602    for version in [4, 6]:
603      self._CheckNullEncryptionTransportMode(version)
604
605  def _CheckGlobalPoliciesByMark(self, version):
606    """Tests that global policies may differ by only the mark."""
607    family = net_test.GetAddressFamily(version)
608    sel = xfrm.EmptySelector(family)
609    # Pick 2 arbitrary mark values.
610    mark1 = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
611    mark2 = xfrm.XfrmMark(mark=0xf00d, mask=xfrm_base.MARK_MASK_ALL)
612    # Create a global policy.
613    policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
614    tmpl = xfrm.UserTemplate(AF_UNSPEC, 0xfeed, 0, None)
615    # Create the policy with the first mark.
616    self.xfrm.AddPolicyInfo(policy, tmpl, mark1)
617    # Create the same policy but with the second (different) mark.
618    self.xfrm.AddPolicyInfo(policy, tmpl, mark2)
619    # Delete the policies individually
620    self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark1)
621    self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark2)
622
623  def testGlobalPoliciesByMarkV4(self):
624    self._CheckGlobalPoliciesByMark(4)
625
626  def testGlobalPoliciesByMarkV6(self):
627    self._CheckGlobalPoliciesByMark(6)
628
629  def _CheckUpdatePolicy(self, version):
630    """Tests that we can can update the template on a policy."""
631    family = net_test.GetAddressFamily(version)
632    tmpl1 = xfrm.UserTemplate(family, 0xdead, 0, None)
633    tmpl2 = xfrm.UserTemplate(family, 0xbeef, 0, None)
634    sel = xfrm.EmptySelector(family)
635    policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
636    mark = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
637
638    def _CheckTemplateMatch(tmpl):
639      """Dump the SPD and match a single template on a single policy."""
640      dump = self.xfrm.DumpPolicyInfo()
641      self.assertEqual(1, len(dump))
642      _, attributes = dump[0]
643      self.assertEqual(attributes['XFRMA_TMPL'], tmpl)
644
645    # Create a new policy using update.
646    self.xfrm.UpdatePolicyInfo(policy, tmpl1, mark, None)
647    # NEWPOLICY will not update the existing policy. This checks both that
648    # UPDPOLICY created a policy and that NEWPOLICY will not perform updates.
649    _CheckTemplateMatch(tmpl1)
650    with self.assertRaisesErrno(EEXIST):
651      self.xfrm.AddPolicyInfo(policy, tmpl2, mark, None)
652    # Update the policy using UPDPOLICY.
653    self.xfrm.UpdatePolicyInfo(policy, tmpl2, mark, None)
654    # There should only be one policy after update, and it should have the
655    # updated template.
656    _CheckTemplateMatch(tmpl2)
657
658  def testUpdatePolicyV4(self):
659    self._CheckUpdatePolicy(4)
660
661  def testUpdatePolicyV6(self):
662    self._CheckUpdatePolicy(6)
663
664  def _CheckPolicyDifferByDirection(self,version):
665    """Tests that policies can differ only by direction."""
666    family = net_test.GetAddressFamily(version)
667    tmpl = xfrm.UserTemplate(family, 0xdead, 0, None)
668    sel = xfrm.EmptySelector(family)
669    mark = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
670    policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
671    self.xfrm.AddPolicyInfo(policy, tmpl, mark)
672    policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_IN, sel)
673    self.xfrm.AddPolicyInfo(policy, tmpl, mark)
674
675  def testPolicyDifferByDirectionV4(self):
676    self._CheckPolicyDifferByDirection(4)
677
678  def testPolicyDifferByDirectionV6(self):
679    self._CheckPolicyDifferByDirection(6)
680
681class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest):
682
683  def _CheckTunnelModeOutputMark(self, version, tunsrc, mark, expected_netid):
684    """Tests sending UDP packets to tunnel mode SAs with output marks.
685
686    Opens a UDP socket and binds it to a random netid, then sets up tunnel mode
687    SAs with an output_mark of mark and sets a socket policy to use the SA.
688    Then checks that sending on those SAs sends a packet on expected_netid,
689    or, if expected_netid is zero, checks that sending returns ENETUNREACH.
690
691    Args:
692      version: 4 or 6.
693      tunsrc: A string, the source address of the tunnel.
694      mark: An integer, the output_mark to set in the SA.
695      expected_netid: An integer, the netid to expect the kernel to send the
696          packet on. If None, expect that sendto will fail with ENETUNREACH.
697    """
698    # Open a UDP socket and bind it to a random netid.
699    family = net_test.GetAddressFamily(version)
700    s = socket(family, SOCK_DGRAM, 0)
701    self.SelectInterface(s, self.RandomNetid(), "mark")
702
703    # For generality, pick a tunnel endpoint that's not the address we
704    # connect the socket to.
705    tundst = TUNNEL_ENDPOINTS[version]
706    tun_addrs = (tunsrc, tundst)
707
708    # Create a tunnel mode SA and use XFRM_OUTPUT_MARK to bind it to netid.
709    spi = TEST_SPI * mark
710    reqid = 100 + spi
711    self.xfrm.AddSaInfo(tunsrc, tundst, spi, xfrm.XFRM_MODE_TUNNEL, reqid,
712                        xfrm_base._ALGO_CBC_AES_256, xfrm_base._ALGO_HMAC_SHA1,
713                        None, None, None, mark)
714
715    # Set a socket policy to use it.
716    xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT, spi, reqid,
717                                tun_addrs)
718
719    # Send a packet and check that we see it on the wire.
720    remoteaddr = self.GetRemoteAddress(version)
721
722    packetlen = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TUNNEL, version,
723                                             False, net_test.UDP_PAYLOAD,
724                                             xfrm_base._ALGO_HMAC_SHA1,
725                                             xfrm_base._ALGO_CBC_AES_256)
726
727    if expected_netid is not None:
728      s.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
729      self._ExpectEspPacketOn(expected_netid, spi, 1, packetlen, tunsrc, tundst)
730    else:
731      with self.assertRaisesErrno(ENETUNREACH):
732        s.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
733
734  def testTunnelModeOutputMarkIPv4(self):
735    for netid in self.NETIDS:
736      tunsrc = self.MyAddress(4, netid)
737      self._CheckTunnelModeOutputMark(4, tunsrc, netid, netid)
738
739  def testTunnelModeOutputMarkIPv6(self):
740    for netid in self.NETIDS:
741      tunsrc = self.MyAddress(6, netid)
742      self._CheckTunnelModeOutputMark(6, tunsrc, netid, netid)
743
744  def testTunnelModeOutputNoMarkIPv4(self):
745    tunsrc = self.MyAddress(4, self.RandomNetid())
746    self._CheckTunnelModeOutputMark(4, tunsrc, 0, None)
747
748  def testTunnelModeOutputNoMarkIPv6(self):
749    tunsrc = self.MyAddress(6, self.RandomNetid())
750    self._CheckTunnelModeOutputMark(6, tunsrc, 0, None)
751
752  def testTunnelModeOutputInvalidMarkIPv4(self):
753    tunsrc = self.MyAddress(4, self.RandomNetid())
754    self._CheckTunnelModeOutputMark(4, tunsrc, 9999, None)
755
756  def testTunnelModeOutputInvalidMarkIPv6(self):
757    tunsrc = self.MyAddress(6, self.RandomNetid())
758    self._CheckTunnelModeOutputMark(6, tunsrc, 9999, None)
759
760  def testTunnelModeOutputMarkAttributes(self):
761    mark = 1234567
762    self.xfrm.AddSaInfo(TEST_ADDR1, TUNNEL_ENDPOINTS[6], 0x1234,
763                        xfrm.XFRM_MODE_TUNNEL, 100, xfrm_base._ALGO_CBC_AES_256,
764                        xfrm_base._ALGO_HMAC_SHA1, None, None, None, mark)
765    dump = self.xfrm.DumpSaInfo()
766    self.assertEqual(1, len(dump))
767    sainfo, attributes = dump[0]
768    self.assertEqual(mark, attributes["XFRMA_OUTPUT_MARK"])
769
770  def testInvalidAlgorithms(self):
771    key = "af442892cdcd0ef650e9c299f9a8436a".decode("hex")
772    invalid_auth = (xfrm.XfrmAlgoAuth(("invalid(algo)", 128, 96)), key)
773    invalid_crypt = (xfrm.XfrmAlgo(("invalid(algo)", 128)), key)
774    with self.assertRaisesErrno(ENOSYS):
775        self.xfrm.AddSaInfo(TEST_ADDR1, TEST_ADDR2, 0x1234,
776            xfrm.XFRM_MODE_TRANSPORT, 0, xfrm_base._ALGO_CBC_AES_256,
777            invalid_auth, None, None, None, 0)
778    with self.assertRaisesErrno(ENOSYS):
779        self.xfrm.AddSaInfo(TEST_ADDR1, TEST_ADDR2, 0x1234,
780            xfrm.XFRM_MODE_TRANSPORT, 0, invalid_crypt,
781            xfrm_base._ALGO_HMAC_SHA1, None, None, None, 0)
782
783  def testUpdateSaAddMark(self):
784    """Test that an embryonic SA can be updated to add a mark."""
785    for version in [4, 6]:
786      spi = 0xABCD
787      # Test that an SA created with ALLOCSPI can be updated with the mark.
788      new_sa = self.xfrm.AllocSpi(net_test.GetWildcardAddress(version),
789                                  IPPROTO_ESP, spi, spi)
790      mark = xfrm.ExactMatchMark(0xf00d)
791      self.xfrm.AddSaInfo(net_test.GetWildcardAddress(version),
792                          net_test.GetWildcardAddress(version),
793                          spi, xfrm.XFRM_MODE_TUNNEL, 0,
794                          xfrm_base._ALGO_CBC_AES_256,
795                          xfrm_base._ALGO_HMAC_SHA1,
796                          None, None, mark, 0, is_update=True)
797      dump = self.xfrm.DumpSaInfo()
798      self.assertEqual(1, len(dump)) # check that update updated
799      sainfo, attributes = dump[0]
800      self.assertEqual(mark, attributes["XFRMA_MARK"])
801      self.xfrm.DeleteSaInfo(net_test.GetWildcardAddress(version),
802                             spi, IPPROTO_ESP, mark)
803
804  def getXfrmStat(self, statName):
805    stateVal = 0
806    with open(XFRM_STATS_PROCFILE, 'r') as f:
807      for line in f:
808          if statName in line:
809            stateVal = int(line.split()[1])
810            break
811      f.close()
812    return stateVal
813
814  def testUpdateActiveSaMarks(self):
815    """Test that the OUTPUT_MARK can be updated on an ACTIVE SA."""
816    for version in [4, 6]:
817      family = net_test.GetAddressFamily(version)
818      netid = self.RandomNetid()
819      remote = self.GetRemoteAddress(version)
820      local = self.MyAddress(version, netid)
821      s = socket(family, SOCK_DGRAM, 0)
822      self.SelectInterface(s, netid, "mark")
823      # Create a mark that we will apply to the policy and later the SA
824      mark = xfrm.ExactMatchMark(netid)
825
826      # Create a global policy that selects using the mark.
827      sel = xfrm.EmptySelector(family)
828      policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
829      tmpl = xfrm.UserTemplate(family, 0, 0, (local, remote))
830      self.xfrm.AddPolicyInfo(policy, tmpl, mark)
831
832      # Pull /proc/net/xfrm_stats for baseline
833      outNoStateCount = self.getXfrmStat(XFRM_STATS_OUT_NO_STATES);
834
835      # should increment XfrmOutNoStates
836      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
837
838      # Check to make sure XfrmOutNoStates is incremented by exactly 1
839      self.assertEqual(outNoStateCount + 1,
840                        self.getXfrmStat(XFRM_STATS_OUT_NO_STATES))
841
842      length = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TUNNEL,
843                                            version, False,
844                                            net_test.UDP_PAYLOAD,
845                                            xfrm_base._ALGO_HMAC_SHA1,
846                                            xfrm_base._ALGO_CBC_AES_256)
847
848      # Add a default SA with no mark that routes to nowhere.
849      try:
850          self.xfrm.AddSaInfo(local,
851                              remote,
852                              TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
853                              xfrm_base._ALGO_CBC_AES_256,
854                              xfrm_base._ALGO_HMAC_SHA1,
855                              None, None, mark, 0, is_update=False)
856      except IOError as e:
857          self.assertEqual(EEXIST, e.errno, "SA exists")
858          self.xfrm.AddSaInfo(local,
859                              remote,
860                              TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
861                              xfrm_base._ALGO_CBC_AES_256,
862                              xfrm_base._ALGO_HMAC_SHA1,
863                              None, None, mark, 0, is_update=True)
864
865      self.assertRaisesErrno(
866          ENETUNREACH,
867          s.sendto, net_test.UDP_PAYLOAD, (remote, 53))
868
869      # Update the SA to route to a valid netid.
870      self.xfrm.AddSaInfo(local,
871                          remote,
872                          TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
873                          xfrm_base._ALGO_CBC_AES_256,
874                          xfrm_base._ALGO_HMAC_SHA1,
875                          None, None, mark, netid, is_update=True)
876
877      # Now the payload routes to the updated netid.
878      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
879      self._ExpectEspPacketOn(netid, TEST_SPI, 1, length, None, None)
880
881      # Get a new netid and reroute the packets to the new netid.
882      reroute_netid = self.RandomNetid(netid)
883      # Update the SA to change the output mark.
884      self.xfrm.AddSaInfo(local,
885                         remote,
886                         TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
887                         xfrm_base._ALGO_CBC_AES_256,
888                         xfrm_base._ALGO_HMAC_SHA1,
889                         None, None, mark, reroute_netid, is_update=True)
890
891      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
892      self._ExpectEspPacketOn(reroute_netid, TEST_SPI, 2, length, None, None)
893
894      dump = self.xfrm.DumpSaInfo()
895
896      self.assertEqual(1, len(dump)) # check that update updated
897      sainfo, attributes = dump[0]
898      self.assertEqual(reroute_netid, attributes["XFRMA_OUTPUT_MARK"])
899
900      self.xfrm.DeleteSaInfo(remote, TEST_SPI, IPPROTO_ESP, mark)
901      self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark)
902
903if __name__ == "__main__":
904  unittest.main()
905