• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19from socket import *  # pylint: disable=wildcard-import
20
21import random
22import itertools
23import struct
24import unittest
25
26from scapy import all as scapy
27from tun_twister import TunTwister
28import csocket
29import iproute
30import multinetwork_base
31import net_test
32import packets
33import util
34import xfrm
35import xfrm_base
36
37_LOOPBACK_IFINDEX = 1
38_TEST_XFRM_IFNAME = "ipsec42"
39_TEST_XFRM_IF_ID = 42
40
41# Does the kernel support xfrmi interfaces?
42def HaveXfrmInterfaces():
43  try:
44    i = iproute.IPRoute()
45    i.CreateXfrmInterface(_TEST_XFRM_IFNAME, _TEST_XFRM_IF_ID,
46                          _LOOPBACK_IFINDEX)
47    i.DeleteLink(_TEST_XFRM_IFNAME)
48    try:
49      i.GetIfIndex(_TEST_XFRM_IFNAME)
50      assert "Deleted interface %s still exists!" % _TEST_XFRM_IFNAME
51    except IOError:
52      pass
53    return True
54  except IOError:
55    return False
56
57HAVE_XFRM_INTERFACES = HaveXfrmInterfaces()
58
59# Parameters to setup tunnels as special networks
60_TUNNEL_NETID_OFFSET = 0xFC00  # Matches reserved netid range for IpSecService
61_BASE_TUNNEL_NETID = {4: 40, 6: 60}
62_BASE_VTI_OKEY = 2000000100
63_BASE_VTI_IKEY = 2000000200
64
65_TEST_OUT_SPI = 0x1234
66_TEST_IN_SPI = _TEST_OUT_SPI
67
68_TEST_OKEY = 2000000100
69_TEST_IKEY = 2000000200
70
71_TEST_REMOTE_PORT = 1234
72
73_SCAPY_IP_TYPE = {4: scapy.IP, 6: scapy.IPv6}
74
75
76def _GetLocalInnerAddress(version):
77  return {4: "10.16.5.15", 6: "2001:db8:1::1"}[version]
78
79
80def _GetRemoteInnerAddress(version):
81  return {4: "10.16.5.20", 6: "2001:db8:2::1"}[version]
82
83
84def _GetRemoteOuterAddress(version):
85  return {4: net_test.IPV4_ADDR, 6: net_test.IPV6_ADDR}[version]
86
87
88def _GetNullAuthCryptTunnelModePkt(inner_version, src_inner, src_outer,
89                                   src_port, dst_inner, dst_outer,
90                                   dst_port, spi, seq_num, ip_hdr_options=None):
91  if ip_hdr_options is None:
92    ip_hdr_options = {}
93
94  ip_hdr_options.update({'src': src_inner, 'dst': dst_inner})
95
96  # Build and receive an ESP packet destined for the inner socket
97  IpType = {4: scapy.IP, 6: scapy.IPv6}[inner_version]
98  input_pkt = (
99      IpType(**ip_hdr_options) / scapy.UDP(sport=src_port, dport=dst_port) /
100      net_test.UDP_PAYLOAD)
101  input_pkt = IpType(str(input_pkt))  # Compute length, checksum.
102  input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, spi, seq_num,
103                                              (src_outer, dst_outer))
104
105  return input_pkt
106
107
108def _CreateReceiveSock(version, port=0):
109  # Create a socket to receive packets.
110  read_sock = socket(net_test.GetAddressFamily(version), SOCK_DGRAM, 0)
111  read_sock.bind((net_test.GetWildcardAddress(version), port))
112  # The second parameter of the tuple is the port number regardless of AF.
113  local_port = read_sock.getsockname()[1]
114  # Guard against the eventuality of the receive failing.
115  csocket.SetSocketTimeout(read_sock, 500)
116
117  return read_sock, local_port
118
119
120def _SendPacket(testInstance, netid, version, remote, remote_port):
121  # Send a packet out via the tunnel-backed network, bound for the port number
122  # of the input socket.
123  write_sock = socket(net_test.GetAddressFamily(version), SOCK_DGRAM, 0)
124  testInstance.SelectInterface(write_sock, netid, "mark")
125  write_sock.sendto(net_test.UDP_PAYLOAD, (remote, remote_port))
126  local_port = write_sock.getsockname()[1]
127
128  return local_port
129
130
131def InjectTests():
132  InjectParameterizedTests(XfrmTunnelTest)
133  InjectParameterizedTests(XfrmInterfaceTest)
134  InjectParameterizedTests(XfrmVtiTest)
135
136
137def InjectParameterizedTests(cls):
138  VERSIONS = (4, 6)
139  param_list = itertools.product(VERSIONS, VERSIONS)
140
141  def NameGenerator(*args):
142    return "IPv%d_in_IPv%d" % tuple(args)
143
144  util.InjectParameterizedTest(cls, param_list, NameGenerator)
145
146
147class XfrmTunnelTest(xfrm_base.XfrmLazyTest):
148
149  def _CheckTunnelOutput(self, inner_version, outer_version, underlying_netid,
150                         netid, local_inner, remote_inner, local_outer,
151                         remote_outer, write_sock):
152
153    write_sock.sendto(net_test.UDP_PAYLOAD, (remote_inner, 53))
154    self._ExpectEspPacketOn(underlying_netid, _TEST_OUT_SPI, 1, None,
155                            local_outer, remote_outer)
156
157  def _CheckTunnelInput(self, inner_version, outer_version, underlying_netid,
158                        netid, local_inner, remote_inner, local_outer,
159                        remote_outer, read_sock):
160
161    # The second parameter of the tuple is the port number regardless of AF.
162    local_port = read_sock.getsockname()[1]
163
164    # Build and receive an ESP packet destined for the inner socket
165    input_pkt = _GetNullAuthCryptTunnelModePkt(
166        inner_version, remote_inner, remote_outer, _TEST_REMOTE_PORT,
167        local_inner, local_outer, local_port, _TEST_IN_SPI, 1)
168    self.ReceivePacketOn(underlying_netid, input_pkt)
169
170    # Verify that the packet data and src are correct
171    data, src = read_sock.recvfrom(4096)
172    self.assertEquals(net_test.UDP_PAYLOAD, data)
173    self.assertEquals((remote_inner, _TEST_REMOTE_PORT), src[:2])
174
175  def _TestTunnel(self, inner_version, outer_version, func, direction,
176                  test_output_mark_unset):
177    """Test a unidirectional XFRM Tunnel with explicit selectors"""
178    # Select the underlying netid, which represents the external
179    # interface from/to which to route ESP packets.
180    u_netid = self.RandomNetid()
181    # Select a random netid that will originate traffic locally and
182    # which represents the netid on which the plaintext is sent
183    netid = self.RandomNetid(exclude=u_netid)
184
185    local_inner = self.MyAddress(inner_version, netid)
186    remote_inner = _GetRemoteInnerAddress(inner_version)
187    local_outer = self.MyAddress(outer_version, u_netid)
188    remote_outer = _GetRemoteOuterAddress(outer_version)
189
190    output_mark = u_netid
191    if test_output_mark_unset:
192      output_mark = None
193      self.SetDefaultNetwork(u_netid)
194
195    try:
196      # Create input/ouput SPs, SAs and sockets to simulate a more realistic
197      # environment.
198      self.xfrm.CreateTunnel(
199          xfrm.XFRM_POLICY_IN, xfrm.SrcDstSelector(remote_inner, local_inner),
200          remote_outer, local_outer, _TEST_IN_SPI, xfrm_base._ALGO_CRYPT_NULL,
201          xfrm_base._ALGO_AUTH_NULL, None, None, None, xfrm.MATCH_METHOD_ALL)
202
203      self.xfrm.CreateTunnel(
204          xfrm.XFRM_POLICY_OUT, xfrm.SrcDstSelector(local_inner, remote_inner),
205          local_outer, remote_outer, _TEST_OUT_SPI, xfrm_base._ALGO_CBC_AES_256,
206          xfrm_base._ALGO_HMAC_SHA1, None, output_mark, None, xfrm.MATCH_METHOD_ALL)
207
208      write_sock = socket(net_test.GetAddressFamily(inner_version), SOCK_DGRAM, 0)
209      self.SelectInterface(write_sock, netid, "mark")
210      read_sock, _ = _CreateReceiveSock(inner_version)
211
212      sock = write_sock if direction == xfrm.XFRM_POLICY_OUT else read_sock
213      func(inner_version, outer_version, u_netid, netid, local_inner,
214          remote_inner, local_outer, remote_outer, sock)
215    finally:
216      if test_output_mark_unset:
217        self.ClearDefaultNetwork()
218
219  def ParamTestTunnelInput(self, inner_version, outer_version):
220    self._TestTunnel(inner_version, outer_version, self._CheckTunnelInput,
221                     xfrm.XFRM_POLICY_IN, False)
222
223  def ParamTestTunnelOutput(self, inner_version, outer_version):
224    self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
225                     xfrm.XFRM_POLICY_OUT, False)
226
227  def ParamTestTunnelOutputNoSetMark(self, inner_version, outer_version):
228    self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
229                     xfrm.XFRM_POLICY_OUT, True)
230
231
232@unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
233class XfrmAddDeleteVtiTest(xfrm_base.XfrmBaseTest):
234  def _VerifyVtiInfoData(self, vti_info_data, version, local_addr, remote_addr,
235                         ikey, okey):
236    self.assertEquals(vti_info_data["IFLA_VTI_IKEY"], ikey)
237    self.assertEquals(vti_info_data["IFLA_VTI_OKEY"], okey)
238
239    family = AF_INET if version == 4 else AF_INET6
240    self.assertEquals(inet_ntop(family, vti_info_data["IFLA_VTI_LOCAL"]),
241                      local_addr)
242    self.assertEquals(inet_ntop(family, vti_info_data["IFLA_VTI_REMOTE"]),
243                      remote_addr)
244
245  def testAddVti(self):
246    """Test the creation of a Virtual Tunnel Interface."""
247    for version in [4, 6]:
248      netid = self.RandomNetid()
249      local_addr = self.MyAddress(version, netid)
250      self.iproute.CreateVirtualTunnelInterface(
251          dev_name=_TEST_XFRM_IFNAME,
252          local_addr=local_addr,
253          remote_addr=_GetRemoteOuterAddress(version),
254          o_key=_TEST_OKEY,
255          i_key=_TEST_IKEY)
256      self._VerifyVtiInfoData(
257          self.iproute.GetIfinfoData(_TEST_XFRM_IFNAME), version, local_addr,
258          _GetRemoteOuterAddress(version), _TEST_IKEY, _TEST_OKEY)
259
260      new_remote_addr = {4: net_test.IPV4_ADDR2, 6: net_test.IPV6_ADDR2}
261      new_okey = _TEST_OKEY + _TEST_XFRM_IF_ID
262      new_ikey = _TEST_IKEY + _TEST_XFRM_IF_ID
263      self.iproute.CreateVirtualTunnelInterface(
264          dev_name=_TEST_XFRM_IFNAME,
265          local_addr=local_addr,
266          remote_addr=new_remote_addr[version],
267          o_key=new_okey,
268          i_key=new_ikey,
269          is_update=True)
270
271      self._VerifyVtiInfoData(
272          self.iproute.GetIfinfoData(_TEST_XFRM_IFNAME), version, local_addr,
273          new_remote_addr[version], new_ikey, new_okey)
274
275      if_index = self.iproute.GetIfIndex(_TEST_XFRM_IFNAME)
276
277      # Validate that the netlink interface matches the ioctl interface.
278      self.assertEquals(net_test.GetInterfaceIndex(_TEST_XFRM_IFNAME), if_index)
279      self.iproute.DeleteLink(_TEST_XFRM_IFNAME)
280      with self.assertRaises(IOError):
281        self.iproute.GetIfIndex(_TEST_XFRM_IFNAME)
282
283  def _QuietDeleteLink(self, ifname):
284    try:
285      self.iproute.DeleteLink(ifname)
286    except IOError:
287      # The link was not present.
288      pass
289
290  def tearDown(self):
291    super(XfrmAddDeleteVtiTest, self).tearDown()
292    self._QuietDeleteLink(_TEST_XFRM_IFNAME)
293
294
295class SaInfo(object):
296
297  def __init__(self, spi):
298    self.spi = spi
299    self.seq_num = 1
300
301
302class IpSecBaseInterface(object):
303
304  def __init__(self, iface, netid, underlying_netid, local, remote, version):
305    self.iface = iface
306    self.netid = netid
307    self.underlying_netid = underlying_netid
308    self.local, self.remote = local, remote
309
310    # XFRM interfaces technically do not have a version. This keeps track of
311    # the IP version of the local and remote addresses.
312    self.version = version
313    self.rx = self.tx = 0
314    self.addrs = {}
315
316    self.iproute = iproute.IPRoute()
317    self.xfrm = xfrm.Xfrm()
318
319  def Teardown(self):
320    self.TeardownXfrm()
321    self.TeardownInterface()
322
323  def TeardownInterface(self):
324    self.iproute.DeleteLink(self.iface)
325
326  def SetupXfrm(self, use_null_crypt):
327    rand_spi = random.randint(0, 0x7fffffff)
328    self.in_sa = SaInfo(rand_spi)
329    self.out_sa = SaInfo(rand_spi)
330
331    # Select algorithms:
332    if use_null_crypt:
333      auth, crypt = xfrm_base._ALGO_AUTH_NULL, xfrm_base._ALGO_CRYPT_NULL
334    else:
335      auth, crypt = xfrm_base._ALGO_HMAC_SHA1, xfrm_base._ALGO_CBC_AES_256
336
337    self._SetupXfrmByType(auth, crypt)
338
339  def Rekey(self, outer_family, new_out_sa, new_in_sa):
340    """Rekeys the Tunnel Interface
341
342    Creates new SAs and updates the outbound security policy to use new SAs.
343
344    Args:
345      outer_family: AF_INET or AF_INET6
346      new_out_sa: An SaInfo struct representing the new outbound SA's info
347      new_in_sa: An SaInfo struct representing the new inbound SA's info
348    """
349    self._Rekey(outer_family, new_out_sa, new_in_sa)
350
351    # Update Interface object
352    self.out_sa = new_out_sa
353    self.in_sa = new_in_sa
354
355  def TeardownXfrm(self):
356    raise NotImplementedError("Subclasses should implement this")
357
358  def _SetupXfrmByType(self, auth_algo, crypt_algo):
359    raise NotImplementedError("Subclasses should implement this")
360
361  def _Rekey(self, outer_family, new_out_sa, new_in_sa):
362    raise NotImplementedError("Subclasses should implement this")
363
364
365class VtiInterface(IpSecBaseInterface):
366
367  def __init__(self, iface, netid, underlying_netid, _, local, remote, version):
368    super(VtiInterface, self).__init__(iface, netid, underlying_netid, local,
369                                       remote, version)
370
371    self.ikey = _TEST_IKEY + netid
372    self.okey = _TEST_OKEY + netid
373
374    self.SetupInterface()
375    self.SetupXfrm(False)
376
377  def SetupInterface(self):
378    return self.iproute.CreateVirtualTunnelInterface(
379        self.iface, self.local, self.remote, self.ikey, self.okey)
380
381  def _SetupXfrmByType(self, auth_algo, crypt_algo):
382    # For the VTI, the selectors are wildcard since packets will only
383    # be selected if they have the appropriate mark, hence the inner
384    # addresses are wildcard.
385    self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_OUT, None, self.local, self.remote,
386                           self.out_sa.spi, crypt_algo, auth_algo,
387                           xfrm.ExactMatchMark(self.okey),
388                           self.underlying_netid, None, xfrm.MATCH_METHOD_ALL)
389
390    self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_IN, None, self.remote, self.local,
391                           self.in_sa.spi, crypt_algo, auth_algo,
392                           xfrm.ExactMatchMark(self.ikey), None, None,
393                           xfrm.MATCH_METHOD_MARK)
394
395  def TeardownXfrm(self):
396    self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_OUT, None, self.remote,
397                           self.out_sa.spi, self.okey, None)
398    self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_IN, None, self.local,
399                           self.in_sa.spi, self.ikey, None)
400
401  def _Rekey(self, outer_family, new_out_sa, new_in_sa):
402    # TODO: Consider ways to share code with xfrm.CreateTunnel(). It's mostly
403    #       the same, but rekeys are asymmetric, and only update the outbound
404    #       policy.
405    self.xfrm.AddSaInfo(self.local, self.remote, new_out_sa.spi,
406                        xfrm.XFRM_MODE_TUNNEL, 0, xfrm_base._ALGO_CRYPT_NULL,
407                        xfrm_base._ALGO_AUTH_NULL, None, None,
408                        xfrm.ExactMatchMark(self.okey), self.underlying_netid)
409
410    self.xfrm.AddSaInfo(self.remote, self.local, new_in_sa.spi,
411                        xfrm.XFRM_MODE_TUNNEL, 0, xfrm_base._ALGO_CRYPT_NULL,
412                        xfrm_base._ALGO_AUTH_NULL, None, None,
413                        xfrm.ExactMatchMark(self.ikey), None)
414
415    # Create new policies for IPv4 and IPv6.
416    for sel in [xfrm.EmptySelector(AF_INET), xfrm.EmptySelector(AF_INET6)]:
417      # Add SPI-specific output policy to enforce using new outbound SPI
418      policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
419      tmpl = xfrm.UserTemplate(outer_family, new_out_sa.spi, 0,
420                                    (self.local, self.remote))
421      self.xfrm.UpdatePolicyInfo(policy, tmpl, xfrm.ExactMatchMark(self.okey),
422                                 0)
423
424  def DeleteOldSaInfo(self, outer_family, old_in_spi, old_out_spi):
425    self.xfrm.DeleteSaInfo(self.local, old_in_spi, IPPROTO_ESP,
426                           xfrm.ExactMatchMark(self.ikey))
427    self.xfrm.DeleteSaInfo(self.remote, old_out_spi, IPPROTO_ESP,
428                           xfrm.ExactMatchMark(self.okey))
429
430
431@unittest.skipUnless(HAVE_XFRM_INTERFACES, "XFRM interfaces unsupported")
432class XfrmAddDeleteXfrmInterfaceTest(xfrm_base.XfrmBaseTest):
433  """Test the creation of an XFRM Interface."""
434
435  def testAddXfrmInterface(self):
436    self.iproute.CreateXfrmInterface(_TEST_XFRM_IFNAME, _TEST_XFRM_IF_ID,
437                                     _LOOPBACK_IFINDEX)
438    if_index = self.iproute.GetIfIndex(_TEST_XFRM_IFNAME)
439    net_test.SetInterfaceUp(_TEST_XFRM_IFNAME)
440
441    # Validate that the netlink interface matches the ioctl interface.
442    self.assertEquals(net_test.GetInterfaceIndex(_TEST_XFRM_IFNAME), if_index)
443    self.iproute.DeleteLink(_TEST_XFRM_IFNAME)
444    with self.assertRaises(IOError):
445      self.iproute.GetIfIndex(_TEST_XFRM_IFNAME)
446
447
448class XfrmInterface(IpSecBaseInterface):
449
450  def __init__(self, iface, netid, underlying_netid, ifindex, local, remote,
451               version):
452    super(XfrmInterface, self).__init__(iface, netid, underlying_netid, local,
453                                        remote, version)
454
455    self.ifindex = ifindex
456    self.xfrm_if_id = netid
457
458    self.SetupInterface()
459    self.SetupXfrm(False)
460
461  def SetupInterface(self):
462    """Create an XFRM interface."""
463    return self.iproute.CreateXfrmInterface(self.iface, self.netid, self.ifindex)
464
465  def _SetupXfrmByType(self, auth_algo, crypt_algo):
466    self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_OUT, None, self.local, self.remote,
467                           self.out_sa.spi, crypt_algo, auth_algo, None,
468                           self.underlying_netid, self.xfrm_if_id,
469                           xfrm.MATCH_METHOD_ALL)
470    self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_IN, None, self.remote, self.local,
471                           self.in_sa.spi, crypt_algo, auth_algo, None, None,
472                           self.xfrm_if_id, xfrm.MATCH_METHOD_IFID)
473
474  def TeardownXfrm(self):
475    self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_OUT, None, self.remote,
476                           self.out_sa.spi, None, self.xfrm_if_id)
477    self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_IN, None, self.local,
478                           self.in_sa.spi, None, self.xfrm_if_id)
479
480  def _Rekey(self, outer_family, new_out_sa, new_in_sa):
481    # TODO: Consider ways to share code with xfrm.CreateTunnel(). It's mostly
482    #       the same, but rekeys are asymmetric, and only update the outbound
483    #       policy.
484    self.xfrm.AddSaInfo(
485        self.local, self.remote, new_out_sa.spi, xfrm.XFRM_MODE_TUNNEL, 0,
486        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL, None, None,
487        None, self.underlying_netid, xfrm_if_id=self.xfrm_if_id)
488
489    self.xfrm.AddSaInfo(
490        self.remote, self.local, new_in_sa.spi, xfrm.XFRM_MODE_TUNNEL, 0,
491        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL, None, None,
492        None, None, xfrm_if_id=self.xfrm_if_id)
493
494    # Create new policies for IPv4 and IPv6.
495    for sel in [xfrm.EmptySelector(AF_INET), xfrm.EmptySelector(AF_INET6)]:
496      # Add SPI-specific output policy to enforce using new outbound SPI
497      policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
498      tmpl = xfrm.UserTemplate(outer_family, new_out_sa.spi, 0,
499                                    (self.local, self.remote))
500      self.xfrm.UpdatePolicyInfo(policy, tmpl, None, self.xfrm_if_id)
501
502  def DeleteOldSaInfo(self, outer_family, old_in_spi, old_out_spi):
503    self.xfrm.DeleteSaInfo(self.local, old_in_spi, IPPROTO_ESP, None,
504                           self.xfrm_if_id)
505    self.xfrm.DeleteSaInfo(self.remote, old_out_spi, IPPROTO_ESP, None,
506                           self.xfrm_if_id)
507
508
509class XfrmTunnelBase(xfrm_base.XfrmBaseTest):
510
511  @classmethod
512  def setUpClass(cls):
513    xfrm_base.XfrmBaseTest.setUpClass()
514    # Tunnel interfaces use marks extensively, so configure realistic packet
515    # marking rules to make the test representative, make PMTUD work, etc.
516    cls.SetInboundMarks(True)
517    cls.SetMarkReflectSysctls(1)
518
519    # Group by tunnel version to ensure that we test at least one IPv4 and one
520    # IPv6 tunnel
521    cls.tunnelsV4 = {}
522    cls.tunnelsV6 = {}
523    for i, underlying_netid in enumerate(cls.tuns):
524      for version in 4, 6:
525        netid = _BASE_TUNNEL_NETID[version] + _TUNNEL_NETID_OFFSET + i
526        iface = "ipsec%s" % netid
527        local = cls.MyAddress(version, underlying_netid)
528        if version == 4:
529          remote = (net_test.IPV4_ADDR if (i % 2) else net_test.IPV4_ADDR2)
530        else:
531          remote = (net_test.IPV6_ADDR if (i % 2) else net_test.IPV6_ADDR2)
532
533        ifindex = cls.ifindices[underlying_netid]
534        tunnel = cls.INTERFACE_CLASS(iface, netid, underlying_netid, ifindex,
535                                   local, remote, version)
536        cls._SetInboundMarking(netid, iface, True)
537        cls._SetupTunnelNetwork(tunnel, True)
538
539        if version == 4:
540          cls.tunnelsV4[netid] = tunnel
541        else:
542          cls.tunnelsV6[netid] = tunnel
543
544  @classmethod
545  def tearDownClass(cls):
546    # The sysctls are restored by MultinetworkBaseTest.tearDownClass.
547    cls.SetInboundMarks(False)
548    for tunnel in cls.tunnelsV4.values() + cls.tunnelsV6.values():
549      cls._SetInboundMarking(tunnel.netid, tunnel.iface, False)
550      cls._SetupTunnelNetwork(tunnel, False)
551      tunnel.Teardown()
552    xfrm_base.XfrmBaseTest.tearDownClass()
553
554  def randomTunnel(self, outer_version):
555    version_dict = self.tunnelsV4 if outer_version == 4 else self.tunnelsV6
556    return random.choice(version_dict.values())
557
558  def setUp(self):
559    multinetwork_base.MultiNetworkBaseTest.setUp(self)
560    self.iproute = iproute.IPRoute()
561    self.xfrm = xfrm.Xfrm()
562
563  def tearDown(self):
564    multinetwork_base.MultiNetworkBaseTest.tearDown(self)
565
566  def _SwapInterfaceAddress(self, ifname, old_addr, new_addr):
567    """Exchange two addresses on a given interface.
568
569    Args:
570      ifname: Name of the interface
571      old_addr: An address to be removed from the interface
572      new_addr: An address to be added to an interface
573    """
574    version = 6 if ":" in new_addr else 4
575    ifindex = net_test.GetInterfaceIndex(ifname)
576    self.iproute.AddAddress(new_addr,
577                            net_test.AddressLengthBits(version), ifindex)
578    self.iproute.DelAddress(old_addr,
579                            net_test.AddressLengthBits(version), ifindex)
580
581  @classmethod
582  def _GetLocalAddress(cls, version, netid):
583    if version == 4:
584      return cls._MyIPv4Address(netid - _TUNNEL_NETID_OFFSET)
585    else:
586      return cls.OnlinkPrefix(6, netid - _TUNNEL_NETID_OFFSET) + "1"
587
588  @classmethod
589  def _SetupTunnelNetwork(cls, tunnel, is_add):
590    """Setup rules and routes for a tunnel Network.
591
592    Takes an interface and depending on the boolean
593    value of is_add, either adds or removes the rules
594    and routes for a tunnel interface to behave like an
595    Android Network for purposes of testing.
596
597    Args:
598      tunnel: A VtiInterface or XfrmInterface, the tunnel to set up.
599      is_add: Boolean that causes this method to perform setup if True or
600        teardown if False
601    """
602    if is_add:
603      # Disable router solicitations to avoid occasional spurious packets
604      # arriving on the underlying network; there are two possible behaviors
605      # when that occurred: either only the RA packet is read, and when it
606      # is echoed back to the tunnel, it causes the test to fail by not
607      # receiving # the UDP_PAYLOAD; or, two packets may arrive on the
608      # underlying # network which fails the assertion that only one ESP packet
609      # is received.
610      cls.SetSysctl(
611          "/proc/sys/net/ipv6/conf/%s/router_solicitations" % tunnel.iface, 0)
612      net_test.SetInterfaceUp(tunnel.iface)
613
614    for version in [4, 6]:
615      ifindex = net_test.GetInterfaceIndex(tunnel.iface)
616      table = tunnel.netid
617
618      # Set up routing rules.
619      start, end = cls.UidRangeForNetid(tunnel.netid)
620      cls.iproute.UidRangeRule(version, is_add, start, end, table,
621                                cls.PRIORITY_UID)
622      cls.iproute.OifRule(version, is_add, tunnel.iface, table, cls.PRIORITY_OIF)
623      cls.iproute.FwmarkRule(version, is_add, tunnel.netid, cls.NETID_FWMASK,
624                              table, cls.PRIORITY_FWMARK)
625
626      # Configure IP addresses.
627      addr = cls._GetLocalAddress(version, tunnel.netid)
628      prefixlen = net_test.AddressLengthBits(version)
629      tunnel.addrs[version] = addr
630      if is_add:
631        cls.iproute.AddAddress(addr, prefixlen, ifindex)
632        cls.iproute.AddRoute(version, table, "default", 0, None, ifindex)
633      else:
634        cls.iproute.DelRoute(version, table, "default", 0, None, ifindex)
635        cls.iproute.DelAddress(addr, prefixlen, ifindex)
636
637  def assertReceivedPacket(self, tunnel, sa_info):
638    tunnel.rx += 1
639    self.assertEquals((tunnel.rx, tunnel.tx),
640                      self.iproute.GetRxTxPackets(tunnel.iface))
641    sa_info.seq_num += 1
642
643  def assertSentPacket(self, tunnel, sa_info):
644    tunnel.tx += 1
645    self.assertEquals((tunnel.rx, tunnel.tx),
646                      self.iproute.GetRxTxPackets(tunnel.iface))
647    sa_info.seq_num += 1
648
649  def _CheckTunnelInput(self, tunnel, inner_version, local_inner, remote_inner,
650                        sa_info=None, expect_fail=False):
651    """Test null-crypt input path over an IPsec interface."""
652    if sa_info is None:
653      sa_info = tunnel.in_sa
654    read_sock, local_port = _CreateReceiveSock(inner_version)
655
656    input_pkt = _GetNullAuthCryptTunnelModePkt(
657        inner_version, remote_inner, tunnel.remote, _TEST_REMOTE_PORT,
658        local_inner, tunnel.local, local_port, sa_info.spi, sa_info.seq_num)
659    self.ReceivePacketOn(tunnel.underlying_netid, input_pkt)
660
661    if expect_fail:
662      self.assertRaisesErrno(EAGAIN, read_sock.recv, 4096)
663    else:
664      # Verify that the packet data and src are correct
665      data, src = read_sock.recvfrom(4096)
666      self.assertReceivedPacket(tunnel, sa_info)
667      self.assertEquals(net_test.UDP_PAYLOAD, data)
668      self.assertEquals((remote_inner, _TEST_REMOTE_PORT), src[:2])
669
670  def _CheckTunnelOutput(self, tunnel, inner_version, local_inner,
671                         remote_inner, sa_info=None):
672    """Test null-crypt output path over an IPsec interface."""
673    if sa_info is None:
674      sa_info = tunnel.out_sa
675    local_port = _SendPacket(self, tunnel.netid, inner_version, remote_inner,
676                             _TEST_REMOTE_PORT)
677
678    # Read a tunneled IP packet on the underlying (outbound) network
679    # verifying that it is an ESP packet.
680    pkt = self._ExpectEspPacketOn(tunnel.underlying_netid, sa_info.spi,
681                                  sa_info.seq_num, None, tunnel.local,
682                                  tunnel.remote)
683
684    # Get and update the IP headers on the inner payload so that we can do a simple
685    # comparison of byte data. Unfortunately, due to the scapy version this runs on,
686    # we cannot parse past the ESP header to the inner IP header, and thus have to
687    # workaround in this manner
688    if inner_version == 4:
689      ip_hdr_options = {
690        'id': scapy.IP(str(pkt.payload)[8:]).id,
691        'flags': scapy.IP(str(pkt.payload)[8:]).flags
692      }
693    else:
694      ip_hdr_options = {'fl': scapy.IPv6(str(pkt.payload)[8:]).fl}
695
696    expected = _GetNullAuthCryptTunnelModePkt(
697        inner_version, local_inner, tunnel.local, local_port, remote_inner,
698        tunnel.remote, _TEST_REMOTE_PORT, sa_info.spi, sa_info.seq_num,
699        ip_hdr_options)
700
701    # Check outer header manually (Avoids having to overwrite outer header's
702    # id, flags or flow label)
703    self.assertSentPacket(tunnel, sa_info)
704    self.assertEquals(expected.src, pkt.src)
705    self.assertEquals(expected.dst, pkt.dst)
706    self.assertEquals(len(expected), len(pkt))
707
708    # Check everything else
709    self.assertEquals(str(expected.payload), str(pkt.payload))
710
711  def _CheckTunnelEncryption(self, tunnel, inner_version, local_inner,
712                             remote_inner):
713    """Test both input and output paths over an encrypted IPsec interface.
714
715    This tests specifically makes sure that the both encryption and decryption
716    work together, as opposed to the _CheckTunnel(Input|Output) where the
717    input and output paths are tested separately, and using null encryption.
718    """
719    src_port = _SendPacket(self, tunnel.netid, inner_version, remote_inner,
720                           _TEST_REMOTE_PORT)
721
722    # Make sure it appeared on the underlying interface
723    pkt = self._ExpectEspPacketOn(tunnel.underlying_netid, tunnel.out_sa.spi,
724                                  tunnel.out_sa.seq_num, None, tunnel.local,
725                                  tunnel.remote)
726
727    # Check that packet is not sent in plaintext
728    self.assertTrue(str(net_test.UDP_PAYLOAD) not in str(pkt))
729
730    # Check src/dst
731    self.assertEquals(tunnel.local, pkt.src)
732    self.assertEquals(tunnel.remote, pkt.dst)
733
734    # Check that the interface statistics recorded the outbound packet
735    self.assertSentPacket(tunnel, tunnel.out_sa)
736
737    try:
738      # Swap the interface addresses to pretend we are the remote
739      self._SwapInterfaceAddress(
740          tunnel.iface, new_addr=remote_inner, old_addr=local_inner)
741
742      # Swap the packet's IP headers and write it back to the underlying
743      # network.
744      pkt = TunTwister.TwistPacket(pkt)
745      read_sock, local_port = _CreateReceiveSock(inner_version,
746                                                 _TEST_REMOTE_PORT)
747      self.ReceivePacketOn(tunnel.underlying_netid, pkt)
748
749      # Verify that the packet data and src are correct
750      data, src = read_sock.recvfrom(4096)
751      self.assertEquals(net_test.UDP_PAYLOAD, data)
752      self.assertEquals((local_inner, src_port), src[:2])
753
754      # Check that the interface statistics recorded the inbound packet
755      self.assertReceivedPacket(tunnel, tunnel.in_sa)
756    finally:
757      # Swap the interface addresses to pretend we are the remote
758      self._SwapInterfaceAddress(
759          tunnel.iface, new_addr=local_inner, old_addr=remote_inner)
760
761  def _CheckTunnelIcmp(self, tunnel, inner_version, local_inner, remote_inner,
762                       sa_info=None):
763    """Test ICMP error path over an IPsec interface."""
764    if sa_info is None:
765      sa_info = tunnel.out_sa
766    # Now attempt to provoke an ICMP error.
767    # TODO: deduplicate with multinetwork_test.py.
768    dst_prefix, intermediate = {
769        4: ("172.19.", "172.16.9.12"),
770        6: ("2001:db8::", "2001:db8::1")
771    }[tunnel.version]
772
773    local_port = _SendPacket(self, tunnel.netid, inner_version, remote_inner,
774                             _TEST_REMOTE_PORT)
775    pkt = self._ExpectEspPacketOn(tunnel.underlying_netid, sa_info.spi,
776                                  sa_info.seq_num, None, tunnel.local,
777                                  tunnel.remote)
778    self.assertSentPacket(tunnel, sa_info)
779
780    myaddr = self.MyAddress(tunnel.version, tunnel.underlying_netid)
781    _, toobig = packets.ICMPPacketTooBig(tunnel.version, intermediate, myaddr,
782                                         pkt)
783    self.ReceivePacketOn(tunnel.underlying_netid, toobig)
784
785    # Check that the packet too big reduced the MTU.
786    routes = self.iproute.GetRoutes(tunnel.remote, 0, tunnel.underlying_netid, None)
787    self.assertEquals(1, len(routes))
788    rtmsg, attributes = routes[0]
789    self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
790    self.assertEquals(packets.PTB_MTU, attributes["RTA_METRICS"]["RTAX_MTU"])
791
792    # Clear PMTU information so that future tests don't have to worry about it.
793    self.InvalidateDstCache(tunnel.version, tunnel.underlying_netid)
794
795  def _CheckTunnelEncryptionWithIcmp(self, tunnel, inner_version, local_inner,
796                                     remote_inner):
797    """Test combined encryption path with ICMP errors over an IPsec tunnel"""
798    self._CheckTunnelEncryption(tunnel, inner_version, local_inner,
799                                remote_inner)
800    self._CheckTunnelIcmp(tunnel, inner_version, local_inner, remote_inner)
801    self._CheckTunnelEncryption(tunnel, inner_version, local_inner,
802                                remote_inner)
803
804  def _TestTunnel(self, inner_version, outer_version, func, use_null_crypt):
805    """Bootstrap method to setup and run tests for the given parameters."""
806    tunnel = self.randomTunnel(outer_version)
807
808    try:
809      # Some tests require that the out_seq_num and in_seq_num are the same
810      # (Specifically encrypted tests), rebuild SAs to ensure seq_num is 1
811      #
812      # Until we get better scapy support, the only way we can build an
813      # encrypted packet is to send it out, and read the packet from the wire.
814      # We then generally use this as the "inbound" encrypted packet, injecting
815      # it into the interface for which it is expected on.
816      #
817      # As such, this is required to ensure that encrypted packets (which we
818      # currently have no way to easily modify) are not considered replay
819      # attacks by the inbound SA.  (eg: received 3 packets, seq_num_in = 3,
820      # sent only 1, # seq_num_out = 1, inbound SA would consider this a replay
821      # attack)
822      tunnel.TeardownXfrm()
823      tunnel.SetupXfrm(use_null_crypt)
824
825      local_inner = tunnel.addrs[inner_version]
826      remote_inner = _GetRemoteInnerAddress(inner_version)
827
828      for i in range(2):
829        func(tunnel, inner_version, local_inner, remote_inner)
830    finally:
831      if use_null_crypt:
832        tunnel.TeardownXfrm()
833        tunnel.SetupXfrm(False)
834
835  def _CheckTunnelRekey(self, tunnel, inner_version, local_inner, remote_inner):
836    old_out_sa = tunnel.out_sa
837    old_in_sa = tunnel.in_sa
838
839    # Check to make sure that both directions work before rekey
840    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
841                           old_in_sa)
842    self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
843                            old_out_sa)
844
845    # Rekey
846    outer_family = net_test.GetAddressFamily(tunnel.version)
847
848    # Create new SA
849    # Distinguish the new SAs with new SPIs.
850    new_out_sa = SaInfo(old_out_sa.spi + 1)
851    new_in_sa = SaInfo(old_in_sa.spi + 1)
852
853    # Perform Rekey
854    tunnel.Rekey(outer_family, new_out_sa, new_in_sa)
855
856    # Expect that the old SPI still works for inbound packets
857    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
858                           old_in_sa)
859
860    # Test both paths with new SPIs, expect outbound to use new SPI
861    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
862                           new_in_sa)
863    self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
864                            new_out_sa)
865
866    # Delete old SAs
867    tunnel.DeleteOldSaInfo(outer_family, old_in_sa.spi, old_out_sa.spi)
868
869    # Test both paths with new SPIs; should still work
870    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
871                           new_in_sa)
872    self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
873                            new_out_sa)
874
875    # Expect failure upon trying to receive a packet with the deleted SPI
876    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
877                           old_in_sa, True)
878
879  def _TestTunnelRekey(self, inner_version, outer_version):
880    """Test packet input and output over a Virtual Tunnel Interface."""
881    tunnel = self.randomTunnel(outer_version)
882
883    try:
884      # Always use null_crypt, so we can check input and output separately
885      tunnel.TeardownXfrm()
886      tunnel.SetupXfrm(True)
887
888      local_inner = tunnel.addrs[inner_version]
889      remote_inner = _GetRemoteInnerAddress(inner_version)
890
891      self._CheckTunnelRekey(tunnel, inner_version, local_inner, remote_inner)
892    finally:
893      tunnel.TeardownXfrm()
894      tunnel.SetupXfrm(False)
895
896
897@unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
898class XfrmVtiTest(XfrmTunnelBase):
899
900  INTERFACE_CLASS = VtiInterface
901
902  def ParamTestVtiInput(self, inner_version, outer_version):
903    self._TestTunnel(inner_version, outer_version, self._CheckTunnelInput, True)
904
905  def ParamTestVtiOutput(self, inner_version, outer_version):
906    self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
907                     True)
908
909  def ParamTestVtiInOutEncrypted(self, inner_version, outer_version):
910    self._TestTunnel(inner_version, outer_version, self._CheckTunnelEncryption,
911                     False)
912
913  def ParamTestVtiIcmp(self, inner_version, outer_version):
914    self._TestTunnel(inner_version, outer_version, self._CheckTunnelIcmp, False)
915
916  def ParamTestVtiEncryptionWithIcmp(self, inner_version, outer_version):
917    self._TestTunnel(inner_version, outer_version,
918                     self._CheckTunnelEncryptionWithIcmp, False)
919
920  def ParamTestVtiRekey(self, inner_version, outer_version):
921    self._TestTunnelRekey(inner_version, outer_version)
922
923
924@unittest.skipUnless(HAVE_XFRM_INTERFACES, "XFRM interfaces unsupported")
925class XfrmInterfaceTest(XfrmTunnelBase):
926
927  INTERFACE_CLASS = XfrmInterface
928
929  def ParamTestXfrmIntfInput(self, inner_version, outer_version):
930    self._TestTunnel(inner_version, outer_version, self._CheckTunnelInput, True)
931
932  def ParamTestXfrmIntfOutput(self, inner_version, outer_version):
933    self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
934                     True)
935
936  def ParamTestXfrmIntfInOutEncrypted(self, inner_version, outer_version):
937    self._TestTunnel(inner_version, outer_version, self._CheckTunnelEncryption,
938                     False)
939
940  def ParamTestXfrmIntfIcmp(self, inner_version, outer_version):
941    self._TestTunnel(inner_version, outer_version, self._CheckTunnelIcmp, False)
942
943  def ParamTestXfrmIntfEncryptionWithIcmp(self, inner_version, outer_version):
944    self._TestTunnel(inner_version, outer_version,
945                     self._CheckTunnelEncryptionWithIcmp, False)
946
947  def ParamTestXfrmIntfRekey(self, inner_version, outer_version):
948    self._TestTunnelRekey(inner_version, outer_version)
949
950
951if __name__ == "__main__":
952  InjectTests()
953  unittest.main()
954