• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Base module for multinetwork tests."""
18
19import errno
20import fcntl
21import os
22import posix
23import random
24import re
25from socket import *  # pylint: disable=wildcard-import
26import struct
27import time
28
29from scapy import all as scapy
30
31import csocket
32import iproute
33import net_test
34
35
36IFF_TUN = 1
37IFF_TAP = 2
38IFF_NO_PI = 0x1000
39TUNSETIFF = 0x400454ca
40
41SO_BINDTODEVICE = 25
42
43# Setsockopt values.
44IP_UNICAST_IF = 50
45IPV6_MULTICAST_IF = 17
46IPV6_UNICAST_IF = 76
47
48# Cmsg values.
49IP_TTL = 2
50IPV6_2292PKTOPTIONS = 6
51IPV6_FLOWINFO = 11
52IPV6_HOPLIMIT = 52  # Different from IPV6_UNICAST_HOPS, this is cmsg only.
53
54
55AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
56IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
57IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
58
59HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
60
61
62class ConfigurationError(AssertionError):
63  pass
64
65
66class UnexpectedPacketError(AssertionError):
67  pass
68
69
70def MakePktInfo(version, addr, ifindex):
71  family = {4: AF_INET, 6: AF_INET6}[version]
72  if not addr:
73    addr = {4: "0.0.0.0", 6: "::"}[version]
74  if addr:
75    addr = inet_pton(family, addr)
76  if version == 6:
77    return csocket.In6Pktinfo((addr, ifindex)).Pack()
78  else:
79    return csocket.InPktinfo((ifindex, addr, b"\x00" * 4)).Pack()
80
81
82class MultiNetworkBaseTest(net_test.NetworkTest):
83  """Base class for all multinetwork tests.
84
85  This class does not contain any test code, but contains code to set up and
86  tear a multi-network environment using multiple tun interfaces. The
87  environment is designed to be similar to a real Android device in terms of
88  rules and routes, and supports IPv4 and IPv6.
89
90  Tests wishing to use this environment should inherit from this class and
91  ensure that any setupClass, tearDownClass, setUp, and tearDown methods they
92  implement also call the superclass versions.
93  """
94
95  # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
96  NETIDS = [100, 150, 200, 250]
97
98  # Stores sysctl values to write back when the test completes.
99  saved_sysctls = {}
100
101  # Wether to output setup commands.
102  DEBUG = False
103
104  # The size of our UID ranges.
105  UID_RANGE_SIZE = 1000
106
107  # Rule priorities.
108  PRIORITY_UID = 100
109  PRIORITY_OIF = 200
110  PRIORITY_FWMARK = 300
111  PRIORITY_IIF = 400
112  PRIORITY_DEFAULT = 999
113  PRIORITY_UNREACHABLE = 1000
114
115  # Actual device routing is more complicated, involving more than one rule
116  # per NetId, but here we make do with just one rule that selects the lower
117  # 16 bits.
118  NETID_FWMASK = 0xffff
119
120  # For convenience.
121  IPV4_ADDR = net_test.IPV4_ADDR
122  IPV6_ADDR = net_test.IPV6_ADDR
123  IPV4_ADDR2 = net_test.IPV4_ADDR2
124  IPV6_ADDR2 = net_test.IPV6_ADDR2
125  IPV4_PING = net_test.IPV4_PING
126  IPV6_PING = net_test.IPV6_PING
127
128  RA_VALIDITY = 300 # seconds
129
130  @classmethod
131  def UidRangeForNetid(cls, netid):
132    return (
133        cls.UID_RANGE_SIZE * netid,
134        cls.UID_RANGE_SIZE * (netid + 1) - 1
135    )
136
137  @classmethod
138  def UidForNetid(cls, netid):
139    if not netid:
140      return 0
141    return random.randint(*cls.UidRangeForNetid(netid))
142
143  @classmethod
144  def _TableForNetid(cls, netid):
145    if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
146      return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
147    else:
148      return netid
149
150  @staticmethod
151  def GetInterfaceName(netid):
152    return "nettest%d" % netid
153
154  @staticmethod
155  def RouterMacAddress(netid):
156    return "02:00:00:00:%02x:00" % netid
157
158  @staticmethod
159  def MyMacAddress(netid):
160    return "02:00:00:00:%02x:01" % netid
161
162  @staticmethod
163  def _RouterAddress(netid, version):
164    if version == 6:
165      return "fe80::%02x00" % netid
166    elif version == 4:
167      return "10.0.%d.1" % netid
168    else:
169      raise ValueError("Don't support IPv%s" % version)
170
171  @classmethod
172  def _MyIPv4Address(cls, netid):
173    return "10.0.%d.2" % netid
174
175  @classmethod
176  def _MyIPv6Address(cls, netid):
177    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
178
179  @classmethod
180  def MyAddress(cls, version, netid):
181    return {4: cls._MyIPv4Address(netid),
182            5: cls._MyIPv4Address(netid),
183            6: cls._MyIPv6Address(netid)}[version]
184
185  @classmethod
186  def MySocketAddress(cls, version, netid):
187    addr = cls.MyAddress(version, netid)
188    return "::ffff:" + addr if version == 5 else addr
189
190  @classmethod
191  def MyLinkLocalAddress(cls, netid):
192    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True)
193
194  @staticmethod
195  def OnlinkPrefixLen(version):
196    return {4: 24, 6: 64}[version]
197
198  @staticmethod
199  def OnlinkPrefix(version, netid):
200    return {4: "10.0.%d.0" % netid,
201            6: "2001:db8:%02x::" % netid}[version]
202
203  @staticmethod
204  def GetRandomDestination(prefix):
205    if "." in prefix:
206      return prefix + "%d.%d" % (random.randint(0, 255), random.randint(0, 255))
207    else:
208      return prefix + "%x:%x" % (random.randint(0, 65535),
209                                 random.randint(0, 65535))
210
211  def GetProtocolFamily(self, version):
212    return {4: AF_INET, 6: AF_INET6}[version]
213
214  @classmethod
215  def CreateTunInterface(cls, netid):
216    iface = cls.GetInterfaceName(netid)
217    try:
218      f = open("/dev/net/tun", "r+b", buffering=0)
219    except IOError:
220      f = open("/dev/tun", "r+b", buffering=0)
221    ifr = struct.pack("16sH", iface.encode(), IFF_TAP | IFF_NO_PI)
222    ifr += b"\x00" * (40 - len(ifr))
223    fcntl.ioctl(f, TUNSETIFF, ifr)
224    # Give ourselves a predictable MAC address.
225    net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
226    # Disable DAD so we don't have to wait for it.
227    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
228    # Set accept_ra to 2, because that's what we use.
229    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2)
230    net_test.SetInterfaceUp(iface)
231    net_test.SetNonBlocking(f)
232    return f
233
234  @classmethod
235  def SendRA(cls, netid, retranstimer=None, reachabletime=0, options=()):
236    validity = cls.RA_VALIDITY # seconds
237    macaddr = cls.RouterMacAddress(netid)
238    lladdr = cls._RouterAddress(netid, 6)
239
240    if retranstimer is None:
241      # If no retrans timer was specified, pick one that's as long as the
242      # router lifetime. This ensures that no spurious ND retransmits
243      # will interfere with test expectations.
244      retranstimer = validity * 1000  # Lifetime is in s, retrans timer in ms.
245
246    # We don't want any routes in the main table. If the kernel doesn't support
247    # putting RA routes into per-interface tables, configure routing manually.
248    routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0
249
250    ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
251          scapy.IPv6(src=lladdr, hlim=255) /
252          scapy.ICMPv6ND_RA(reachabletime=reachabletime,
253                            retranstimer=retranstimer,
254                            routerlifetime=routerlifetime) /
255          scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
256          scapy.ICMPv6NDOptPrefixInfo(prefix=cls.OnlinkPrefix(6, netid),
257                                      prefixlen=cls.OnlinkPrefixLen(6),
258                                      L=1, A=1,
259                                      validlifetime=validity,
260                                      preferredlifetime=validity))
261    for option in options:
262      ra /= option
263    posix.write(cls.tuns[netid].fileno(), bytes(ra))
264
265  @classmethod
266  def _RunSetupCommands(cls, netid, is_add):
267    for version in [4, 6]:
268      # Find out how to configure things.
269      iface = cls.GetInterfaceName(netid)
270      ifindex = cls.ifindices[netid]
271      macaddr = cls.RouterMacAddress(netid)
272      router = cls._RouterAddress(netid, version)
273      table = cls._TableForNetid(netid)
274
275      # Set up routing rules.
276      start, end = cls.UidRangeForNetid(netid)
277      cls.iproute.UidRangeRule(version, is_add, start, end, table,
278                               cls.PRIORITY_UID)
279      cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF)
280      cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table,
281                             cls.PRIORITY_FWMARK)
282
283      # Configure routing and addressing.
284      #
285      # IPv6 uses autoconf for everything, except if per-device autoconf routing
286      # tables are not supported, in which case the default route (only) is
287      # configured manually. For IPv4 we have to manually configure addresses,
288      # routes, and neighbour cache entries (since we don't reply to ARP or ND).
289      #
290      # Since deleting addresses also causes routes to be deleted, we need to
291      # be careful with ordering or the delete commands will fail with ENOENT.
292      #
293      # A real Android system will have both IPv4 and IPv6 routes for
294      # directly-connected subnets in the per-interface routing tables. Ensure
295      # we create those as well.
296      do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None)
297      if is_add:
298        if version == 4:
299          cls.iproute.AddAddress(cls._MyIPv4Address(netid),
300                                 cls.OnlinkPrefixLen(4), ifindex)
301          cls.iproute.AddNeighbour(version, router, macaddr, ifindex)
302        if do_routing:
303          cls.iproute.AddRoute(version, table,
304                               cls.OnlinkPrefix(version, netid),
305                               cls.OnlinkPrefixLen(version), None, ifindex)
306          cls.iproute.AddRoute(version, table, "default", 0, router, ifindex)
307      else:
308        if do_routing:
309          cls.iproute.DelRoute(version, table, "default", 0, router, ifindex)
310          cls.iproute.DelRoute(version, table,
311                               cls.OnlinkPrefix(version, netid),
312                               cls.OnlinkPrefixLen(version), None, ifindex)
313        if version == 4:
314          cls.iproute.DelNeighbour(version, router, macaddr, ifindex)
315          cls.iproute.DelAddress(cls._MyIPv4Address(netid),
316                                 cls.OnlinkPrefixLen(4), ifindex)
317
318  @classmethod
319  def SetMarkReflectSysctls(cls, value):
320    """Makes kernel-generated replies use the mark of the original packet."""
321    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
322    cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
323
324  @classmethod
325  def _SetInboundMarking(cls, netid, iface, is_add):
326    for version in [4, 6]:
327      # Run iptables to set up incoming packet marking.
328      add_del = "-A" if is_add else "-D"
329      iptables = {4: "iptables", 6: "ip6tables"}[version]
330      args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
331          add_del, iface, netid)
332      if net_test.RunIptablesCommand(version, args):
333        raise ConfigurationError("Setup command failed: %s" % args)
334
335  @classmethod
336  def SetInboundMarks(cls, is_add):
337    for netid in cls.tuns:
338      cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add)
339
340  @classmethod
341  def SetDefaultNetwork(cls, netid):
342    table = cls._TableForNetid(netid) if netid else None
343    for version in [4, 6]:
344      is_add = table is not None
345      cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT)
346
347  @classmethod
348  def ClearDefaultNetwork(cls):
349    cls.SetDefaultNetwork(None)
350
351  @classmethod
352  def GetSysctl(cls, sysctl):
353    with open(sysctl, "r") as sysctl_file:
354      return sysctl_file.read()
355
356  @classmethod
357  def SetSysctl(cls, sysctl, value):
358    # Only save each sysctl value the first time we set it. This is so we can
359    # set it to arbitrary values multiple times and still write it back
360    # correctly at the end.
361    if sysctl not in cls.saved_sysctls:
362      cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
363    with open(sysctl, "w") as sysctl_file:
364      sysctl_file.write(str(value) + "\n")
365
366  @classmethod
367  def SetIPv6SysctlOnAllIfaces(cls, sysctl, value):
368    for netid in cls.tuns:
369      iface = cls.GetInterfaceName(netid)
370      name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl)
371      cls.SetSysctl(name, value)
372
373  @classmethod
374  def _RestoreSysctls(cls):
375    for sysctl, value in cls.saved_sysctls.items():
376      try:
377        with open(sysctl, "w") as sysctl_file:
378          sysctl_file.write(value)
379      except IOError:
380        pass
381
382  @classmethod
383  def _ICMPRatelimitFilename(cls, version):
384    return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
385                               6: "ipv6/icmp/ratelimit"}[version]
386
387  @classmethod
388  def _SetICMPRatelimit(cls, version, limit):
389    cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
390
391  @classmethod
392  def setUpClass(cls):
393    # This is per-class setup instead of per-testcase setup because shelling out
394    # to ip and iptables is slow, and because routing configuration doesn't
395    # change during the test.
396    cls.iproute = iproute.IPRoute()
397    cls.tuns = {}
398    cls.ifindices = {}
399    if HAVE_AUTOCONF_TABLE:
400      cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
401      cls.AUTOCONF_TABLE_OFFSET = -1000
402    else:
403      cls.AUTOCONF_TABLE_OFFSET = None
404
405    # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
406    for version in [4, 6]:
407      cls._SetICMPRatelimit(version, 0)
408
409    for version in [4, 6]:
410      cls.iproute.UnreachableRule(version, True, cls.PRIORITY_UNREACHABLE)
411
412    for netid in cls.NETIDS:
413      cls.tuns[netid] = cls.CreateTunInterface(netid)
414      iface = cls.GetInterfaceName(netid)
415      cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
416
417      cls.SendRA(netid)
418      cls._RunSetupCommands(netid, True)
419
420    # Don't print lots of "device foo entered promiscuous mode" warnings.
421    cls.loglevel = cls.GetConsoleLogLevel()
422    cls.SetConsoleLogLevel(net_test.KERN_INFO)
423
424    # When running on device, don't send connections through FwmarkServer.
425    os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1"
426
427    # Uncomment to look around at interface and rule configuration while
428    # running in the background. (Once the test finishes running, all the
429    # interfaces and rules are gone.)
430    # time.sleep(30)
431
432  @classmethod
433  def tearDownClass(cls):
434    del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"]
435
436    for version in [4, 6]:
437      try:
438        cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE)
439      except IOError:
440        pass
441
442    for netid in cls.tuns:
443      cls._RunSetupCommands(netid, False)
444      cls.tuns[netid].close()
445
446    cls.iproute.close()
447    cls._RestoreSysctls()
448    cls.SetConsoleLogLevel(cls.loglevel)
449
450  def setUp(self):
451    self.ClearTunQueues()
452
453  def SetSocketMark(self, s, netid):
454    if netid is None:
455      netid = 0
456    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
457
458  def GetSocketMark(self, s):
459    return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
460
461  def ClearSocketMark(self, s):
462    self.SetSocketMark(s, 0)
463
464  def BindToDevice(self, s, iface):
465    if not iface:
466      iface = ""
467    s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface.encode())
468
469  def SetUnicastInterface(self, s, ifindex):
470    # Otherwise, Python thinks it's a 1-byte option.
471    ifindex = struct.pack("!I", ifindex)
472
473    # Always set the IPv4 interface, because it will be used even on IPv6
474    # sockets if the destination address is a mapped address.
475    s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
476    if s.family == AF_INET6:
477      s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
478
479  def GetRemoteAddress(self, version):
480    return {4: self.IPV4_ADDR,
481            5: self.IPV4_ADDR,  # see GetRemoteSocketAddress()
482            6: self.IPV6_ADDR}[version]
483
484  def GetRemoteSocketAddress(self, version):
485    addr = self.GetRemoteAddress(version)
486    return "::ffff:" + addr if version == 5 else addr
487
488  def GetOtherRemoteSocketAddress(self, version):
489    return {4: self.IPV4_ADDR2,
490            5: "::ffff:" + self.IPV4_ADDR2,
491            6: self.IPV6_ADDR2}[version]
492
493  def SelectInterface(self, s, netid, mode):
494    if mode == "uid":
495      os.fchown(s.fileno(), self.UidForNetid(netid), -1)
496    elif mode == "mark":
497      self.SetSocketMark(s, netid)
498    elif mode == "oif":
499      iface = self.GetInterfaceName(netid) if netid else ""
500      self.BindToDevice(s, iface)
501    elif mode == "ucast_oif":
502      self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
503    else:
504      raise ValueError("Unknown interface selection mode %s" % mode)
505
506  def BuildSocket(self, version, constructor, netid, routing_mode):
507    if version == 5: version = 6
508    s = constructor(self.GetProtocolFamily(version))
509
510    if routing_mode not in [None, "uid"]:
511      self.SelectInterface(s, netid, routing_mode)
512    elif routing_mode == "uid":
513      os.fchown(s.fileno(), self.UidForNetid(netid), -1)
514
515    return s
516
517  def RandomNetid(self, exclude=None):
518    """Return a random netid from the list of netids
519
520    Args:
521      exclude: a netid or list of netids that should not be chosen
522    """
523    if exclude is None:
524      exclude = []
525    elif isinstance(exclude, int):
526        exclude = [exclude]
527    diff = [netid for netid in self.NETIDS if netid not in exclude]
528    return random.choice(diff)
529
530  def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs):
531    if netid is not None:
532      pktinfo = MakePktInfo(version, None, self.ifindices[netid])
533      cmsg_level, cmsg_name = {
534          4: (net_test.SOL_IP, csocket.IP_PKTINFO),
535          6: (net_test.SOL_IPV6, csocket.IPV6_PKTINFO)}[version]
536      cmsgs.append((cmsg_level, cmsg_name, pktinfo))
537    csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM)
538
539  def ReceiveEtherPacketOn(self, netid, packet):
540    posix.write(self.tuns[netid].fileno(), bytes(packet))
541
542  def ReceivePacketOn(self, netid, ip_packet):
543    routermac = self.RouterMacAddress(netid)
544    mymac = self.MyMacAddress(netid)
545    packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
546    self.ReceiveEtherPacketOn(netid, packet)
547
548  def ReadAllPacketsOn(self, netid, include_multicast=False):
549    """Return all queued packets on a netid as a list.
550
551    Args:
552      netid: The netid from which to read packets
553      include_multicast: A boolean, whether to remove multicast packets
554        (default=False)
555    """
556    packets = []
557    retries = 0
558    max_retries = 1
559    while True:
560      try:
561        packet = posix.read(self.tuns[netid].fileno(), 4096)
562        if not packet:
563          break
564        ether = scapy.Ether(packet)
565        # Multicast frames are frames where the first byte of the destination
566        # MAC address has 1 in the least-significant bit.
567        if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
568          packets.append(ether.payload)
569      except OSError as e:
570        # EAGAIN means there are no more packets waiting.
571        if e.errno == errno.EAGAIN:
572          # If we didn't see any packets, try again for good luck.
573          if not packets and retries < max_retries:
574            time.sleep(0.01)
575            retries += 1
576            continue
577          else:
578            break
579        # Anything else is unexpected.
580        else:
581          raise e
582    return packets
583
584  def InvalidateDstCache(self, version, netid):
585    """Invalidates destination cache entries of sockets on the specified table.
586
587    Creates and then deletes a low-priority throw route in the table for the
588    given netid, which invalidates the destination cache entries of any sockets
589    that refer to routes in that table.
590
591    The fact that this method actually invalidates destination cache entries is
592    tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel
593    does not re-route sockets when they are remarked, but does re-route them if
594    this method is called.
595
596    Args:
597      version: The IP version, 4 or 6.
598      netid: The netid to invalidate dst caches on.
599    """
600    iface = self.GetInterfaceName(netid)
601    ifindex = self.ifindices[netid]
602    table = self._TableForNetid(netid)
603    for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]:
604      self.iproute._Route(version, iproute.RTPROT_STATIC, action, table,
605                          "default", 0, nexthop=None, dev=None, mark=None,
606                          uid=None, route_type=iproute.RTN_THROW,
607                          priority=100000)
608
609  def ClearTunQueues(self):
610    # Keep reading packets on all netids until we get no packets on any of them.
611    waiting = None
612    while waiting != 0:
613      waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
614
615  def assertPacketMatches(self, expected, actual):
616    # The expected packet is just a rough sketch of the packet we expect to
617    # receive. For example, it doesn't contain fields we can't predict, such as
618    # initial TCP sequence numbers, or that depend on the host implementation
619    # and settings, such as TCP options. To check whether the packet matches
620    # what we expect, instead of just checking all the known fields one by one,
621    # we blank out fields in the actual packet and then compare the whole
622    # packets to each other as strings. Because we modify the actual packet,
623    # make a copy here.
624    actual = actual.copy()
625
626    # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
627    actualip = actual.getlayer("IP")
628    expectedip = expected.getlayer("IP")
629    if actualip and expectedip:
630      actualip.id = expectedip.id
631      actualip.flags &= 5
632      actualip.chksum = None  # Change the header, recalculate the checksum.
633
634    # Blank out the flow label, since new kernels randomize it by default.
635    actualipv6 = actual.getlayer("IPv6")
636    expectedipv6 = expected.getlayer("IPv6")
637    if actualipv6 and expectedipv6:
638      actualipv6.fl = expectedipv6.fl
639
640    # Blank out UDP fields that we can't predict (e.g., the source port for
641    # kernel-originated packets).
642    actualudp = actual.getlayer("UDP")
643    expectedudp = expected.getlayer("UDP")
644    if actualudp and expectedudp:
645      if expectedudp.sport is None:
646        actualudp.sport = None
647        actualudp.chksum = None
648      elif actualudp.chksum == 0xffff:
649        # Scapy does not appear to change 0 to 0xffff as required by RFC 768.
650        actualudp.chksum = 0
651
652    # Since the TCP code below messes with options, recalculate the length.
653    if actualip:
654      actualip.len = None
655    if actualipv6:
656      actualipv6.plen = None
657
658    # Blank out TCP fields that we can't predict.
659    actualtcp = actual.getlayer("TCP")
660    expectedtcp = expected.getlayer("TCP")
661    if actualtcp and expectedtcp:
662      actualtcp.dataofs = expectedtcp.dataofs
663      actualtcp.options = expectedtcp.options
664      actualtcp.window = expectedtcp.window
665      if expectedtcp.sport is None:
666        actualtcp.sport = None
667      if expectedtcp.seq is None:
668        actualtcp.seq = None
669      if expectedtcp.ack is None:
670        actualtcp.ack = None
671      actualtcp.chksum = None
672
673    # Serialize the packet so that expected packet fields that are only set when
674    # a packet is serialized e.g., the checksum) are filled in.
675    expected_real = expected.__class__(bytes(expected))
676    actual_real = actual.__class__(bytes(actual))
677    # repr() can be expensive. Call it only if the test is going to fail and we
678    # want to see the error.
679    if expected_real != actual_real:
680      self.assertEqual(repr(expected_real), repr(actual_real))
681
682  def PacketMatches(self, expected, actual):
683    try:
684      self.assertPacketMatches(expected, actual)
685      return True
686    except AssertionError:
687      return False
688
689  def ExpectNoPacketsOn(self, netid, msg):
690    packets = self.ReadAllPacketsOn(netid)
691    if packets:
692      firstpacket = repr(packets[0])
693    else:
694      firstpacket = ""
695    self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
696
697  def ExpectPacketOn(self, netid, msg, expected):
698    # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
699    # multicast packets unless the packet we expect to see is a multicast
700    # packet. For now the only tests that use this are IPv6.
701    ipv6 = expected.getlayer("IPv6")
702    if ipv6 and ipv6.dst.startswith("ff"):
703      include_multicast = True
704    else:
705      include_multicast = False
706
707    packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
708    self.assertTrue(packets, msg + ": received no packets")
709
710    # If we receive a packet that matches what we expected, return it.
711    for packet in packets:
712      if self.PacketMatches(expected, packet):
713        return packet
714
715    # None of the packets matched. Call assertPacketMatches to output a diff
716    # between the expected packet and the last packet we received. In theory,
717    # we'd output a diff to the packet that's the best match for what we
718    # expected, but this is good enough for now.
719    try:
720      self.assertPacketMatches(expected, packets[-1])
721    except Exception as e:
722      raise UnexpectedPacketError(
723          "%s: diff with last packet:\n%s" % (msg, str(e)))
724
725  def Combinations(self, version):
726    """Produces a list of combinations to test."""
727    combinations = []
728
729    # Check packets addressed to the IP addresses of all our interfaces...
730    for dest_ip_netid in self.tuns:
731      ip_if = self.GetInterfaceName(dest_ip_netid)
732      myaddr = self.MyAddress(version, dest_ip_netid)
733      prefix = {4: "172.22.", 6: "2001:db8:aaaa::"}[version]
734      remoteaddr = self.GetRandomDestination(prefix)
735
736      # ... coming in on all our interfaces.
737      for netid in self.tuns:
738        iif = self.GetInterfaceName(netid)
739        combinations.append((netid, iif, ip_if, myaddr, remoteaddr))
740
741    return combinations
742
743  def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc):
744    msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra)
745    if reply_desc:
746      msg += ": Expecting %s on %s" % (reply_desc, iif)
747    else:
748      msg += ": Expecting no packets on %s" % iif
749    return msg
750
751  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
752    self.ReceivePacketOn(netid, packet)
753    if reply:
754      return self.ExpectPacketOn(netid, msg, reply)
755    else:
756      self.ExpectNoPacketsOn(netid, msg)
757      return None
758
759
760class InboundMarkingTest(MultiNetworkBaseTest):
761  """Class that automatically sets up inbound marking."""
762
763  @classmethod
764  def setUpClass(cls):
765    super(InboundMarkingTest, cls).setUpClass()
766    cls.SetInboundMarks(True)
767
768  @classmethod
769  def tearDownClass(cls):
770    cls.SetInboundMarks(False)
771    super(InboundMarkingTest, cls).tearDownClass()
772