• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import contextlib
18import fcntl
19import os
20import random
21import re
22from socket import *  # pylint: disable=wildcard-import
23import struct
24import sys
25import unittest
26
27from scapy import all as scapy
28
29import binascii
30import csocket
31
32# TODO: Move these to csocket.py.
33SOL_IPV6 = 41
34IP_RECVERR = 11
35IPV6_RECVERR = 25
36IP_TRANSPARENT = 19
37IPV6_TRANSPARENT = 75
38IPV6_TCLASS = 67
39IPV6_FLOWLABEL_MGR = 32
40IPV6_FLOWINFO_SEND = 33
41
42SO_BINDTODEVICE = 25
43SO_MARK = 36
44SO_PROTOCOL = 38
45SO_DOMAIN = 39
46SO_COOKIE = 57
47
48ETH_P_IP = 0x0800
49ETH_P_IPV6 = 0x86dd
50
51IPPROTO_GRE = 47
52
53SIOCSIFHWADDR = 0x8924
54
55IPV6_FL_A_GET = 0
56IPV6_FL_A_PUT = 1
57IPV6_FL_A_RENEW = 1
58
59IPV6_FL_F_CREATE = 1
60IPV6_FL_F_EXCL = 2
61
62IPV6_FL_S_NONE = 0
63IPV6_FL_S_EXCL = 1
64IPV6_FL_S_ANY = 255
65
66IFNAMSIZ = 16
67
68IPV4_PING = b"\x08\x00\x00\x00\x0a\xce\x00\x03"
69IPV6_PING = b"\x80\x00\x00\x00\x0a\xce\x00\x03"
70
71IPV4_ADDR = "8.8.8.8"
72IPV4_ADDR2 = "8.8.4.4"
73IPV6_ADDR = "2001:4860:4860::8888"
74IPV6_ADDR2 = "2001:4860:4860::8844"
75
76IPV6_SEQ_DGRAM_HEADER = ("  sl  "
77                         "local_address                         "
78                         "remote_address                        "
79                         "st tx_queue rx_queue tr tm->when retrnsmt"
80                         "   uid  timeout inode ref pointer drops\n")
81
82UDP_HDR_LEN = 8
83
84# Arbitrary packet payload.
85UDP_PAYLOAD = bytes(scapy.DNS(rd=1,
86                              id=random.randint(0, 65535),
87                              qd=scapy.DNSQR(qname="wWW.GoOGle.CoM",
88                                             qtype="AAAA")))
89
90# Unix group to use if we want to open sockets as non-root.
91AID_INET = 3003
92
93# Kernel log verbosity levels.
94KERN_INFO = 6
95
96LINUX_VERSION = csocket.LinuxVersion()
97LINUX_ANY_VERSION = (0, 0)
98
99def KernelAtLeast(versions):
100  """Checks the kernel version matches the specified versions.
101
102  Args:
103    versions: a list of versions expressed as tuples,
104    e.g., [(5, 10, 108), (5, 15, 31)]. The kernel version matches if it's
105    between each specified version and the next minor version with last digit
106    set to 0. In this example, the kernel version must match either:
107      >= 5.10.108 and < 5.15.0
108      >= 5.15.31
109    While this is less flexible than matching exact tuples, it allows the caller
110    to pass in fewer arguments, because Android only supports certain minor
111    versions (4.19, 5.4, 5.10, ...)
112
113  Returns:
114    True if the kernel version matches, False otherwise
115  """
116  maxversion = (1000, 255, 65535)
117  for version in sorted(versions, reverse=True):
118    if version[:2] == maxversion[:2]:
119      raise ValueError("Duplicate minor version: %s %s", (version, maxversion))
120    if LINUX_VERSION >= version and LINUX_VERSION < maxversion:
121      return True
122    maxversion = (version[0], version[1], 0)
123  return False
124
125def ByteToHex(b):
126  return "%02x" % (ord(b) if isinstance(b, str) else b)
127
128def GetWildcardAddress(version):
129  return {4: "0.0.0.0", 6: "::"}[version]
130
131def GetIpHdrLength(version):
132  return {4: 20, 6: 40}[version]
133
134def GetAddressFamily(version):
135  return {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
136
137
138def AddressLengthBits(version):
139  return {4: 32, 6: 128}[version]
140
141def GetAddressVersion(address):
142  if ":" not in address:
143    return 4
144  if address.startswith("::ffff"):
145    return 5
146  return 6
147
148def SetSocketTos(s, tos):
149  level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
150  option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family]
151  s.setsockopt(level, option, tos)
152
153
154def SetNonBlocking(fd):
155  flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
156  fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
157
158
159# Convenience functions to create sockets.
160def Socket(family, sock_type, protocol):
161  s = socket(family, sock_type, protocol)
162  csocket.SetSocketTimeout(s, 5000)
163  return s
164
165
166def PingSocket(family):
167  proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family]
168  return Socket(family, SOCK_DGRAM, proto)
169
170
171def IPv4PingSocket():
172  return PingSocket(AF_INET)
173
174
175def IPv6PingSocket():
176  return PingSocket(AF_INET6)
177
178
179def TCPSocket(family):
180  s = Socket(family, SOCK_STREAM, IPPROTO_TCP)
181  SetNonBlocking(s.fileno())
182  return s
183
184
185def IPv4TCPSocket():
186  return TCPSocket(AF_INET)
187
188
189def IPv6TCPSocket():
190  return TCPSocket(AF_INET6)
191
192
193def UDPSocket(family):
194  return Socket(family, SOCK_DGRAM, IPPROTO_UDP)
195
196
197def RawGRESocket(family):
198  s = Socket(family, SOCK_RAW, IPPROTO_GRE)
199  return s
200
201
202def BindRandomPort(version, sock):
203  addr = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
204  sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
205  sock.bind((addr, 0))
206  if sock.getsockopt(SOL_SOCKET, SO_PROTOCOL) == IPPROTO_TCP:
207    sock.listen(100)
208  port = sock.getsockname()[1]
209  return port
210
211
212def EnableFinWait(sock):
213  # Disabling SO_LINGER causes sockets to go into FIN_WAIT on close().
214  sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 0, 0))
215
216
217def DisableFinWait(sock):
218  # Enabling SO_LINGER with a timeout of zero causes close() to send RST.
219  sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0))
220
221
222def CreateSocketPair(family, socktype, addr):
223  clientsock = socket(family, socktype, 0)
224  listensock = socket(family, socktype, 0)
225  listensock.bind((addr, 0))
226  addr = listensock.getsockname()
227  if socktype == SOCK_STREAM:
228    listensock.listen(1)
229  clientsock.connect(listensock.getsockname())
230  if socktype == SOCK_STREAM:
231    acceptedsock, _ = listensock.accept()
232    DisableFinWait(clientsock)
233    DisableFinWait(acceptedsock)
234    listensock.close()
235  else:
236    listensock.connect(clientsock.getsockname())
237    acceptedsock = listensock
238  return clientsock, acceptedsock
239
240
241def GetInterfaceIndex(ifname):
242  with UDPSocket(AF_INET) as s:
243    ifr = struct.pack("%dsi" % IFNAMSIZ, ifname.encode(), 0)
244    ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr)
245    return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
246
247
248def SetInterfaceHWAddr(ifname, hwaddr):
249  with UDPSocket(AF_INET) as s:
250    hwaddr = hwaddr.replace(":", "")
251    hwaddr = binascii.unhexlify(hwaddr)
252    if len(hwaddr) != 6:
253      raise ValueError("Unknown hardware address length %d" % len(hwaddr))
254    ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname.encode(), scapy.ARPHDR_ETHER,
255                      hwaddr)
256    fcntl.ioctl(s, SIOCSIFHWADDR, ifr)
257
258
259def SetInterfaceState(ifname, up):
260  ifname_bytes = ifname.encode()
261  with UDPSocket(AF_INET) as s:
262    ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, 0)
263    ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr)
264    _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
265    if up:
266      flags |= scapy.IFF_UP
267    else:
268      flags &= ~scapy.IFF_UP
269    ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, flags)
270    ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr)
271
272
273def SetInterfaceUp(ifname):
274  return SetInterfaceState(ifname, True)
275
276
277def SetInterfaceDown(ifname):
278  return SetInterfaceState(ifname, False)
279
280
281def CanonicalizeIPv6Address(addr):
282  return inet_ntop(AF_INET6, inet_pton(AF_INET6, addr))
283
284
285def FormatProcAddress(unformatted):
286  groups = []
287  for i in range(0, len(unformatted), 4):
288    groups.append(unformatted[i:i+4])
289  formatted = ":".join(groups)
290  # Compress the address.
291  address = CanonicalizeIPv6Address(formatted)
292  return address
293
294
295def FormatSockStatAddress(address):
296  if ":" in address:
297    family = AF_INET6
298  else:
299    family = AF_INET
300  binary = inet_pton(family, address)
301  out = ""
302  for i in range(0, len(binary), 4):
303    out += "%08X" % struct.unpack("=L", binary[i:i+4])
304  return out
305
306
307def GetLinkAddress(ifname, linklocal):
308  with open("/proc/net/if_inet6") as if_inet6:
309    addresses = if_inet6.readlines()
310  for address in addresses:
311    address = [s for s in address.strip().split(" ") if s]
312    if address[5] == ifname:
313      if (linklocal and address[0].startswith("fe80")
314          or not linklocal and not address[0].startswith("fe80")):
315        # Convert the address from raw hex to something with colons in it.
316        return FormatProcAddress(address[0])
317  return None
318
319
320def GetDefaultRoute(version=6):
321  if version == 6:
322    with open("/proc/net/ipv6_route") as ipv6_route:
323      routes = ipv6_route.readlines()
324    for route in routes:
325      route = [s for s in route.strip().split(" ") if s]
326      if (route[0] == "00000000000000000000000000000000" and route[1] == "00"
327          # Routes in non-default tables end up in /proc/net/ipv6_route!!!
328          and route[9] != "lo" and not route[9].startswith("nettest")):
329        return FormatProcAddress(route[4]), route[9]
330    raise ValueError("No IPv6 default route found")
331  elif version == 4:
332    with open("/proc/net/route") as ipv4_route:
333      routes = ipv4_route.readlines()
334    for route in routes:
335      route = [s for s in route.strip().split("\t") if s]
336      if route[1] == "00000000" and route[7] == "00000000":
337        gw, iface = route[2], route[0]
338        gw = inet_ntop(AF_INET, binascii.unhexlify(gw)[::-1])
339        return gw, iface
340    raise ValueError("No IPv4 default route found")
341  else:
342    raise ValueError("Don't know about IPv%s" % version)
343
344
345def GetDefaultRouteInterface():
346  unused_gw, iface = GetDefaultRoute()
347  return iface
348
349
350def MakeFlowLabelOption(addr, label):
351  # struct in6_flowlabel_req {
352  #         struct in6_addr flr_dst;
353  #         __be32  flr_label;
354  #         __u8    flr_action;
355  #         __u8    flr_share;
356  #         __u16   flr_flags;
357  #         __u16   flr_expires;
358  #         __u16   flr_linger;
359  #         __u32   __flr_pad;
360  #         /* Options in format of IPV6_PKTOPTIONS */
361  # };
362  fmt = "16sIBBHHH4s"
363  assert struct.calcsize(fmt) == 32
364  addr = inet_pton(AF_INET6, addr)
365  assert len(addr) == 16
366  label = htonl(label & 0xfffff)
367  action = IPV6_FL_A_GET
368  share = IPV6_FL_S_ANY
369  flags = IPV6_FL_F_CREATE
370  pad = b"\x00" * 4
371  return struct.pack(fmt, addr, label, action, share, flags, 0, 0, pad)
372
373
374def SetFlowLabel(s, addr, label):
375  opt = MakeFlowLabelOption(addr, label)
376  s.setsockopt(SOL_IPV6, IPV6_FLOWLABEL_MGR, opt)
377  # Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1).
378
379
380def GetIptablesBinaryPath(version):
381  if version == 4:
382    paths = (
383        "/sbin/iptables-legacy",
384        "/sbin/iptables",
385        "/system/bin/iptables-legacy",
386        "/system/bin/iptables",
387    )
388  elif version == 6:
389    paths = (
390        "/sbin/ip6tables-legacy",
391        "/sbin/ip6tables",
392        "/system/bin/ip6tables-legacy",
393        "/system/bin/ip6tables",
394    )
395  for iptables_path in paths:
396    if os.access(iptables_path, os.X_OK):
397      return iptables_path
398  raise FileNotFoundError(
399      "iptables binary for IPv{} not found".format(version) +
400      ", checked: {}".format(", ".join(paths)))
401
402
403def RunIptablesCommand(version, args):
404  iptables_path = GetIptablesBinaryPath(version)
405  return os.spawnvp(os.P_WAIT, iptables_path, [iptables_path] + args.split(" "))
406
407# Determine network configuration.
408try:
409  GetDefaultRoute(version=4)
410  HAVE_IPV4 = True
411except ValueError:
412  HAVE_IPV4 = False
413
414try:
415  GetDefaultRoute(version=6)
416  HAVE_IPV6 = True
417except ValueError:
418  HAVE_IPV6 = False
419
420class RunAsUidGid(object):
421  """Context guard to run a code block as a given UID."""
422
423  def __init__(self, uid, gid):
424    self.uid = uid
425    self.gid = gid
426
427  def __enter__(self):
428    if self.gid:
429      self.saved_gid = os.getgid()
430      os.setgid(self.gid)
431    if self.uid:
432      self.saved_uids = os.getresuid()
433      self.saved_groups = os.getgroups()
434      os.setgroups(self.saved_groups + [AID_INET])
435      os.setresuid(self.uid, self.uid, self.saved_uids[0])
436
437  def __exit__(self, unused_type, unused_value, unused_traceback):
438    if self.uid:
439      os.setresuid(*self.saved_uids)
440      os.setgroups(self.saved_groups)
441    if self.gid:
442      os.setgid(self.saved_gid)
443
444class RunAsUid(RunAsUidGid):
445  """Context guard to run a code block as a given GID and UID."""
446
447  def __init__(self, uid):
448    RunAsUidGid.__init__(self, uid, 0)
449
450class NetworkTest(unittest.TestCase):
451
452  @contextlib.contextmanager
453  def _errnoCheck(self, err_num):
454    with self.assertRaises(EnvironmentError) as context:
455      yield context
456    self.assertEqual(context.exception.errno, err_num)
457
458  def assertRaisesErrno(self, err_num, f=None, *args):
459    """Test that the system returns an errno error.
460
461    This works similarly to unittest.TestCase.assertRaises. You can call it as
462    an assertion, or use it as a context manager.
463    e.g.
464        self.assertRaisesErrno(errno.ENOENT, do_things, arg1, arg2)
465    or
466        with self.assertRaisesErrno(errno.ENOENT):
467          do_things(arg1, arg2)
468
469    Args:
470      err_num: an errno constant
471      f: (optional) A callable that should result in error
472      *args: arguments passed to f
473    """
474    if f is None:
475      return self._errnoCheck(err_num)
476    else:
477      with self._errnoCheck(err_num):
478        f(*args)
479
480  def ReadProcNetSocket(self, protocol):
481    # Read file.
482    filename = "/proc/net/%s" % protocol
483    with open(filename) as f:
484      lines = f.readlines()
485
486    # Possibly check, and strip, header.
487    if protocol in ["icmp6", "raw6", "udp6"]:
488      self.assertEqual(IPV6_SEQ_DGRAM_HEADER, lines[0])
489    lines = lines[1:]
490
491    # Check contents.
492    if protocol.endswith("6"):
493      addrlen = 32
494    else:
495      addrlen = 8
496
497    if protocol.startswith("tcp"):
498      # Real sockets have 5 extra numbers, timewait sockets have none.
499      end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+)$"
500    elif re.match("icmp|udp|raw", protocol):
501      # Drops.
502      end_regexp = " +([0-9]+) *$"
503    else:
504      raise ValueError("Don't know how to parse %s" % filename)
505
506    regexp = re.compile(r" *(\d+): "                    # bucket
507                        "([0-9A-F]{%d}:[0-9A-F]{4}) "   # srcaddr, port
508                        "([0-9A-F]{%d}:[0-9A-F]{4}) "   # dstaddr, port
509                        "([0-9A-F][0-9A-F]) "           # state
510                        "([0-9A-F]{8}:[0-9A-F]{8}) "    # mem
511                        "([0-9A-F]{2}:[0-9A-F]{8}) "    # ?
512                        "([0-9A-F]{8}) +"               # ?
513                        "([0-9]+) +"                    # uid
514                        "([0-9]+) +"                    # timeout
515                        "([0-9]+) +"                    # inode
516                        "([0-9]+) +"                    # refcnt
517                        "([0-9a-f]+)"                   # sp
518                        "%s"                            # icmp has spaces
519                        % (addrlen, addrlen, end_regexp))
520    # Return a list of lists with only source / dest addresses for now.
521    # TODO: consider returning a dict or namedtuple instead.
522    out = []
523    for line in lines:
524      m = regexp.match(line)
525      if m is None:
526        raise ValueError("Failed match on [%s]" % line)
527      (_, src, dst, state, mem,
528       _, _, uid, _, _, refcnt, _, extra) = m.groups()
529      out.append([src, dst, state, mem, uid, refcnt, extra])
530    return out
531
532  @staticmethod
533  def GetConsoleLogLevel():
534    with open("/proc/sys/kernel/printk") as printk:
535      return int(printk.readline().split()[0])
536
537  @staticmethod
538  def SetConsoleLogLevel(level):
539    with open("/proc/sys/kernel/printk", "w") as printk:
540      return printk.write("%s\n" % level)
541
542
543if __name__ == "__main__":
544  unittest.main()
545