• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import contextlib
18import fcntl
19import os
20import random
21import re
22from socket import *  # pylint: disable=wildcard-import
23import struct
24import sys
25import unittest
26
27from scapy.arch import linux
28from scapy import all as scapy
29
30import binascii
31import csocket
32import gki
33
34# TODO: Move these to csocket.py.
35SOL_IPV6 = 41
36IP_RECVERR = 11
37IPV6_RECVERR = 25
38IP_TRANSPARENT = 19
39IPV6_TRANSPARENT = 75
40IPV6_TCLASS = 67
41IPV6_FLOWLABEL_MGR = 32
42IPV6_FLOWINFO_SEND = 33
43
44SO_BINDTODEVICE = 25
45SO_MARK = 36
46SO_PROTOCOL = 38
47SO_DOMAIN = 39
48SO_COOKIE = 57
49
50ETH_P_IP = 0x0800
51ETH_P_IPV6 = 0x86dd
52
53IPPROTO_GRE = 47
54
55SIOCSIFHWADDR = 0x8924
56
57IPV6_FL_A_GET = 0
58IPV6_FL_A_PUT = 1
59IPV6_FL_A_RENEW = 1
60
61IPV6_FL_F_CREATE = 1
62IPV6_FL_F_EXCL = 2
63
64IPV6_FL_S_NONE = 0
65IPV6_FL_S_EXCL = 1
66IPV6_FL_S_ANY = 255
67
68IFNAMSIZ = 16
69
70IPV4_PING = b"\x08\x00\x00\x00\x0a\xce\x00\x03"
71IPV6_PING = b"\x80\x00\x00\x00\x0a\xce\x00\x03"
72
73IPV4_ADDR = "8.8.8.8"
74IPV4_ADDR2 = "8.8.4.4"
75IPV6_ADDR = "2001:4860:4860::8888"
76IPV6_ADDR2 = "2001:4860:4860::8844"
77
78IPV6_SEQ_DGRAM_HEADER = ("  sl  "
79                         "local_address                         "
80                         "remote_address                        "
81                         "st tx_queue rx_queue tr tm->when retrnsmt"
82                         "   uid  timeout inode ref pointer drops\n")
83
84UDP_HDR_LEN = 8
85
86# Arbitrary packet payload.
87UDP_PAYLOAD = bytes(scapy.DNS(rd=1,
88                              id=random.randint(0, 65535),
89                              qd=scapy.DNSQR(qname="wWW.GoOGle.CoM",
90                                             qtype="AAAA")))
91
92# Unix group to use if we want to open sockets as non-root.
93AID_INET = 3003
94
95# Kernel log verbosity levels.
96KERN_INFO = 6
97
98# The following ends up being (VERSION, PATCHLEVEL, SUBLEVEL) from top of kernel's Makefile
99LINUX_VERSION = csocket.LinuxVersion()
100
101LINUX_ANY_VERSION = (0, 0, 0)
102
103# Linus always releases x.y.0-rcZ or x.y.0, any stable (incl. LTS) release will be x.y.1+
104IS_STABLE = (LINUX_VERSION[2] > 0)
105
106# From //system/gsid/libgsi.cpp IsGsiRunning()
107IS_GSI = os.access("/metadata/gsi/dsu/booted", os.F_OK)
108
109# NonGXI() is useful to run tests starting from a specific kernel version,
110# thus allowing one to test for correctly backported fixes,
111# without running the tests on non-updatable kernels (as part of GSI tests).
112#
113# Running vts_net_test on GSI image basically doesn't make sense, since
114# it's not like the unmodified vendor image - including the kernel - can be
115# realistically fixed in such a setup. Particularly problematic is GSI
116# on *older* pixel vendor: newer pixel images will have the fixed kernel,
117# but running newer GSI against ancient vendor will not see those fixes.
118#
119# Normally you'd also want to run on GKI kernels, but older release branches
120# are no longer maintained, so they also need to be excluded.
121# Proper GKI testing will happen on at the tip of the appropriate ACK/GKI branch.
122def NonGXI(major, minor):
123  """Checks the kernel version is >= major.minor, and not GKI or GSI."""
124
125  if IS_GSI or gki.IS_GKI:
126    return False
127  return LINUX_VERSION >= (major, minor, 0)
128
129def KernelAtLeast(versions):
130  """Checks the kernel version matches the specified versions.
131
132  Args:
133    versions: a list of versions expressed as tuples,
134    e.g., [(5, 10, 108), (5, 15, 31)]. The kernel version matches if it's
135    between each specified version and the next minor version with last digit
136    set to 0. In this example, the kernel version must match either:
137      >= 5.10.108 and < 5.15.0
138      >= 5.15.31
139    While this is less flexible than matching exact tuples, it allows the caller
140    to pass in fewer arguments, because Android only supports certain minor
141    versions (4.19, 5.4, 5.10, ...)
142
143  Returns:
144    True if the kernel version matches, False otherwise
145  """
146  maxversion = (1000, 255, 65535)
147  for version in sorted(versions, reverse=True):
148    if version[:2] == maxversion[:2]:
149      raise ValueError("Duplicate minor version: %s %s", (version, maxversion))
150    if LINUX_VERSION >= version and LINUX_VERSION < maxversion:
151      return True
152    maxversion = (version[0], version[1], 0)
153  return False
154
155def ByteToHex(b):
156  return "%02x" % (ord(b) if isinstance(b, str) else b)
157
158def GetWildcardAddress(version):
159  return {4: "0.0.0.0", 6: "::"}[version]
160
161def GetIpHdrLength(version):
162  return {4: 20, 6: 40}[version]
163
164def GetAddressFamily(version):
165  return {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
166
167
168def AddressLengthBits(version):
169  return {4: 32, 6: 128}[version]
170
171def GetAddressVersion(address):
172  if ":" not in address:
173    return 4
174  if address.startswith("::ffff"):
175    return 5
176  return 6
177
178def SetSocketTos(s, tos):
179  level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
180  option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family]
181  s.setsockopt(level, option, tos)
182
183
184def SetNonBlocking(fd):
185  flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
186  fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
187
188
189# Convenience functions to create sockets.
190def Socket(family, sock_type, protocol):
191  s = socket(family, sock_type, protocol)
192  csocket.SetSocketTimeout(s, 5000)
193  return s
194
195
196def PingSocket(family):
197  proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family]
198  return Socket(family, SOCK_DGRAM, proto)
199
200
201def IPv4PingSocket():
202  return PingSocket(AF_INET)
203
204
205def IPv6PingSocket():
206  return PingSocket(AF_INET6)
207
208
209def TCPSocket(family):
210  s = Socket(family, SOCK_STREAM, IPPROTO_TCP)
211  SetNonBlocking(s.fileno())
212  return s
213
214
215def IPv4TCPSocket():
216  return TCPSocket(AF_INET)
217
218
219def IPv6TCPSocket():
220  return TCPSocket(AF_INET6)
221
222
223def UDPSocket(family):
224  return Socket(family, SOCK_DGRAM, IPPROTO_UDP)
225
226
227def RawGRESocket(family):
228  s = Socket(family, SOCK_RAW, IPPROTO_GRE)
229  return s
230
231
232def BindRandomPort(version, sock):
233  addr = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
234  sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
235  sock.bind((addr, 0))
236  if sock.getsockopt(SOL_SOCKET, SO_PROTOCOL) == IPPROTO_TCP:
237    sock.listen(100)
238  port = sock.getsockname()[1]
239  return port
240
241
242def EnableFinWait(sock):
243  # Disabling SO_LINGER causes sockets to go into FIN_WAIT on close().
244  sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 0, 0))
245
246
247def DisableFinWait(sock):
248  # Enabling SO_LINGER with a timeout of zero causes close() to send RST.
249  sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0))
250
251
252def CreateSocketPair(family, socktype, addr):
253  clientsock = socket(family, socktype, 0)
254  listensock = socket(family, socktype, 0)
255  listensock.bind((addr, 0))
256  addr = listensock.getsockname()
257  if socktype == SOCK_STREAM:
258    listensock.listen(1)
259  clientsock.connect(listensock.getsockname())
260  if socktype == SOCK_STREAM:
261    acceptedsock, _ = listensock.accept()
262    DisableFinWait(clientsock)
263    DisableFinWait(acceptedsock)
264    listensock.close()
265  else:
266    listensock.connect(clientsock.getsockname())
267    acceptedsock = listensock
268  return clientsock, acceptedsock
269
270
271def GetInterfaceIndex(ifname):
272  with UDPSocket(AF_INET) as s:
273    ifr = struct.pack("%dsi" % IFNAMSIZ, ifname.encode(), 0)
274    ifr = fcntl.ioctl(s, linux.SIOCGIFINDEX, ifr)
275    return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
276
277
278def SetInterfaceHWAddr(ifname, hwaddr):
279  with UDPSocket(AF_INET) as s:
280    hwaddr = hwaddr.replace(":", "")
281    hwaddr = binascii.unhexlify(hwaddr)
282    if len(hwaddr) != 6:
283      raise ValueError("Unknown hardware address length %d" % len(hwaddr))
284    ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname.encode(), scapy.ARPHDR_ETHER,
285                      hwaddr)
286    fcntl.ioctl(s, SIOCSIFHWADDR, ifr)
287
288
289def SetInterfaceState(ifname, up):
290  ifname_bytes = ifname.encode()
291  with UDPSocket(AF_INET) as s:
292    ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, 0)
293    ifr = fcntl.ioctl(s, linux.SIOCGIFFLAGS, ifr)
294    _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
295    if up:
296      flags |= linux.IFF_UP
297    else:
298      flags &= ~linux.IFF_UP
299    ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, flags)
300    ifr = fcntl.ioctl(s, linux.SIOCSIFFLAGS, ifr)
301
302
303def SetInterfaceUp(ifname):
304  return SetInterfaceState(ifname, True)
305
306
307def SetInterfaceDown(ifname):
308  return SetInterfaceState(ifname, False)
309
310
311def CanonicalizeIPv6Address(addr):
312  return inet_ntop(AF_INET6, inet_pton(AF_INET6, addr))
313
314
315def FormatProcAddress(unformatted):
316  groups = []
317  for i in range(0, len(unformatted), 4):
318    groups.append(unformatted[i:i+4])
319  formatted = ":".join(groups)
320  # Compress the address.
321  address = CanonicalizeIPv6Address(formatted)
322  return address
323
324
325def FormatSockStatAddress(address):
326  if ":" in address:
327    family = AF_INET6
328  else:
329    family = AF_INET
330  binary = inet_pton(family, address)
331  out = ""
332  for i in range(0, len(binary), 4):
333    out += "%08X" % struct.unpack("=L", binary[i:i+4])
334  return out
335
336
337def GetLinkAddress(ifname, linklocal):
338  with open("/proc/net/if_inet6") as if_inet6:
339    addresses = if_inet6.readlines()
340  for address in addresses:
341    address = [s for s in address.strip().split(" ") if s]
342    if address[5] == ifname:
343      if (linklocal and address[0].startswith("fe80")
344          or not linklocal and not address[0].startswith("fe80")):
345        # Convert the address from raw hex to something with colons in it.
346        return FormatProcAddress(address[0])
347  return None
348
349
350def GetDefaultRoute(version=6):
351  if version == 6:
352    with open("/proc/net/ipv6_route") as ipv6_route:
353      routes = ipv6_route.readlines()
354    for route in routes:
355      route = [s for s in route.strip().split(" ") if s]
356      if (route[0] == "00000000000000000000000000000000" and route[1] == "00"
357          # Routes in non-default tables end up in /proc/net/ipv6_route!!!
358          and route[9] != "lo" and not route[9].startswith("nettest")):
359        return FormatProcAddress(route[4]), route[9]
360    raise ValueError("No IPv6 default route found")
361  elif version == 4:
362    with open("/proc/net/route") as ipv4_route:
363      routes = ipv4_route.readlines()
364    for route in routes:
365      route = [s for s in route.strip().split("\t") if s]
366      if route[1] == "00000000" and route[7] == "00000000":
367        gw, iface = route[2], route[0]
368        gw = inet_ntop(AF_INET, binascii.unhexlify(gw)[::-1])
369        return gw, iface
370    raise ValueError("No IPv4 default route found")
371  else:
372    raise ValueError("Don't know about IPv%s" % version)
373
374
375def GetDefaultRouteInterface():
376  unused_gw, iface = GetDefaultRoute()
377  return iface
378
379
380def MakeFlowLabelOption(addr, label):
381  # struct in6_flowlabel_req {
382  #         struct in6_addr flr_dst;
383  #         __be32  flr_label;
384  #         __u8    flr_action;
385  #         __u8    flr_share;
386  #         __u16   flr_flags;
387  #         __u16   flr_expires;
388  #         __u16   flr_linger;
389  #         __u32   __flr_pad;
390  #         /* Options in format of IPV6_PKTOPTIONS */
391  # };
392  fmt = "16sIBBHHH4s"
393  assert struct.calcsize(fmt) == 32
394  addr = inet_pton(AF_INET6, addr)
395  assert len(addr) == 16
396  label = htonl(label & 0xfffff)
397  action = IPV6_FL_A_GET
398  share = IPV6_FL_S_ANY
399  flags = IPV6_FL_F_CREATE
400  pad = b"\x00" * 4
401  return struct.pack(fmt, addr, label, action, share, flags, 0, 0, pad)
402
403
404def SetFlowLabel(s, addr, label):
405  opt = MakeFlowLabelOption(addr, label)
406  s.setsockopt(SOL_IPV6, IPV6_FLOWLABEL_MGR, opt)
407  # Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1).
408
409
410def GetIptablesBinaryPath(version):
411  if version == 4:
412    paths = (
413        "/sbin/iptables-legacy",
414        "/sbin/iptables",
415        "/system/bin/iptables-legacy",
416        "/system/bin/iptables",
417    )
418  elif version == 6:
419    paths = (
420        "/sbin/ip6tables-legacy",
421        "/sbin/ip6tables",
422        "/system/bin/ip6tables-legacy",
423        "/system/bin/ip6tables",
424    )
425  for iptables_path in paths:
426    if os.access(iptables_path, os.X_OK):
427      return iptables_path
428  raise FileNotFoundError(
429      "iptables binary for IPv{} not found".format(version) +
430      ", checked: {}".format(", ".join(paths)))
431
432
433def RunIptablesCommand(version, args):
434  iptables_path = GetIptablesBinaryPath(version)
435  return os.spawnvp(
436      os.P_WAIT, iptables_path,
437      [iptables_path, "-w"] + args.split(" "))
438
439# Determine network configuration.
440try:
441  GetDefaultRoute(version=4)
442  HAVE_IPV4 = True
443except ValueError:
444  HAVE_IPV4 = False
445
446try:
447  GetDefaultRoute(version=6)
448  HAVE_IPV6 = True
449except ValueError:
450  HAVE_IPV6 = False
451
452class RunAsUidGid(object):
453  """Context guard to run a code block as a given UID."""
454
455  def __init__(self, uid, gid):
456    self.uid = uid
457    self.gid = gid
458
459  def __enter__(self):
460    if self.gid:
461      self.saved_gid = os.getgid()
462      os.setgid(self.gid)
463    if self.uid:
464      self.saved_uids = os.getresuid()
465      self.saved_groups = os.getgroups()
466      os.setgroups(self.saved_groups + [AID_INET])
467      os.setresuid(self.uid, self.uid, self.saved_uids[0])
468
469  def __exit__(self, unused_type, unused_value, unused_traceback):
470    if self.uid:
471      os.setresuid(*self.saved_uids)
472      os.setgroups(self.saved_groups)
473    if self.gid:
474      os.setgid(self.saved_gid)
475
476class RunAsUid(RunAsUidGid):
477  """Context guard to run a code block as a given GID and UID."""
478
479  def __init__(self, uid):
480    RunAsUidGid.__init__(self, uid, 0)
481
482class NetworkTest(unittest.TestCase):
483
484  @contextlib.contextmanager
485  def _errnoCheck(self, err_num):
486    with self.assertRaises(EnvironmentError) as context:
487      yield context
488    self.assertEqual(context.exception.errno, err_num)
489
490  def assertRaisesErrno(self, err_num, f=None, *args):
491    """Test that the system returns an errno error.
492
493    This works similarly to unittest.TestCase.assertRaises. You can call it as
494    an assertion, or use it as a context manager.
495    e.g.
496        self.assertRaisesErrno(errno.ENOENT, do_things, arg1, arg2)
497    or
498        with self.assertRaisesErrno(errno.ENOENT):
499          do_things(arg1, arg2)
500
501    Args:
502      err_num: an errno constant
503      f: (optional) A callable that should result in error
504      *args: arguments passed to f
505    """
506    if f is None:
507      return self._errnoCheck(err_num)
508    else:
509      with self._errnoCheck(err_num):
510        f(*args)
511
512  def ReadProcNetSocket(self, protocol):
513    # Read file.
514    filename = "/proc/net/%s" % protocol
515    with open(filename) as f:
516      lines = f.readlines()
517
518    # Possibly check, and strip, header.
519    if protocol in ["icmp6", "raw6", "udp6"]:
520      self.assertEqual(IPV6_SEQ_DGRAM_HEADER, lines[0])
521    lines = lines[1:]
522
523    # Check contents.
524    if protocol.endswith("6"):
525      addrlen = 32
526    else:
527      addrlen = 8
528
529    if protocol.startswith("tcp"):
530      # Real sockets have 5 extra numbers, timewait sockets have none.
531      end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+)$"
532    elif re.match("icmp|udp|raw", protocol):
533      # Drops.
534      end_regexp = " +([0-9]+) *$"
535    else:
536      raise ValueError("Don't know how to parse %s" % filename)
537
538    regexp = re.compile(r" *(\d+): "                    # bucket
539                        "([0-9A-F]{%d}:[0-9A-F]{4}) "   # srcaddr, port
540                        "([0-9A-F]{%d}:[0-9A-F]{4}) "   # dstaddr, port
541                        "([0-9A-F][0-9A-F]) "           # state
542                        "([0-9A-F]{8}:[0-9A-F]{8}) "    # mem
543                        "([0-9A-F]{2}:[0-9A-F]{8}) "    # ?
544                        "([0-9A-F]{8}) +"               # ?
545                        "([0-9]+) +"                    # uid
546                        "([0-9]+) +"                    # timeout
547                        "([0-9]+) +"                    # inode
548                        "([0-9]+) +"                    # refcnt
549                        "([0-9a-f]+)"                   # sp
550                        "%s"                            # icmp has spaces
551                        % (addrlen, addrlen, end_regexp))
552    # Return a list of lists with only source / dest addresses for now.
553    # TODO: consider returning a dict or namedtuple instead.
554    out = []
555    for line in lines:
556      m = regexp.match(line)
557      if m is None:
558        raise ValueError("Failed match on [%s]" % line)
559      (_, src, dst, state, mem,
560       _, _, uid, _, _, refcnt, _, extra) = m.groups()
561      out.append([src, dst, state, mem, uid, refcnt, extra])
562    return out
563
564  @staticmethod
565  def GetConsoleLogLevel():
566    with open("/proc/sys/kernel/printk") as printk:
567      return int(printk.readline().split()[0])
568
569  @staticmethod
570  def SetConsoleLogLevel(level):
571    with open("/proc/sys/kernel/printk", "w") as printk:
572      return printk.write("%s\n" % level)
573
574
575if __name__ == "__main__":
576  unittest.main()
577