1#!/usr/bin/python 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Base module for multinetwork tests.""" 18 19import errno 20import fcntl 21import os 22import posix 23import random 24import re 25from socket import * # pylint: disable=wildcard-import 26import struct 27import time 28 29from scapy import all as scapy 30 31import csocket 32import iproute 33import net_test 34 35 36IFF_TUN = 1 37IFF_TAP = 2 38IFF_NO_PI = 0x1000 39TUNSETIFF = 0x400454ca 40 41SO_BINDTODEVICE = 25 42 43# Setsockopt values. 44IP_UNICAST_IF = 50 45IPV6_MULTICAST_IF = 17 46IPV6_UNICAST_IF = 76 47 48# Cmsg values. 49IP_TTL = 2 50IPV6_2292PKTOPTIONS = 6 51IPV6_FLOWINFO = 11 52IPV6_HOPLIMIT = 52 # Different from IPV6_UNICAST_HOPS, this is cmsg only. 53 54 55AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table" 56IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect" 57IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect" 58 59HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL) 60 61 62class UnexpectedPacketError(AssertionError): 63 pass 64 65 66def MakePktInfo(version, addr, ifindex): 67 family = {4: AF_INET, 6: AF_INET6}[version] 68 if not addr: 69 addr = {4: "0.0.0.0", 6: "::"}[version] 70 if addr: 71 addr = inet_pton(family, addr) 72 if version == 6: 73 return csocket.In6Pktinfo((addr, ifindex)).Pack() 74 else: 75 return csocket.InPktinfo((ifindex, addr, "\x00" * 4)).Pack() 76 77 78class MultiNetworkBaseTest(net_test.NetworkTest): 79 """Base class for all multinetwork tests. 80 81 This class does not contain any test code, but contains code to set up and 82 tear a multi-network environment using multiple tun interfaces. The 83 environment is designed to be similar to a real Android device in terms of 84 rules and routes, and supports IPv4 and IPv6. 85 86 Tests wishing to use this environment should inherit from this class and 87 ensure that any setupClass, tearDownClass, setUp, and tearDown methods they 88 implement also call the superclass versions. 89 """ 90 91 # Must be between 1 and 256, since we put them in MAC addresses and IIDs. 92 NETIDS = [100, 150, 200, 250] 93 94 # Stores sysctl values to write back when the test completes. 95 saved_sysctls = {} 96 97 # Wether to output setup commands. 98 DEBUG = False 99 100 # The size of our UID ranges. 101 UID_RANGE_SIZE = 1000 102 103 # Rule priorities. 104 PRIORITY_UID = 100 105 PRIORITY_OIF = 200 106 PRIORITY_FWMARK = 300 107 PRIORITY_IIF = 400 108 PRIORITY_DEFAULT = 999 109 PRIORITY_UNREACHABLE = 1000 110 111 # Actual device routing is more complicated, involving more than one rule 112 # per NetId, but here we make do with just one rule that selects the lower 113 # 16 bits. 114 NETID_FWMASK = 0xffff 115 116 # For convenience. 117 IPV4_ADDR = net_test.IPV4_ADDR 118 IPV6_ADDR = net_test.IPV6_ADDR 119 IPV4_ADDR2 = net_test.IPV4_ADDR2 120 IPV6_ADDR2 = net_test.IPV6_ADDR2 121 IPV4_PING = net_test.IPV4_PING 122 IPV6_PING = net_test.IPV6_PING 123 124 RA_VALIDITY = 300 # seconds 125 126 @classmethod 127 def UidRangeForNetid(cls, netid): 128 return ( 129 cls.UID_RANGE_SIZE * netid, 130 cls.UID_RANGE_SIZE * (netid + 1) - 1 131 ) 132 133 @classmethod 134 def UidForNetid(cls, netid): 135 if not netid: 136 return 0 137 return random.randint(*cls.UidRangeForNetid(netid)) 138 139 @classmethod 140 def _TableForNetid(cls, netid): 141 if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices: 142 return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET) 143 else: 144 return netid 145 146 @staticmethod 147 def GetInterfaceName(netid): 148 return "nettest%d" % netid 149 150 @staticmethod 151 def RouterMacAddress(netid): 152 return "02:00:00:00:%02x:00" % netid 153 154 @staticmethod 155 def MyMacAddress(netid): 156 return "02:00:00:00:%02x:01" % netid 157 158 @staticmethod 159 def _RouterAddress(netid, version): 160 if version == 6: 161 return "fe80::%02x00" % netid 162 elif version == 4: 163 return "10.0.%d.1" % netid 164 else: 165 raise ValueError("Don't support IPv%s" % version) 166 167 @classmethod 168 def _MyIPv4Address(cls, netid): 169 return "10.0.%d.2" % netid 170 171 @classmethod 172 def _MyIPv6Address(cls, netid): 173 return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False) 174 175 @classmethod 176 def MyAddress(cls, version, netid): 177 return {4: cls._MyIPv4Address(netid), 178 5: cls._MyIPv4Address(netid), 179 6: cls._MyIPv6Address(netid)}[version] 180 181 @classmethod 182 def MySocketAddress(cls, version, netid): 183 addr = cls.MyAddress(version, netid) 184 return "::ffff:" + addr if version == 5 else addr 185 186 @classmethod 187 def MyLinkLocalAddress(cls, netid): 188 return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True) 189 190 @staticmethod 191 def OnlinkPrefixLen(version): 192 return {4: 24, 6: 64}[version] 193 194 @staticmethod 195 def OnlinkPrefix(version, netid): 196 return {4: "10.0.%d.0" % netid, 197 6: "2001:db8:%02x::" % netid}[version] 198 199 @staticmethod 200 def GetRandomDestination(prefix): 201 if "." in prefix: 202 return prefix + "%d.%d" % (random.randint(0, 255), random.randint(0, 255)) 203 else: 204 return prefix + "%x:%x" % (random.randint(0, 65535), 205 random.randint(0, 65535)) 206 207 def GetProtocolFamily(self, version): 208 return {4: AF_INET, 6: AF_INET6}[version] 209 210 @classmethod 211 def CreateTunInterface(cls, netid): 212 iface = cls.GetInterfaceName(netid) 213 try: 214 f = open("/dev/net/tun", "r+b") 215 except IOError: 216 f = open("/dev/tun", "r+b") 217 ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI) 218 ifr += "\x00" * (40 - len(ifr)) 219 fcntl.ioctl(f, TUNSETIFF, ifr) 220 # Give ourselves a predictable MAC address. 221 net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid)) 222 # Disable DAD so we don't have to wait for it. 223 cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0) 224 # Set accept_ra to 2, because that's what we use. 225 cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2) 226 net_test.SetInterfaceUp(iface) 227 net_test.SetNonBlocking(f) 228 return f 229 230 @classmethod 231 def SendRA(cls, netid, retranstimer=None, reachabletime=0, options=()): 232 validity = cls.RA_VALIDITY # seconds 233 macaddr = cls.RouterMacAddress(netid) 234 lladdr = cls._RouterAddress(netid, 6) 235 236 if retranstimer is None: 237 # If no retrans timer was specified, pick one that's as long as the 238 # router lifetime. This ensures that no spurious ND retransmits 239 # will interfere with test expectations. 240 retranstimer = validity * 1000 # Lifetime is in s, retrans timer in ms. 241 242 # We don't want any routes in the main table. If the kernel doesn't support 243 # putting RA routes into per-interface tables, configure routing manually. 244 routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0 245 246 ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") / 247 scapy.IPv6(src=lladdr, hlim=255) / 248 scapy.ICMPv6ND_RA(reachabletime=reachabletime, 249 retranstimer=retranstimer, 250 routerlifetime=routerlifetime) / 251 scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) / 252 scapy.ICMPv6NDOptPrefixInfo(prefix=cls.OnlinkPrefix(6, netid), 253 prefixlen=cls.OnlinkPrefixLen(6), 254 L=1, A=1, 255 validlifetime=validity, 256 preferredlifetime=validity)) 257 for option in options: 258 ra /= option 259 posix.write(cls.tuns[netid].fileno(), str(ra)) 260 261 @classmethod 262 def _RunSetupCommands(cls, netid, is_add): 263 for version in [4, 6]: 264 # Find out how to configure things. 265 iface = cls.GetInterfaceName(netid) 266 ifindex = cls.ifindices[netid] 267 macaddr = cls.RouterMacAddress(netid) 268 router = cls._RouterAddress(netid, version) 269 table = cls._TableForNetid(netid) 270 271 # Set up routing rules. 272 start, end = cls.UidRangeForNetid(netid) 273 cls.iproute.UidRangeRule(version, is_add, start, end, table, 274 cls.PRIORITY_UID) 275 cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF) 276 cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table, 277 cls.PRIORITY_FWMARK) 278 279 # Configure routing and addressing. 280 # 281 # IPv6 uses autoconf for everything, except if per-device autoconf routing 282 # tables are not supported, in which case the default route (only) is 283 # configured manually. For IPv4 we have to manually configure addresses, 284 # routes, and neighbour cache entries (since we don't reply to ARP or ND). 285 # 286 # Since deleting addresses also causes routes to be deleted, we need to 287 # be careful with ordering or the delete commands will fail with ENOENT. 288 # 289 # A real Android system will have both IPv4 and IPv6 routes for 290 # directly-connected subnets in the per-interface routing tables. Ensure 291 # we create those as well. 292 do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None) 293 if is_add: 294 if version == 4: 295 cls.iproute.AddAddress(cls._MyIPv4Address(netid), 296 cls.OnlinkPrefixLen(4), ifindex) 297 cls.iproute.AddNeighbour(version, router, macaddr, ifindex) 298 if do_routing: 299 cls.iproute.AddRoute(version, table, 300 cls.OnlinkPrefix(version, netid), 301 cls.OnlinkPrefixLen(version), None, ifindex) 302 cls.iproute.AddRoute(version, table, "default", 0, router, ifindex) 303 else: 304 if do_routing: 305 cls.iproute.DelRoute(version, table, "default", 0, router, ifindex) 306 cls.iproute.DelRoute(version, table, 307 cls.OnlinkPrefix(version, netid), 308 cls.OnlinkPrefixLen(version), None, ifindex) 309 if version == 4: 310 cls.iproute.DelNeighbour(version, router, macaddr, ifindex) 311 cls.iproute.DelAddress(cls._MyIPv4Address(netid), 312 cls.OnlinkPrefixLen(4), ifindex) 313 314 @classmethod 315 def SetMarkReflectSysctls(cls, value): 316 """Makes kernel-generated replies use the mark of the original packet.""" 317 cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value) 318 cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value) 319 320 @classmethod 321 def _SetInboundMarking(cls, netid, iface, is_add): 322 for version in [4, 6]: 323 # Run iptables to set up incoming packet marking. 324 add_del = "-A" if is_add else "-D" 325 iptables = {4: "iptables", 6: "ip6tables"}[version] 326 args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % ( 327 add_del, iface, netid) 328 if net_test.RunIptablesCommand(version, args): 329 raise ConfigurationError("Setup command failed: %s" % args) 330 331 @classmethod 332 def SetInboundMarks(cls, is_add): 333 for netid in cls.tuns: 334 cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add) 335 336 @classmethod 337 def SetDefaultNetwork(cls, netid): 338 table = cls._TableForNetid(netid) if netid else None 339 for version in [4, 6]: 340 is_add = table is not None 341 cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT) 342 343 @classmethod 344 def ClearDefaultNetwork(cls): 345 cls.SetDefaultNetwork(None) 346 347 @classmethod 348 def GetSysctl(cls, sysctl): 349 return open(sysctl, "r").read() 350 351 @classmethod 352 def SetSysctl(cls, sysctl, value): 353 # Only save each sysctl value the first time we set it. This is so we can 354 # set it to arbitrary values multiple times and still write it back 355 # correctly at the end. 356 if sysctl not in cls.saved_sysctls: 357 cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl) 358 open(sysctl, "w").write(str(value) + "\n") 359 360 @classmethod 361 def SetIPv6SysctlOnAllIfaces(cls, sysctl, value): 362 for netid in cls.tuns: 363 iface = cls.GetInterfaceName(netid) 364 name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl) 365 cls.SetSysctl(name, value) 366 367 @classmethod 368 def _RestoreSysctls(cls): 369 for sysctl, value in cls.saved_sysctls.items(): 370 try: 371 open(sysctl, "w").write(value) 372 except IOError: 373 pass 374 375 @classmethod 376 def _ICMPRatelimitFilename(cls, version): 377 return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit", 378 6: "ipv6/icmp/ratelimit"}[version] 379 380 @classmethod 381 def _SetICMPRatelimit(cls, version, limit): 382 cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit) 383 384 @classmethod 385 def setUpClass(cls): 386 # This is per-class setup instead of per-testcase setup because shelling out 387 # to ip and iptables is slow, and because routing configuration doesn't 388 # change during the test. 389 cls.iproute = iproute.IPRoute() 390 cls.tuns = {} 391 cls.ifindices = {} 392 if HAVE_AUTOCONF_TABLE: 393 cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000) 394 cls.AUTOCONF_TABLE_OFFSET = -1000 395 else: 396 cls.AUTOCONF_TABLE_OFFSET = None 397 398 # Disable ICMP rate limits. These will be restored by _RestoreSysctls. 399 for version in [4, 6]: 400 cls._SetICMPRatelimit(version, 0) 401 402 for version in [4, 6]: 403 cls.iproute.UnreachableRule(version, True, cls.PRIORITY_UNREACHABLE) 404 405 for netid in cls.NETIDS: 406 cls.tuns[netid] = cls.CreateTunInterface(netid) 407 iface = cls.GetInterfaceName(netid) 408 cls.ifindices[netid] = net_test.GetInterfaceIndex(iface) 409 410 cls.SendRA(netid) 411 cls._RunSetupCommands(netid, True) 412 413 # Don't print lots of "device foo entered promiscuous mode" warnings. 414 cls.loglevel = cls.GetConsoleLogLevel() 415 cls.SetConsoleLogLevel(net_test.KERN_INFO) 416 417 # When running on device, don't send connections through FwmarkServer. 418 os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1" 419 420 # Uncomment to look around at interface and rule configuration while 421 # running in the background. (Once the test finishes running, all the 422 # interfaces and rules are gone.) 423 # time.sleep(30) 424 425 @classmethod 426 def tearDownClass(cls): 427 del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] 428 429 for version in [4, 6]: 430 try: 431 cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE) 432 except IOError: 433 pass 434 435 for netid in cls.tuns: 436 cls._RunSetupCommands(netid, False) 437 cls.tuns[netid].close() 438 439 cls._RestoreSysctls() 440 cls.SetConsoleLogLevel(cls.loglevel) 441 442 def setUp(self): 443 self.ClearTunQueues() 444 445 def SetSocketMark(self, s, netid): 446 if netid is None: 447 netid = 0 448 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid) 449 450 def GetSocketMark(self, s): 451 return s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 452 453 def ClearSocketMark(self, s): 454 self.SetSocketMark(s, 0) 455 456 def BindToDevice(self, s, iface): 457 if not iface: 458 iface = "" 459 s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface) 460 461 def SetUnicastInterface(self, s, ifindex): 462 # Otherwise, Python thinks it's a 1-byte option. 463 ifindex = struct.pack("!I", ifindex) 464 465 # Always set the IPv4 interface, because it will be used even on IPv6 466 # sockets if the destination address is a mapped address. 467 s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex) 468 if s.family == AF_INET6: 469 s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex) 470 471 def GetRemoteAddress(self, version): 472 return {4: self.IPV4_ADDR, 473 5: self.IPV4_ADDR, # see GetRemoteSocketAddress() 474 6: self.IPV6_ADDR}[version] 475 476 def GetRemoteSocketAddress(self, version): 477 addr = self.GetRemoteAddress(version) 478 return "::ffff:" + addr if version == 5 else addr 479 480 def GetOtherRemoteSocketAddress(self, version): 481 return {4: self.IPV4_ADDR2, 482 5: "::ffff:" + self.IPV4_ADDR2, 483 6: self.IPV6_ADDR2}[version] 484 485 def SelectInterface(self, s, netid, mode): 486 if mode == "uid": 487 os.fchown(s.fileno(), self.UidForNetid(netid), -1) 488 elif mode == "mark": 489 self.SetSocketMark(s, netid) 490 elif mode == "oif": 491 iface = self.GetInterfaceName(netid) if netid else "" 492 self.BindToDevice(s, iface) 493 elif mode == "ucast_oif": 494 self.SetUnicastInterface(s, self.ifindices.get(netid, 0)) 495 else: 496 raise ValueError("Unknown interface selection mode %s" % mode) 497 498 def BuildSocket(self, version, constructor, netid, routing_mode): 499 if version == 5: version = 6 500 s = constructor(self.GetProtocolFamily(version)) 501 502 if routing_mode not in [None, "uid"]: 503 self.SelectInterface(s, netid, routing_mode) 504 elif routing_mode == "uid": 505 os.fchown(s.fileno(), self.UidForNetid(netid), -1) 506 507 return s 508 509 def RandomNetid(self, exclude=None): 510 """Return a random netid from the list of netids 511 512 Args: 513 exclude: a netid or list of netids that should not be chosen 514 """ 515 if exclude is None: 516 exclude = [] 517 elif isinstance(exclude, int): 518 exclude = [exclude] 519 diff = [netid for netid in self.NETIDS if netid not in exclude] 520 return random.choice(diff) 521 522 def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs): 523 if netid is not None: 524 pktinfo = MakePktInfo(version, None, self.ifindices[netid]) 525 cmsg_level, cmsg_name = { 526 4: (net_test.SOL_IP, csocket.IP_PKTINFO), 527 6: (net_test.SOL_IPV6, csocket.IPV6_PKTINFO)}[version] 528 cmsgs.append((cmsg_level, cmsg_name, pktinfo)) 529 csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM) 530 531 def ReceiveEtherPacketOn(self, netid, packet): 532 posix.write(self.tuns[netid].fileno(), str(packet)) 533 534 def ReceivePacketOn(self, netid, ip_packet): 535 routermac = self.RouterMacAddress(netid) 536 mymac = self.MyMacAddress(netid) 537 packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet 538 self.ReceiveEtherPacketOn(netid, packet) 539 540 def ReadAllPacketsOn(self, netid, include_multicast=False): 541 """Return all queued packets on a netid as a list. 542 543 Args: 544 netid: The netid from which to read packets 545 include_multicast: A boolean, whether to remove multicast packets 546 (default=False) 547 """ 548 packets = [] 549 retries = 0 550 max_retries = 1 551 while True: 552 try: 553 packet = posix.read(self.tuns[netid].fileno(), 4096) 554 if not packet: 555 break 556 ether = scapy.Ether(packet) 557 # Multicast frames are frames where the first byte of the destination 558 # MAC address has 1 in the least-significant bit. 559 if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1: 560 packets.append(ether.payload) 561 except OSError as e: 562 # EAGAIN means there are no more packets waiting. 563 if re.match(e.message, os.strerror(errno.EAGAIN)): 564 # If we didn't see any packets, try again for good luck. 565 if not packets and retries < max_retries: 566 time.sleep(0.01) 567 retries += 1 568 continue 569 else: 570 break 571 # Anything else is unexpected. 572 else: 573 raise e 574 return packets 575 576 def InvalidateDstCache(self, version, netid): 577 """Invalidates destination cache entries of sockets on the specified table. 578 579 Creates and then deletes a low-priority throw route in the table for the 580 given netid, which invalidates the destination cache entries of any sockets 581 that refer to routes in that table. 582 583 The fact that this method actually invalidates destination cache entries is 584 tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel 585 does not re-route sockets when they are remarked, but does re-route them if 586 this method is called. 587 588 Args: 589 version: The IP version, 4 or 6. 590 netid: The netid to invalidate dst caches on. 591 """ 592 iface = self.GetInterfaceName(netid) 593 ifindex = self.ifindices[netid] 594 table = self._TableForNetid(netid) 595 for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]: 596 self.iproute._Route(version, iproute.RTPROT_STATIC, action, table, 597 "default", 0, nexthop=None, dev=None, mark=None, 598 uid=None, route_type=iproute.RTN_THROW, 599 priority=100000) 600 601 def ClearTunQueues(self): 602 # Keep reading packets on all netids until we get no packets on any of them. 603 waiting = None 604 while waiting != 0: 605 waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS) 606 607 def assertPacketMatches(self, expected, actual): 608 # The expected packet is just a rough sketch of the packet we expect to 609 # receive. For example, it doesn't contain fields we can't predict, such as 610 # initial TCP sequence numbers, or that depend on the host implementation 611 # and settings, such as TCP options. To check whether the packet matches 612 # what we expect, instead of just checking all the known fields one by one, 613 # we blank out fields in the actual packet and then compare the whole 614 # packets to each other as strings. Because we modify the actual packet, 615 # make a copy here. 616 actual = actual.copy() 617 618 # Blank out IPv4 fields that we can't predict, like ID and the DF bit. 619 actualip = actual.getlayer("IP") 620 expectedip = expected.getlayer("IP") 621 if actualip and expectedip: 622 actualip.id = expectedip.id 623 actualip.flags &= 5 624 actualip.chksum = None # Change the header, recalculate the checksum. 625 626 # Blank out the flow label, since new kernels randomize it by default. 627 actualipv6 = actual.getlayer("IPv6") 628 expectedipv6 = expected.getlayer("IPv6") 629 if actualipv6 and expectedipv6: 630 actualipv6.fl = expectedipv6.fl 631 632 # Blank out UDP fields that we can't predict (e.g., the source port for 633 # kernel-originated packets). 634 actualudp = actual.getlayer("UDP") 635 expectedudp = expected.getlayer("UDP") 636 if actualudp and expectedudp: 637 if expectedudp.sport is None: 638 actualudp.sport = None 639 actualudp.chksum = None 640 elif actualudp.chksum == 0xffff: 641 # Scapy does not appear to change 0 to 0xffff as required by RFC 768. 642 actualudp.chksum = 0 643 644 # Since the TCP code below messes with options, recalculate the length. 645 if actualip: 646 actualip.len = None 647 if actualipv6: 648 actualipv6.plen = None 649 650 # Blank out TCP fields that we can't predict. 651 actualtcp = actual.getlayer("TCP") 652 expectedtcp = expected.getlayer("TCP") 653 if actualtcp and expectedtcp: 654 actualtcp.dataofs = expectedtcp.dataofs 655 actualtcp.options = expectedtcp.options 656 actualtcp.window = expectedtcp.window 657 if expectedtcp.sport is None: 658 actualtcp.sport = None 659 if expectedtcp.seq is None: 660 actualtcp.seq = None 661 if expectedtcp.ack is None: 662 actualtcp.ack = None 663 actualtcp.chksum = None 664 665 # Serialize the packet so that expected packet fields that are only set when 666 # a packet is serialized e.g., the checksum) are filled in. 667 expected_real = expected.__class__(str(expected)) 668 actual_real = actual.__class__(str(actual)) 669 # repr() can be expensive. Call it only if the test is going to fail and we 670 # want to see the error. 671 if expected_real != actual_real: 672 self.assertEqual(repr(expected_real), repr(actual_real)) 673 674 def PacketMatches(self, expected, actual): 675 try: 676 self.assertPacketMatches(expected, actual) 677 return True 678 except AssertionError: 679 return False 680 681 def ExpectNoPacketsOn(self, netid, msg): 682 packets = self.ReadAllPacketsOn(netid) 683 if packets: 684 firstpacket = repr(packets[0]) 685 else: 686 firstpacket = "" 687 self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket) 688 689 def ExpectPacketOn(self, netid, msg, expected): 690 # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop 691 # multicast packets unless the packet we expect to see is a multicast 692 # packet. For now the only tests that use this are IPv6. 693 ipv6 = expected.getlayer("IPv6") 694 if ipv6 and ipv6.dst.startswith("ff"): 695 include_multicast = True 696 else: 697 include_multicast = False 698 699 packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast) 700 self.assertTrue(packets, msg + ": received no packets") 701 702 # If we receive a packet that matches what we expected, return it. 703 for packet in packets: 704 if self.PacketMatches(expected, packet): 705 return packet 706 707 # None of the packets matched. Call assertPacketMatches to output a diff 708 # between the expected packet and the last packet we received. In theory, 709 # we'd output a diff to the packet that's the best match for what we 710 # expected, but this is good enough for now. 711 try: 712 self.assertPacketMatches(expected, packets[-1]) 713 except Exception as e: 714 raise UnexpectedPacketError( 715 "%s: diff with last packet:\n%s" % (msg, e.message)) 716 717 def Combinations(self, version): 718 """Produces a list of combinations to test.""" 719 combinations = [] 720 721 # Check packets addressed to the IP addresses of all our interfaces... 722 for dest_ip_netid in self.tuns: 723 ip_if = self.GetInterfaceName(dest_ip_netid) 724 myaddr = self.MyAddress(version, dest_ip_netid) 725 prefix = {4: "172.22.", 6: "2001:db8:aaaa::"}[version] 726 remoteaddr = self.GetRandomDestination(prefix) 727 728 # ... coming in on all our interfaces. 729 for netid in self.tuns: 730 iif = self.GetInterfaceName(netid) 731 combinations.append((netid, iif, ip_if, myaddr, remoteaddr)) 732 733 return combinations 734 735 def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc): 736 msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra) 737 if reply_desc: 738 msg += ": Expecting %s on %s" % (reply_desc, iif) 739 else: 740 msg += ": Expecting no packets on %s" % iif 741 return msg 742 743 def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): 744 self.ReceivePacketOn(netid, packet) 745 if reply: 746 return self.ExpectPacketOn(netid, msg, reply) 747 else: 748 self.ExpectNoPacketsOn(netid, msg) 749 return None 750 751 752class InboundMarkingTest(MultiNetworkBaseTest): 753 """Class that automatically sets up inbound marking.""" 754 755 @classmethod 756 def setUpClass(cls): 757 super(InboundMarkingTest, cls).setUpClass() 758 cls.SetInboundMarks(True) 759 760 @classmethod 761 def tearDownClass(cls): 762 cls.SetInboundMarks(False) 763 super(InboundMarkingTest, cls).tearDownClass() 764