1#!/usr/bin/python3 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import contextlib 18import fcntl 19import os 20import random 21import re 22from socket import * # pylint: disable=wildcard-import 23import struct 24import sys 25import unittest 26 27from scapy import all as scapy 28 29import binascii 30import csocket 31 32# TODO: Move these to csocket.py. 33SOL_IPV6 = 41 34IP_RECVERR = 11 35IPV6_RECVERR = 25 36IP_TRANSPARENT = 19 37IPV6_TRANSPARENT = 75 38IPV6_TCLASS = 67 39IPV6_FLOWLABEL_MGR = 32 40IPV6_FLOWINFO_SEND = 33 41 42SO_BINDTODEVICE = 25 43SO_MARK = 36 44SO_PROTOCOL = 38 45SO_DOMAIN = 39 46SO_COOKIE = 57 47 48ETH_P_IP = 0x0800 49ETH_P_IPV6 = 0x86dd 50 51IPPROTO_GRE = 47 52 53SIOCSIFHWADDR = 0x8924 54 55IPV6_FL_A_GET = 0 56IPV6_FL_A_PUT = 1 57IPV6_FL_A_RENEW = 1 58 59IPV6_FL_F_CREATE = 1 60IPV6_FL_F_EXCL = 2 61 62IPV6_FL_S_NONE = 0 63IPV6_FL_S_EXCL = 1 64IPV6_FL_S_ANY = 255 65 66IFNAMSIZ = 16 67 68IPV4_PING = b"\x08\x00\x00\x00\x0a\xce\x00\x03" 69IPV6_PING = b"\x80\x00\x00\x00\x0a\xce\x00\x03" 70 71IPV4_ADDR = "8.8.8.8" 72IPV4_ADDR2 = "8.8.4.4" 73IPV6_ADDR = "2001:4860:4860::8888" 74IPV6_ADDR2 = "2001:4860:4860::8844" 75 76IPV6_SEQ_DGRAM_HEADER = (" sl " 77 "local_address " 78 "remote_address " 79 "st tx_queue rx_queue tr tm->when retrnsmt" 80 " uid timeout inode ref pointer drops\n") 81 82UDP_HDR_LEN = 8 83 84# Arbitrary packet payload. 85UDP_PAYLOAD = bytes(scapy.DNS(rd=1, 86 id=random.randint(0, 65535), 87 qd=scapy.DNSQR(qname="wWW.GoOGle.CoM", 88 qtype="AAAA"))) 89 90# Unix group to use if we want to open sockets as non-root. 91AID_INET = 3003 92 93# Kernel log verbosity levels. 94KERN_INFO = 6 95 96LINUX_VERSION = csocket.LinuxVersion() 97LINUX_ANY_VERSION = (0, 0) 98 99def KernelAtLeast(versions): 100 """Checks the kernel version matches the specified versions. 101 102 Args: 103 versions: a list of versions expressed as tuples, 104 e.g., [(5, 10, 108), (5, 15, 31)]. The kernel version matches if it's 105 between each specified version and the next minor version with last digit 106 set to 0. In this example, the kernel version must match either: 107 >= 5.10.108 and < 5.15.0 108 >= 5.15.31 109 While this is less flexible than matching exact tuples, it allows the caller 110 to pass in fewer arguments, because Android only supports certain minor 111 versions (4.19, 5.4, 5.10, ...) 112 113 Returns: 114 True if the kernel version matches, False otherwise 115 """ 116 maxversion = (1000, 255, 65535) 117 for version in sorted(versions, reverse=True): 118 if version[:2] == maxversion[:2]: 119 raise ValueError("Duplicate minor version: %s %s", (version, maxversion)) 120 if LINUX_VERSION >= version and LINUX_VERSION < maxversion: 121 return True 122 maxversion = (version[0], version[1], 0) 123 return False 124 125def ByteToHex(b): 126 return "%02x" % (ord(b) if isinstance(b, str) else b) 127 128def GetWildcardAddress(version): 129 return {4: "0.0.0.0", 6: "::"}[version] 130 131def GetIpHdrLength(version): 132 return {4: 20, 6: 40}[version] 133 134def GetAddressFamily(version): 135 return {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 136 137 138def AddressLengthBits(version): 139 return {4: 32, 6: 128}[version] 140 141def GetAddressVersion(address): 142 if ":" not in address: 143 return 4 144 if address.startswith("::ffff"): 145 return 5 146 return 6 147 148def SetSocketTos(s, tos): 149 level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family] 150 option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family] 151 s.setsockopt(level, option, tos) 152 153 154def SetNonBlocking(fd): 155 flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) 156 fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) 157 158 159# Convenience functions to create sockets. 160def Socket(family, sock_type, protocol): 161 s = socket(family, sock_type, protocol) 162 csocket.SetSocketTimeout(s, 5000) 163 return s 164 165 166def PingSocket(family): 167 proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family] 168 return Socket(family, SOCK_DGRAM, proto) 169 170 171def IPv4PingSocket(): 172 return PingSocket(AF_INET) 173 174 175def IPv6PingSocket(): 176 return PingSocket(AF_INET6) 177 178 179def TCPSocket(family): 180 s = Socket(family, SOCK_STREAM, IPPROTO_TCP) 181 SetNonBlocking(s.fileno()) 182 return s 183 184 185def IPv4TCPSocket(): 186 return TCPSocket(AF_INET) 187 188 189def IPv6TCPSocket(): 190 return TCPSocket(AF_INET6) 191 192 193def UDPSocket(family): 194 return Socket(family, SOCK_DGRAM, IPPROTO_UDP) 195 196 197def RawGRESocket(family): 198 s = Socket(family, SOCK_RAW, IPPROTO_GRE) 199 return s 200 201 202def BindRandomPort(version, sock): 203 addr = {4: "0.0.0.0", 5: "::", 6: "::"}[version] 204 sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) 205 sock.bind((addr, 0)) 206 if sock.getsockopt(SOL_SOCKET, SO_PROTOCOL) == IPPROTO_TCP: 207 sock.listen(100) 208 port = sock.getsockname()[1] 209 return port 210 211 212def EnableFinWait(sock): 213 # Disabling SO_LINGER causes sockets to go into FIN_WAIT on close(). 214 sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 0, 0)) 215 216 217def DisableFinWait(sock): 218 # Enabling SO_LINGER with a timeout of zero causes close() to send RST. 219 sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0)) 220 221 222def CreateSocketPair(family, socktype, addr): 223 clientsock = socket(family, socktype, 0) 224 listensock = socket(family, socktype, 0) 225 listensock.bind((addr, 0)) 226 addr = listensock.getsockname() 227 if socktype == SOCK_STREAM: 228 listensock.listen(1) 229 clientsock.connect(listensock.getsockname()) 230 if socktype == SOCK_STREAM: 231 acceptedsock, _ = listensock.accept() 232 DisableFinWait(clientsock) 233 DisableFinWait(acceptedsock) 234 listensock.close() 235 else: 236 listensock.connect(clientsock.getsockname()) 237 acceptedsock = listensock 238 return clientsock, acceptedsock 239 240 241def GetInterfaceIndex(ifname): 242 with UDPSocket(AF_INET) as s: 243 ifr = struct.pack("%dsi" % IFNAMSIZ, ifname.encode(), 0) 244 ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr) 245 return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1] 246 247 248def SetInterfaceHWAddr(ifname, hwaddr): 249 with UDPSocket(AF_INET) as s: 250 hwaddr = hwaddr.replace(":", "") 251 hwaddr = binascii.unhexlify(hwaddr) 252 if len(hwaddr) != 6: 253 raise ValueError("Unknown hardware address length %d" % len(hwaddr)) 254 ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname.encode(), scapy.ARPHDR_ETHER, 255 hwaddr) 256 fcntl.ioctl(s, SIOCSIFHWADDR, ifr) 257 258 259def SetInterfaceState(ifname, up): 260 ifname_bytes = ifname.encode() 261 with UDPSocket(AF_INET) as s: 262 ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, 0) 263 ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr) 264 _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr) 265 if up: 266 flags |= scapy.IFF_UP 267 else: 268 flags &= ~scapy.IFF_UP 269 ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, flags) 270 ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr) 271 272 273def SetInterfaceUp(ifname): 274 return SetInterfaceState(ifname, True) 275 276 277def SetInterfaceDown(ifname): 278 return SetInterfaceState(ifname, False) 279 280 281def CanonicalizeIPv6Address(addr): 282 return inet_ntop(AF_INET6, inet_pton(AF_INET6, addr)) 283 284 285def FormatProcAddress(unformatted): 286 groups = [] 287 for i in range(0, len(unformatted), 4): 288 groups.append(unformatted[i:i+4]) 289 formatted = ":".join(groups) 290 # Compress the address. 291 address = CanonicalizeIPv6Address(formatted) 292 return address 293 294 295def FormatSockStatAddress(address): 296 if ":" in address: 297 family = AF_INET6 298 else: 299 family = AF_INET 300 binary = inet_pton(family, address) 301 out = "" 302 for i in range(0, len(binary), 4): 303 out += "%08X" % struct.unpack("=L", binary[i:i+4]) 304 return out 305 306 307def GetLinkAddress(ifname, linklocal): 308 with open("/proc/net/if_inet6") as if_inet6: 309 addresses = if_inet6.readlines() 310 for address in addresses: 311 address = [s for s in address.strip().split(" ") if s] 312 if address[5] == ifname: 313 if (linklocal and address[0].startswith("fe80") 314 or not linklocal and not address[0].startswith("fe80")): 315 # Convert the address from raw hex to something with colons in it. 316 return FormatProcAddress(address[0]) 317 return None 318 319 320def GetDefaultRoute(version=6): 321 if version == 6: 322 with open("/proc/net/ipv6_route") as ipv6_route: 323 routes = ipv6_route.readlines() 324 for route in routes: 325 route = [s for s in route.strip().split(" ") if s] 326 if (route[0] == "00000000000000000000000000000000" and route[1] == "00" 327 # Routes in non-default tables end up in /proc/net/ipv6_route!!! 328 and route[9] != "lo" and not route[9].startswith("nettest")): 329 return FormatProcAddress(route[4]), route[9] 330 raise ValueError("No IPv6 default route found") 331 elif version == 4: 332 with open("/proc/net/route") as ipv4_route: 333 routes = ipv4_route.readlines() 334 for route in routes: 335 route = [s for s in route.strip().split("\t") if s] 336 if route[1] == "00000000" and route[7] == "00000000": 337 gw, iface = route[2], route[0] 338 gw = inet_ntop(AF_INET, binascii.unhexlify(gw)[::-1]) 339 return gw, iface 340 raise ValueError("No IPv4 default route found") 341 else: 342 raise ValueError("Don't know about IPv%s" % version) 343 344 345def GetDefaultRouteInterface(): 346 unused_gw, iface = GetDefaultRoute() 347 return iface 348 349 350def MakeFlowLabelOption(addr, label): 351 # struct in6_flowlabel_req { 352 # struct in6_addr flr_dst; 353 # __be32 flr_label; 354 # __u8 flr_action; 355 # __u8 flr_share; 356 # __u16 flr_flags; 357 # __u16 flr_expires; 358 # __u16 flr_linger; 359 # __u32 __flr_pad; 360 # /* Options in format of IPV6_PKTOPTIONS */ 361 # }; 362 fmt = "16sIBBHHH4s" 363 assert struct.calcsize(fmt) == 32 364 addr = inet_pton(AF_INET6, addr) 365 assert len(addr) == 16 366 label = htonl(label & 0xfffff) 367 action = IPV6_FL_A_GET 368 share = IPV6_FL_S_ANY 369 flags = IPV6_FL_F_CREATE 370 pad = b"\x00" * 4 371 return struct.pack(fmt, addr, label, action, share, flags, 0, 0, pad) 372 373 374def SetFlowLabel(s, addr, label): 375 opt = MakeFlowLabelOption(addr, label) 376 s.setsockopt(SOL_IPV6, IPV6_FLOWLABEL_MGR, opt) 377 # Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1). 378 379 380def GetIptablesBinaryPath(version): 381 if version == 4: 382 paths = ( 383 "/sbin/iptables-legacy", 384 "/sbin/iptables", 385 "/system/bin/iptables-legacy", 386 "/system/bin/iptables", 387 ) 388 elif version == 6: 389 paths = ( 390 "/sbin/ip6tables-legacy", 391 "/sbin/ip6tables", 392 "/system/bin/ip6tables-legacy", 393 "/system/bin/ip6tables", 394 ) 395 for iptables_path in paths: 396 if os.access(iptables_path, os.X_OK): 397 return iptables_path 398 raise FileNotFoundError( 399 "iptables binary for IPv{} not found".format(version) + 400 ", checked: {}".format(", ".join(paths))) 401 402 403def RunIptablesCommand(version, args): 404 iptables_path = GetIptablesBinaryPath(version) 405 return os.spawnvp(os.P_WAIT, iptables_path, [iptables_path] + args.split(" ")) 406 407# Determine network configuration. 408try: 409 GetDefaultRoute(version=4) 410 HAVE_IPV4 = True 411except ValueError: 412 HAVE_IPV4 = False 413 414try: 415 GetDefaultRoute(version=6) 416 HAVE_IPV6 = True 417except ValueError: 418 HAVE_IPV6 = False 419 420class RunAsUidGid(object): 421 """Context guard to run a code block as a given UID.""" 422 423 def __init__(self, uid, gid): 424 self.uid = uid 425 self.gid = gid 426 427 def __enter__(self): 428 if self.gid: 429 self.saved_gid = os.getgid() 430 os.setgid(self.gid) 431 if self.uid: 432 self.saved_uids = os.getresuid() 433 self.saved_groups = os.getgroups() 434 os.setgroups(self.saved_groups + [AID_INET]) 435 os.setresuid(self.uid, self.uid, self.saved_uids[0]) 436 437 def __exit__(self, unused_type, unused_value, unused_traceback): 438 if self.uid: 439 os.setresuid(*self.saved_uids) 440 os.setgroups(self.saved_groups) 441 if self.gid: 442 os.setgid(self.saved_gid) 443 444class RunAsUid(RunAsUidGid): 445 """Context guard to run a code block as a given GID and UID.""" 446 447 def __init__(self, uid): 448 RunAsUidGid.__init__(self, uid, 0) 449 450class NetworkTest(unittest.TestCase): 451 452 @contextlib.contextmanager 453 def _errnoCheck(self, err_num): 454 with self.assertRaises(EnvironmentError) as context: 455 yield context 456 self.assertEqual(context.exception.errno, err_num) 457 458 def assertRaisesErrno(self, err_num, f=None, *args): 459 """Test that the system returns an errno error. 460 461 This works similarly to unittest.TestCase.assertRaises. You can call it as 462 an assertion, or use it as a context manager. 463 e.g. 464 self.assertRaisesErrno(errno.ENOENT, do_things, arg1, arg2) 465 or 466 with self.assertRaisesErrno(errno.ENOENT): 467 do_things(arg1, arg2) 468 469 Args: 470 err_num: an errno constant 471 f: (optional) A callable that should result in error 472 *args: arguments passed to f 473 """ 474 if f is None: 475 return self._errnoCheck(err_num) 476 else: 477 with self._errnoCheck(err_num): 478 f(*args) 479 480 def ReadProcNetSocket(self, protocol): 481 # Read file. 482 filename = "/proc/net/%s" % protocol 483 with open(filename) as f: 484 lines = f.readlines() 485 486 # Possibly check, and strip, header. 487 if protocol in ["icmp6", "raw6", "udp6"]: 488 self.assertEqual(IPV6_SEQ_DGRAM_HEADER, lines[0]) 489 lines = lines[1:] 490 491 # Check contents. 492 if protocol.endswith("6"): 493 addrlen = 32 494 else: 495 addrlen = 8 496 497 if protocol.startswith("tcp"): 498 # Real sockets have 5 extra numbers, timewait sockets have none. 499 end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+)$" 500 elif re.match("icmp|udp|raw", protocol): 501 # Drops. 502 end_regexp = " +([0-9]+) *$" 503 else: 504 raise ValueError("Don't know how to parse %s" % filename) 505 506 regexp = re.compile(r" *(\d+): " # bucket 507 "([0-9A-F]{%d}:[0-9A-F]{4}) " # srcaddr, port 508 "([0-9A-F]{%d}:[0-9A-F]{4}) " # dstaddr, port 509 "([0-9A-F][0-9A-F]) " # state 510 "([0-9A-F]{8}:[0-9A-F]{8}) " # mem 511 "([0-9A-F]{2}:[0-9A-F]{8}) " # ? 512 "([0-9A-F]{8}) +" # ? 513 "([0-9]+) +" # uid 514 "([0-9]+) +" # timeout 515 "([0-9]+) +" # inode 516 "([0-9]+) +" # refcnt 517 "([0-9a-f]+)" # sp 518 "%s" # icmp has spaces 519 % (addrlen, addrlen, end_regexp)) 520 # Return a list of lists with only source / dest addresses for now. 521 # TODO: consider returning a dict or namedtuple instead. 522 out = [] 523 for line in lines: 524 m = regexp.match(line) 525 if m is None: 526 raise ValueError("Failed match on [%s]" % line) 527 (_, src, dst, state, mem, 528 _, _, uid, _, _, refcnt, _, extra) = m.groups() 529 out.append([src, dst, state, mem, uid, refcnt, extra]) 530 return out 531 532 @staticmethod 533 def GetConsoleLogLevel(): 534 with open("/proc/sys/kernel/printk") as printk: 535 return int(printk.readline().split()[0]) 536 537 @staticmethod 538 def SetConsoleLogLevel(level): 539 with open("/proc/sys/kernel/printk", "w") as printk: 540 return printk.write("%s\n" % level) 541 542 543if __name__ == "__main__": 544 unittest.main() 545