1#!/usr/bin/python3 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Base module for multinetwork tests.""" 18 19import errno 20import fcntl 21import os 22import posix 23import random 24import re 25from socket import * # pylint: disable=wildcard-import 26import struct 27import time 28 29from scapy import all as scapy 30 31import csocket 32import iproute 33import net_test 34 35 36IFF_TUN = 1 37IFF_TAP = 2 38IFF_NO_PI = 0x1000 39TUNSETIFF = 0x400454ca 40 41SO_BINDTODEVICE = 25 42 43# Setsockopt values. 44IP_UNICAST_IF = 50 45IPV6_MULTICAST_IF = 17 46IPV6_UNICAST_IF = 76 47 48# Cmsg values. 49IP_TTL = 2 50IPV6_2292PKTOPTIONS = 6 51IPV6_FLOWINFO = 11 52IPV6_HOPLIMIT = 52 # Different from IPV6_UNICAST_HOPS, this is cmsg only. 53 54 55AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table" 56IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect" 57IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect" 58 59HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL) 60 61 62class ConfigurationError(AssertionError): 63 pass 64 65 66class UnexpectedPacketError(AssertionError): 67 pass 68 69 70def MakePktInfo(version, addr, ifindex): 71 family = {4: AF_INET, 6: AF_INET6}[version] 72 if not addr: 73 addr = {4: "0.0.0.0", 6: "::"}[version] 74 if addr: 75 addr = inet_pton(family, addr) 76 if version == 6: 77 return csocket.In6Pktinfo((addr, ifindex)).Pack() 78 else: 79 return csocket.InPktinfo((ifindex, addr, b"\x00" * 4)).Pack() 80 81 82class MultiNetworkBaseTest(net_test.NetworkTest): 83 """Base class for all multinetwork tests. 84 85 This class does not contain any test code, but contains code to set up and 86 tear a multi-network environment using multiple tun interfaces. The 87 environment is designed to be similar to a real Android device in terms of 88 rules and routes, and supports IPv4 and IPv6. 89 90 Tests wishing to use this environment should inherit from this class and 91 ensure that any setupClass, tearDownClass, setUp, and tearDown methods they 92 implement also call the superclass versions. 93 """ 94 95 # Must be between 1 and 256, since we put them in MAC addresses and IIDs. 96 NETIDS = [100, 150, 200, 250] 97 98 # Stores sysctl values to write back when the test completes. 99 saved_sysctls = {} 100 101 # Wether to output setup commands. 102 DEBUG = False 103 104 # The size of our UID ranges. 105 UID_RANGE_SIZE = 1000 106 107 # Rule priorities. 108 PRIORITY_UID = 100 109 PRIORITY_OIF = 200 110 PRIORITY_FWMARK = 300 111 PRIORITY_IIF = 400 112 PRIORITY_DEFAULT = 999 113 PRIORITY_UNREACHABLE = 1000 114 115 # Actual device routing is more complicated, involving more than one rule 116 # per NetId, but here we make do with just one rule that selects the lower 117 # 16 bits. 118 NETID_FWMASK = 0xffff 119 120 # For convenience. 121 IPV4_ADDR = net_test.IPV4_ADDR 122 IPV6_ADDR = net_test.IPV6_ADDR 123 IPV4_ADDR2 = net_test.IPV4_ADDR2 124 IPV6_ADDR2 = net_test.IPV6_ADDR2 125 IPV4_PING = net_test.IPV4_PING 126 IPV6_PING = net_test.IPV6_PING 127 128 RA_VALIDITY = 300 # seconds 129 130 @classmethod 131 def UidRangeForNetid(cls, netid): 132 return ( 133 cls.UID_RANGE_SIZE * netid, 134 cls.UID_RANGE_SIZE * (netid + 1) - 1 135 ) 136 137 @classmethod 138 def UidForNetid(cls, netid): 139 if not netid: 140 return 0 141 return random.randint(*cls.UidRangeForNetid(netid)) 142 143 @classmethod 144 def _TableForNetid(cls, netid): 145 if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices: 146 return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET) 147 else: 148 return netid 149 150 @staticmethod 151 def GetInterfaceName(netid): 152 return "nettest%d" % netid 153 154 @staticmethod 155 def RouterMacAddress(netid): 156 return "02:00:00:00:%02x:00" % netid 157 158 @staticmethod 159 def MyMacAddress(netid): 160 return "02:00:00:00:%02x:01" % netid 161 162 @staticmethod 163 def _RouterAddress(netid, version): 164 if version == 6: 165 return "fe80::%02x00" % netid 166 elif version == 4: 167 return "10.0.%d.1" % netid 168 else: 169 raise ValueError("Don't support IPv%s" % version) 170 171 @classmethod 172 def _MyIPv4Address(cls, netid): 173 return "10.0.%d.2" % netid 174 175 @classmethod 176 def _MyIPv6Address(cls, netid): 177 return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False) 178 179 @classmethod 180 def MyAddress(cls, version, netid): 181 return {4: cls._MyIPv4Address(netid), 182 5: cls._MyIPv4Address(netid), 183 6: cls._MyIPv6Address(netid)}[version] 184 185 @classmethod 186 def MySocketAddress(cls, version, netid): 187 addr = cls.MyAddress(version, netid) 188 return "::ffff:" + addr if version == 5 else addr 189 190 @classmethod 191 def MyLinkLocalAddress(cls, netid): 192 return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True) 193 194 @staticmethod 195 def OnlinkPrefixLen(version): 196 return {4: 24, 6: 64}[version] 197 198 @staticmethod 199 def OnlinkPrefix(version, netid): 200 return {4: "10.0.%d.0" % netid, 201 6: "2001:db8:%02x::" % netid}[version] 202 203 @staticmethod 204 def GetRandomDestination(prefix): 205 if "." in prefix: 206 return prefix + "%d.%d" % (random.randint(0, 255), random.randint(0, 255)) 207 else: 208 return prefix + "%x:%x" % (random.randint(0, 65535), 209 random.randint(0, 65535)) 210 211 def GetProtocolFamily(self, version): 212 return {4: AF_INET, 6: AF_INET6}[version] 213 214 @classmethod 215 def CreateTunInterface(cls, netid): 216 iface = cls.GetInterfaceName(netid) 217 try: 218 f = open("/dev/net/tun", "r+b", buffering=0) 219 except IOError: 220 f = open("/dev/tun", "r+b", buffering=0) 221 ifr = struct.pack("16sH", iface.encode(), IFF_TAP | IFF_NO_PI) 222 ifr += b"\x00" * (40 - len(ifr)) 223 fcntl.ioctl(f, TUNSETIFF, ifr) 224 # Give ourselves a predictable MAC address. 225 net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid)) 226 # Disable DAD so we don't have to wait for it. 227 cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0) 228 # Set accept_ra to 2, because that's what we use. 229 cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2) 230 net_test.SetInterfaceUp(iface) 231 net_test.SetNonBlocking(f) 232 return f 233 234 @classmethod 235 def SendRA(cls, netid, retranstimer=None, reachabletime=0, options=()): 236 validity = cls.RA_VALIDITY # seconds 237 macaddr = cls.RouterMacAddress(netid) 238 lladdr = cls._RouterAddress(netid, 6) 239 240 if retranstimer is None: 241 # If no retrans timer was specified, pick one that's as long as the 242 # router lifetime. This ensures that no spurious ND retransmits 243 # will interfere with test expectations. 244 retranstimer = validity * 1000 # Lifetime is in s, retrans timer in ms. 245 246 # We don't want any routes in the main table. If the kernel doesn't support 247 # putting RA routes into per-interface tables, configure routing manually. 248 routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0 249 250 ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") / 251 scapy.IPv6(src=lladdr, hlim=255) / 252 scapy.ICMPv6ND_RA(reachabletime=reachabletime, 253 retranstimer=retranstimer, 254 routerlifetime=routerlifetime) / 255 scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) / 256 scapy.ICMPv6NDOptPrefixInfo(prefix=cls.OnlinkPrefix(6, netid), 257 prefixlen=cls.OnlinkPrefixLen(6), 258 L=1, A=1, 259 validlifetime=validity, 260 preferredlifetime=validity)) 261 for option in options: 262 ra /= option 263 posix.write(cls.tuns[netid].fileno(), bytes(ra)) 264 265 @classmethod 266 def _RunSetupCommands(cls, netid, is_add): 267 for version in [4, 6]: 268 # Find out how to configure things. 269 iface = cls.GetInterfaceName(netid) 270 ifindex = cls.ifindices[netid] 271 macaddr = cls.RouterMacAddress(netid) 272 router = cls._RouterAddress(netid, version) 273 table = cls._TableForNetid(netid) 274 275 # Set up routing rules. 276 start, end = cls.UidRangeForNetid(netid) 277 cls.iproute.UidRangeRule(version, is_add, start, end, table, 278 cls.PRIORITY_UID) 279 cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF) 280 cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table, 281 cls.PRIORITY_FWMARK) 282 283 # Configure routing and addressing. 284 # 285 # IPv6 uses autoconf for everything, except if per-device autoconf routing 286 # tables are not supported, in which case the default route (only) is 287 # configured manually. For IPv4 we have to manually configure addresses, 288 # routes, and neighbour cache entries (since we don't reply to ARP or ND). 289 # 290 # Since deleting addresses also causes routes to be deleted, we need to 291 # be careful with ordering or the delete commands will fail with ENOENT. 292 # 293 # A real Android system will have both IPv4 and IPv6 routes for 294 # directly-connected subnets in the per-interface routing tables. Ensure 295 # we create those as well. 296 do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None) 297 if is_add: 298 if version == 4: 299 cls.iproute.AddAddress(cls._MyIPv4Address(netid), 300 cls.OnlinkPrefixLen(4), ifindex) 301 cls.iproute.AddNeighbour(version, router, macaddr, ifindex) 302 if do_routing: 303 cls.iproute.AddRoute(version, table, 304 cls.OnlinkPrefix(version, netid), 305 cls.OnlinkPrefixLen(version), None, ifindex) 306 cls.iproute.AddRoute(version, table, "default", 0, router, ifindex) 307 else: 308 if do_routing: 309 cls.iproute.DelRoute(version, table, "default", 0, router, ifindex) 310 cls.iproute.DelRoute(version, table, 311 cls.OnlinkPrefix(version, netid), 312 cls.OnlinkPrefixLen(version), None, ifindex) 313 if version == 4: 314 cls.iproute.DelNeighbour(version, router, macaddr, ifindex) 315 cls.iproute.DelAddress(cls._MyIPv4Address(netid), 316 cls.OnlinkPrefixLen(4), ifindex) 317 318 @classmethod 319 def SetMarkReflectSysctls(cls, value): 320 """Makes kernel-generated replies use the mark of the original packet.""" 321 cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value) 322 cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value) 323 324 @classmethod 325 def _SetInboundMarking(cls, netid, iface, is_add): 326 for version in [4, 6]: 327 # Run iptables to set up incoming packet marking. 328 add_del = "-A" if is_add else "-D" 329 iptables = {4: "iptables", 6: "ip6tables"}[version] 330 args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % ( 331 add_del, iface, netid) 332 if net_test.RunIptablesCommand(version, args): 333 raise ConfigurationError("Setup command failed: %s" % args) 334 335 @classmethod 336 def SetInboundMarks(cls, is_add): 337 for netid in cls.tuns: 338 cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add) 339 340 @classmethod 341 def SetDefaultNetwork(cls, netid): 342 table = cls._TableForNetid(netid) if netid else None 343 for version in [4, 6]: 344 is_add = table is not None 345 cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT) 346 347 @classmethod 348 def ClearDefaultNetwork(cls): 349 cls.SetDefaultNetwork(None) 350 351 @classmethod 352 def GetSysctl(cls, sysctl): 353 with open(sysctl, "r") as sysctl_file: 354 return sysctl_file.read() 355 356 @classmethod 357 def SetSysctl(cls, sysctl, value): 358 # Only save each sysctl value the first time we set it. This is so we can 359 # set it to arbitrary values multiple times and still write it back 360 # correctly at the end. 361 if sysctl not in cls.saved_sysctls: 362 cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl) 363 with open(sysctl, "w") as sysctl_file: 364 sysctl_file.write(str(value) + "\n") 365 366 @classmethod 367 def SetIPv6SysctlOnAllIfaces(cls, sysctl, value): 368 for netid in cls.tuns: 369 iface = cls.GetInterfaceName(netid) 370 name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl) 371 cls.SetSysctl(name, value) 372 373 @classmethod 374 def _RestoreSysctls(cls): 375 for sysctl, value in cls.saved_sysctls.items(): 376 try: 377 with open(sysctl, "w") as sysctl_file: 378 sysctl_file.write(value) 379 except IOError: 380 pass 381 382 @classmethod 383 def _ICMPRatelimitFilename(cls, version): 384 return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit", 385 6: "ipv6/icmp/ratelimit"}[version] 386 387 @classmethod 388 def _SetICMPRatelimit(cls, version, limit): 389 cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit) 390 391 @classmethod 392 def setUpClass(cls): 393 # This is per-class setup instead of per-testcase setup because shelling out 394 # to ip and iptables is slow, and because routing configuration doesn't 395 # change during the test. 396 cls.iproute = iproute.IPRoute() 397 cls.tuns = {} 398 cls.ifindices = {} 399 if HAVE_AUTOCONF_TABLE: 400 cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000) 401 cls.AUTOCONF_TABLE_OFFSET = -1000 402 else: 403 cls.AUTOCONF_TABLE_OFFSET = None 404 405 # Disable ICMP rate limits. These will be restored by _RestoreSysctls. 406 for version in [4, 6]: 407 cls._SetICMPRatelimit(version, 0) 408 409 for version in [4, 6]: 410 cls.iproute.UnreachableRule(version, True, cls.PRIORITY_UNREACHABLE) 411 412 for netid in cls.NETIDS: 413 cls.tuns[netid] = cls.CreateTunInterface(netid) 414 iface = cls.GetInterfaceName(netid) 415 cls.ifindices[netid] = net_test.GetInterfaceIndex(iface) 416 417 cls.SendRA(netid) 418 cls._RunSetupCommands(netid, True) 419 420 # Don't print lots of "device foo entered promiscuous mode" warnings. 421 cls.loglevel = cls.GetConsoleLogLevel() 422 cls.SetConsoleLogLevel(net_test.KERN_INFO) 423 424 # When running on device, don't send connections through FwmarkServer. 425 os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1" 426 427 # Uncomment to look around at interface and rule configuration while 428 # running in the background. (Once the test finishes running, all the 429 # interfaces and rules are gone.) 430 # time.sleep(30) 431 432 @classmethod 433 def tearDownClass(cls): 434 del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] 435 436 for version in [4, 6]: 437 try: 438 cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE) 439 except IOError: 440 pass 441 442 for netid in cls.tuns: 443 cls._RunSetupCommands(netid, False) 444 cls.tuns[netid].close() 445 446 cls.iproute.close() 447 cls._RestoreSysctls() 448 cls.SetConsoleLogLevel(cls.loglevel) 449 450 def setUp(self): 451 self.ClearTunQueues() 452 453 def SetSocketMark(self, s, netid): 454 if netid is None: 455 netid = 0 456 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid) 457 458 def GetSocketMark(self, s): 459 return s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 460 461 def ClearSocketMark(self, s): 462 self.SetSocketMark(s, 0) 463 464 def BindToDevice(self, s, iface): 465 if not iface: 466 iface = "" 467 s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface.encode()) 468 469 def SetUnicastInterface(self, s, ifindex): 470 # Otherwise, Python thinks it's a 1-byte option. 471 ifindex = struct.pack("!I", ifindex) 472 473 # Always set the IPv4 interface, because it will be used even on IPv6 474 # sockets if the destination address is a mapped address. 475 s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex) 476 if s.family == AF_INET6: 477 s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex) 478 479 def GetRemoteAddress(self, version): 480 return {4: self.IPV4_ADDR, 481 5: self.IPV4_ADDR, # see GetRemoteSocketAddress() 482 6: self.IPV6_ADDR}[version] 483 484 def GetRemoteSocketAddress(self, version): 485 addr = self.GetRemoteAddress(version) 486 return "::ffff:" + addr if version == 5 else addr 487 488 def GetOtherRemoteSocketAddress(self, version): 489 return {4: self.IPV4_ADDR2, 490 5: "::ffff:" + self.IPV4_ADDR2, 491 6: self.IPV6_ADDR2}[version] 492 493 def SelectInterface(self, s, netid, mode): 494 if mode == "uid": 495 os.fchown(s.fileno(), self.UidForNetid(netid), -1) 496 elif mode == "mark": 497 self.SetSocketMark(s, netid) 498 elif mode == "oif": 499 iface = self.GetInterfaceName(netid) if netid else "" 500 self.BindToDevice(s, iface) 501 elif mode == "ucast_oif": 502 self.SetUnicastInterface(s, self.ifindices.get(netid, 0)) 503 else: 504 raise ValueError("Unknown interface selection mode %s" % mode) 505 506 def BuildSocket(self, version, constructor, netid, routing_mode): 507 if version == 5: version = 6 508 s = constructor(self.GetProtocolFamily(version)) 509 510 if routing_mode not in [None, "uid"]: 511 self.SelectInterface(s, netid, routing_mode) 512 elif routing_mode == "uid": 513 os.fchown(s.fileno(), self.UidForNetid(netid), -1) 514 515 return s 516 517 def RandomNetid(self, exclude=None): 518 """Return a random netid from the list of netids 519 520 Args: 521 exclude: a netid or list of netids that should not be chosen 522 """ 523 if exclude is None: 524 exclude = [] 525 elif isinstance(exclude, int): 526 exclude = [exclude] 527 diff = [netid for netid in self.NETIDS if netid not in exclude] 528 return random.choice(diff) 529 530 def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs): 531 if netid is not None: 532 pktinfo = MakePktInfo(version, None, self.ifindices[netid]) 533 cmsg_level, cmsg_name = { 534 4: (net_test.SOL_IP, csocket.IP_PKTINFO), 535 6: (net_test.SOL_IPV6, csocket.IPV6_PKTINFO)}[version] 536 cmsgs.append((cmsg_level, cmsg_name, pktinfo)) 537 csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM) 538 539 def ReceiveEtherPacketOn(self, netid, packet): 540 posix.write(self.tuns[netid].fileno(), bytes(packet)) 541 542 def ReceivePacketOn(self, netid, ip_packet): 543 routermac = self.RouterMacAddress(netid) 544 mymac = self.MyMacAddress(netid) 545 packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet 546 self.ReceiveEtherPacketOn(netid, packet) 547 548 def ReadAllPacketsOn(self, netid, include_multicast=False): 549 """Return all queued packets on a netid as a list. 550 551 Args: 552 netid: The netid from which to read packets 553 include_multicast: A boolean, whether to remove multicast packets 554 (default=False) 555 """ 556 packets = [] 557 retries = 0 558 max_retries = 1 559 while True: 560 try: 561 packet = posix.read(self.tuns[netid].fileno(), 4096) 562 if not packet: 563 break 564 ether = scapy.Ether(packet) 565 # Multicast frames are frames where the first byte of the destination 566 # MAC address has 1 in the least-significant bit. 567 if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1: 568 packets.append(ether.payload) 569 except OSError as e: 570 # EAGAIN means there are no more packets waiting. 571 if e.errno == errno.EAGAIN: 572 # If we didn't see any packets, try again for good luck. 573 if not packets and retries < max_retries: 574 time.sleep(0.01) 575 retries += 1 576 continue 577 else: 578 break 579 # Anything else is unexpected. 580 else: 581 raise e 582 return packets 583 584 def InvalidateDstCache(self, version, netid): 585 """Invalidates destination cache entries of sockets on the specified table. 586 587 Creates and then deletes a low-priority throw route in the table for the 588 given netid, which invalidates the destination cache entries of any sockets 589 that refer to routes in that table. 590 591 The fact that this method actually invalidates destination cache entries is 592 tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel 593 does not re-route sockets when they are remarked, but does re-route them if 594 this method is called. 595 596 Args: 597 version: The IP version, 4 or 6. 598 netid: The netid to invalidate dst caches on. 599 """ 600 iface = self.GetInterfaceName(netid) 601 ifindex = self.ifindices[netid] 602 table = self._TableForNetid(netid) 603 for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]: 604 self.iproute._Route(version, iproute.RTPROT_STATIC, action, table, 605 "default", 0, nexthop=None, dev=None, mark=None, 606 uid=None, route_type=iproute.RTN_THROW, 607 priority=100000) 608 609 def ClearTunQueues(self): 610 # Keep reading packets on all netids until we get no packets on any of them. 611 waiting = None 612 while waiting != 0: 613 waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS) 614 615 def assertPacketMatches(self, expected, actual): 616 # The expected packet is just a rough sketch of the packet we expect to 617 # receive. For example, it doesn't contain fields we can't predict, such as 618 # initial TCP sequence numbers, or that depend on the host implementation 619 # and settings, such as TCP options. To check whether the packet matches 620 # what we expect, instead of just checking all the known fields one by one, 621 # we blank out fields in the actual packet and then compare the whole 622 # packets to each other as strings. Because we modify the actual packet, 623 # make a copy here. 624 actual = actual.copy() 625 626 # Blank out IPv4 fields that we can't predict, like ID and the DF bit. 627 actualip = actual.getlayer("IP") 628 expectedip = expected.getlayer("IP") 629 if actualip and expectedip: 630 actualip.id = expectedip.id 631 actualip.flags &= 5 632 actualip.chksum = None # Change the header, recalculate the checksum. 633 634 # Blank out the flow label, since new kernels randomize it by default. 635 actualipv6 = actual.getlayer("IPv6") 636 expectedipv6 = expected.getlayer("IPv6") 637 if actualipv6 and expectedipv6: 638 actualipv6.fl = expectedipv6.fl 639 640 # Blank out UDP fields that we can't predict (e.g., the source port for 641 # kernel-originated packets). 642 actualudp = actual.getlayer("UDP") 643 expectedudp = expected.getlayer("UDP") 644 if actualudp and expectedudp: 645 if expectedudp.sport is None: 646 actualudp.sport = None 647 actualudp.chksum = None 648 elif actualudp.chksum == 0xffff: 649 # Scapy does not appear to change 0 to 0xffff as required by RFC 768. 650 actualudp.chksum = 0 651 652 # Since the TCP code below messes with options, recalculate the length. 653 if actualip: 654 actualip.len = None 655 if actualipv6: 656 actualipv6.plen = None 657 658 # Blank out TCP fields that we can't predict. 659 actualtcp = actual.getlayer("TCP") 660 expectedtcp = expected.getlayer("TCP") 661 if actualtcp and expectedtcp: 662 actualtcp.dataofs = expectedtcp.dataofs 663 actualtcp.options = expectedtcp.options 664 actualtcp.window = expectedtcp.window 665 if expectedtcp.sport is None: 666 actualtcp.sport = None 667 if expectedtcp.seq is None: 668 actualtcp.seq = None 669 if expectedtcp.ack is None: 670 actualtcp.ack = None 671 actualtcp.chksum = None 672 673 # Serialize the packet so that expected packet fields that are only set when 674 # a packet is serialized e.g., the checksum) are filled in. 675 expected_real = expected.__class__(bytes(expected)) 676 actual_real = actual.__class__(bytes(actual)) 677 # repr() can be expensive. Call it only if the test is going to fail and we 678 # want to see the error. 679 if expected_real != actual_real: 680 self.assertEqual(repr(expected_real), repr(actual_real)) 681 682 def PacketMatches(self, expected, actual): 683 try: 684 self.assertPacketMatches(expected, actual) 685 return True 686 except AssertionError: 687 return False 688 689 def ExpectNoPacketsOn(self, netid, msg): 690 packets = self.ReadAllPacketsOn(netid) 691 if packets: 692 firstpacket = repr(packets[0]) 693 else: 694 firstpacket = "" 695 self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket) 696 697 def ExpectPacketOn(self, netid, msg, expected): 698 # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop 699 # multicast packets unless the packet we expect to see is a multicast 700 # packet. For now the only tests that use this are IPv6. 701 ipv6 = expected.getlayer("IPv6") 702 if ipv6 and ipv6.dst.startswith("ff"): 703 include_multicast = True 704 else: 705 include_multicast = False 706 707 packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast) 708 self.assertTrue(packets, msg + ": received no packets") 709 710 # If we receive a packet that matches what we expected, return it. 711 for packet in packets: 712 if self.PacketMatches(expected, packet): 713 return packet 714 715 # None of the packets matched. Call assertPacketMatches to output a diff 716 # between the expected packet and the last packet we received. In theory, 717 # we'd output a diff to the packet that's the best match for what we 718 # expected, but this is good enough for now. 719 try: 720 self.assertPacketMatches(expected, packets[-1]) 721 except Exception as e: 722 raise UnexpectedPacketError( 723 "%s: diff with last packet:\n%s" % (msg, str(e))) 724 725 def Combinations(self, version): 726 """Produces a list of combinations to test.""" 727 combinations = [] 728 729 # Check packets addressed to the IP addresses of all our interfaces... 730 for dest_ip_netid in self.tuns: 731 ip_if = self.GetInterfaceName(dest_ip_netid) 732 myaddr = self.MyAddress(version, dest_ip_netid) 733 prefix = {4: "172.22.", 6: "2001:db8:aaaa::"}[version] 734 remoteaddr = self.GetRandomDestination(prefix) 735 736 # ... coming in on all our interfaces. 737 for netid in self.tuns: 738 iif = self.GetInterfaceName(netid) 739 combinations.append((netid, iif, ip_if, myaddr, remoteaddr)) 740 741 return combinations 742 743 def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc): 744 msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra) 745 if reply_desc: 746 msg += ": Expecting %s on %s" % (reply_desc, iif) 747 else: 748 msg += ": Expecting no packets on %s" % iif 749 return msg 750 751 def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): 752 self.ReceivePacketOn(netid, packet) 753 if reply: 754 return self.ExpectPacketOn(netid, msg, reply) 755 else: 756 self.ExpectNoPacketsOn(netid, msg) 757 return None 758 759 760class InboundMarkingTest(MultiNetworkBaseTest): 761 """Class that automatically sets up inbound marking.""" 762 763 @classmethod 764 def setUpClass(cls): 765 super(InboundMarkingTest, cls).setUpClass() 766 cls.SetInboundMarks(True) 767 768 @classmethod 769 def tearDownClass(cls): 770 cls.SetInboundMarks(False) 771 super(InboundMarkingTest, cls).tearDownClass() 772