• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import cstruct
18import ctypes
19import errno
20import os
21import random
22from socket import *  # pylint: disable=wildcard-import
23import struct
24import time           # pylint: disable=unused-import
25import unittest
26
27from scapy import all as scapy
28
29import csocket
30import iproute
31import multinetwork_base
32import net_test
33import netlink
34import packets
35
36# For brevity.
37UDP_PAYLOAD = net_test.UDP_PAYLOAD
38
39IPV6_FLOWINFO = 11
40
41SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
42TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
43
44# The IP[V6]UNICAST_IF socket option was added between 3.1 and 3.4.
45HAVE_UNICAST_IF = net_test.LINUX_VERSION >= (3, 4, 0)
46
47# RTPROT_RA is working properly with 4.14
48HAVE_RTPROT_RA = net_test.LINUX_VERSION >= (4, 14, 0)
49
50class ConfigurationError(AssertionError):
51  pass
52
53
54class OutgoingTest(multinetwork_base.MultiNetworkBaseTest):
55
56  # How many times to run outgoing packet tests.
57  ITERATIONS = 5
58
59  def CheckPingPacket(self, version, netid, routing_mode, packet):
60    s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
61
62    myaddr = self.MyAddress(version, netid)
63    mysockaddr = self.MySocketAddress(version, netid)
64    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
65    s.bind((mysockaddr, packets.PING_IDENT))
66    net_test.SetSocketTos(s, packets.PING_TOS)
67
68    dstaddr = self.GetRemoteAddress(version)
69    dstsockaddr = self.GetRemoteSocketAddress(version)
70    desc, expected = packets.ICMPEcho(version, myaddr, dstaddr)
71    msg = "IPv%d ping: expected %s on %s" % (
72        version, desc, self.GetInterfaceName(netid))
73
74    s.sendto(packet + packets.PING_PAYLOAD, (dstsockaddr, 19321))
75
76    self.ExpectPacketOn(netid, msg, expected)
77
78  def CheckTCPSYNPacket(self, version, netid, routing_mode):
79    s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
80
81    myaddr = self.MyAddress(version, netid)
82    dstaddr = self.GetRemoteAddress(version)
83    dstsockaddr = self.GetRemoteSocketAddress(version)
84    desc, expected = packets.SYN(53, version, myaddr, dstaddr,
85                                 sport=None, seq=None)
86
87
88    # Non-blocking TCP connects always return EINPROGRESS.
89    self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstsockaddr, 53))
90    msg = "IPv%s TCP connect: expected %s on %s" % (
91        version, desc, self.GetInterfaceName(netid))
92    self.ExpectPacketOn(netid, msg, expected)
93    s.close()
94
95  def CheckUDPPacket(self, version, netid, routing_mode):
96    s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
97
98    myaddr = self.MyAddress(version, netid)
99    dstaddr = self.GetRemoteAddress(version)
100    dstsockaddr = self.GetRemoteSocketAddress(version)
101
102    desc, expected = packets.UDP(version, myaddr, dstaddr, sport=None)
103    msg = "IPv%s UDP %%s: expected %s on %s" % (
104        version, desc, self.GetInterfaceName(netid))
105
106    s.sendto(UDP_PAYLOAD, (dstsockaddr, 53))
107    self.ExpectPacketOn(netid, msg % "sendto", expected)
108
109    # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
110    if routing_mode != "ucast_oif":
111      s.connect((dstsockaddr, 53))
112      s.send(UDP_PAYLOAD)
113      self.ExpectPacketOn(netid, msg % "connect/send", expected)
114      s.close()
115
116  def CheckRawGrePacket(self, version, netid, routing_mode):
117    s = self.BuildSocket(version, net_test.RawGRESocket, netid, routing_mode)
118
119    inner_version = {4: 6, 6: 4}[version]
120    inner_src = self.MyAddress(inner_version, netid)
121    inner_dst = self.GetRemoteAddress(inner_version)
122    inner = str(packets.UDP(inner_version, inner_src, inner_dst, sport=None)[1])
123
124    ethertype = {4: net_test.ETH_P_IP, 6: net_test.ETH_P_IPV6}[inner_version]
125    # A GRE header can be as simple as two zero bytes and the ethertype.
126    packet = struct.pack("!i", ethertype) + inner
127    myaddr = self.MyAddress(version, netid)
128    dstaddr = self.GetRemoteAddress(version)
129
130    s.sendto(packet, (dstaddr, IPPROTO_GRE))
131    desc, expected = packets.GRE(version, myaddr, dstaddr, ethertype, inner)
132    msg = "Raw IPv%d GRE with inner IPv%d UDP: expected %s on %s" % (
133        version, inner_version, desc, self.GetInterfaceName(netid))
134    self.ExpectPacketOn(netid, msg, expected)
135
136  def CheckOutgoingPackets(self, routing_mode):
137    for _ in range(self.ITERATIONS):
138      for netid in self.tuns:
139
140        self.CheckPingPacket(4, netid, routing_mode, self.IPV4_PING)
141        # Kernel bug.
142        if routing_mode != "oif":
143          self.CheckPingPacket(6, netid, routing_mode, self.IPV6_PING)
144
145        # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
146        if routing_mode != "ucast_oif":
147          self.CheckTCPSYNPacket(4, netid, routing_mode)
148          self.CheckTCPSYNPacket(6, netid, routing_mode)
149          self.CheckTCPSYNPacket(5, netid, routing_mode)
150
151        self.CheckUDPPacket(4, netid, routing_mode)
152        self.CheckUDPPacket(6, netid, routing_mode)
153        self.CheckUDPPacket(5, netid, routing_mode)
154
155        # Creating raw sockets on non-root UIDs requires properly setting
156        # capabilities, which is hard to do from Python.
157        # IP_UNICAST_IF is not supported on raw sockets.
158        if routing_mode not in ["uid", "ucast_oif"]:
159          self.CheckRawGrePacket(4, netid, routing_mode)
160          self.CheckRawGrePacket(6, netid, routing_mode)
161
162  def testMarkRouting(self):
163    """Checks that socket marking selects the right outgoing interface."""
164    self.CheckOutgoingPackets("mark")
165
166  def testUidRouting(self):
167    """Checks that UID routing selects the right outgoing interface."""
168    self.CheckOutgoingPackets("uid")
169
170  def testOifRouting(self):
171    """Checks that oif routing selects the right outgoing interface."""
172    self.CheckOutgoingPackets("oif")
173
174  @unittest.skipUnless(HAVE_UNICAST_IF, "no support for UNICAST_IF")
175  def testUcastOifRouting(self):
176    """Checks that ucast oif routing selects the right outgoing interface."""
177    self.CheckOutgoingPackets("ucast_oif")
178
179  def CheckRemarking(self, version, use_connect):
180    modes = ["mark", "oif", "uid"]
181    # Setting UNICAST_IF on connected sockets does not work.
182    if not use_connect and HAVE_UNICAST_IF:
183      modes += ["ucast_oif"]
184
185    for mode in modes:
186      s = net_test.UDPSocket(self.GetProtocolFamily(version))
187
188      # Figure out what packets to expect.
189      sport = net_test.BindRandomPort(version, s)
190      dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
191      unspec = {4: "0.0.0.0", 6: "::"}[version]  # Placeholder.
192      desc, expected = packets.UDP(version, unspec, dstaddr, sport)
193
194      # If we're testing connected sockets, connect the socket on the first
195      # netid now.
196      if use_connect:
197        netid = list(self.tuns.keys())[0]
198        self.SelectInterface(s, netid, mode)
199        s.connect((dstaddr, 53))
200        expected.src = self.MyAddress(version, netid)
201
202      # For each netid, select that network without closing the socket, and
203      # check that the packets sent on that socket go out on the right network.
204      #
205      # For connected sockets, routing is cached in the socket's destination
206      # cache entry. In this case, we check that selecting the network a second
207      # time on the same socket (except via SO_BINDTODEVICE, or SO_MARK on 5.0+
208      # kernels) does not change routing, but that subsequently invalidating the
209      # destination cache entry does. This is a bug in the kernel because
210      # re-selecting the netid should cause routing to change, and future
211      # kernels may fix this bug for per-UID routing and ucast_oif routing like
212      # they already have for mark-based routing. But until they do, this
213      # behaviour provides a convenient way to check that InvalidateDstCache
214      # actually works.
215      prevnetid = None
216      for netid in self.tuns:
217        self.SelectInterface(s, netid, mode)
218        if not use_connect:
219          expected.src = self.MyAddress(version, netid)
220
221        def ExpectSendUsesNetid(netid):
222          connected_str = "Connected" if use_connect else "Unconnected"
223          msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
224              connected_str, version, mode, desc, self.GetInterfaceName(netid))
225          if use_connect:
226            s.send(UDP_PAYLOAD)
227          else:
228            s.sendto(UDP_PAYLOAD, (dstaddr, 53))
229          self.ExpectPacketOn(netid, msg, expected)
230
231        # Does this socket have a stale dst cache entry that we need to clear?
232        def SocketHasStaleDstCacheEntry():
233          if not prevnetid:
234            # This is the first time we're marking the socket.
235            return False
236          if not use_connect:
237            # Non-connected sockets never have dst cache entries.
238            return False
239          if mode in ["uid", "ucast_oif"]:
240            # No kernel invalidates the dst cache entry if the UID or the
241            # UCAST_OIF socket option changes.
242            return True
243          if mode == "oif":
244            # Changing SO_BINDTODEVICE always invalidates the dst cache entry.
245            return False
246          if mode == "mark":
247            # Changing the mark invalidates the dst cache entry in 5.0+.
248            return net_test.LINUX_VERSION < (5, 0, 0)
249          raise AssertionError("%s must be one of %s" % (mode, modes))
250
251        if SocketHasStaleDstCacheEntry():
252            ExpectSendUsesNetid(prevnetid)
253            # ... until we invalidate it.
254            self.InvalidateDstCache(version, prevnetid)
255
256        # In any case, future sends must be correct.
257        ExpectSendUsesNetid(netid)
258
259        self.SelectInterface(s, None, mode)
260        prevnetid = netid
261
262  def testIPv4Remarking(self):
263    """Checks that updating the mark on an IPv4 socket changes routing."""
264    self.CheckRemarking(4, False)
265    self.CheckRemarking(4, True)
266
267  def testIPv6Remarking(self):
268    """Checks that updating the mark on an IPv6 socket changes routing."""
269    self.CheckRemarking(6, False)
270    self.CheckRemarking(6, True)
271
272  def testIPv6StickyPktinfo(self):
273    for _ in range(self.ITERATIONS):
274      for netid in self.tuns:
275        s = net_test.UDPSocket(AF_INET6)
276
277        # Set a flowlabel.
278        net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xdead)
279        s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_FLOWINFO_SEND, 1)
280
281        # Set some destination options.
282        nonce = "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c"
283        dstopts = "".join([
284            "\x11\x02",              # Next header=UDP, 24 bytes of options.
285            "\x01\x06", "\x00" * 6,  # PadN, 6 bytes of padding.
286            "\x8b\x0c",              # ILNP nonce, 12 bytes.
287            nonce
288        ])
289        s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, dstopts)
290        s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_HOPS, 255)
291
292        pktinfo = multinetwork_base.MakePktInfo(6, None, self.ifindices[netid])
293
294        # Set the sticky pktinfo option.
295        s.setsockopt(net_test.SOL_IPV6, IPV6_PKTINFO, pktinfo)
296
297        # Specify the flowlabel in the destination address.
298        s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 53, 0xdead, 0))
299
300        sport = s.getsockname()[1]
301        srcaddr = self.MyAddress(6, netid)
302        expected = (scapy.IPv6(src=srcaddr, dst=net_test.IPV6_ADDR,
303                               fl=0xdead, hlim=255) /
304                    scapy.IPv6ExtHdrDestOpt(
305                        options=[scapy.PadN(optdata="\x00\x00\x00\x00\x00\x00"),
306                                 scapy.HBHOptUnknown(otype=0x8b,
307                                                     optdata=nonce)]) /
308                    scapy.UDP(sport=sport, dport=53) /
309                    UDP_PAYLOAD)
310        msg = "IPv6 UDP using sticky pktinfo: expected UDP packet on %s" % (
311            self.GetInterfaceName(netid))
312        self.ExpectPacketOn(netid, msg, expected)
313
314  def CheckPktinfoRouting(self, version):
315    for _ in range(self.ITERATIONS):
316      for netid in self.tuns:
317        family = self.GetProtocolFamily(version)
318        s = net_test.UDPSocket(family)
319
320        if version == 6:
321          # Create a flowlabel so we can use it.
322          net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xbeef)
323
324          # Specify some arbitrary options.
325          # We declare the flowlabel as ctypes.c_uint32 because on a 32-bit
326          # Python interpreter an integer greater than 0x7fffffff (such as our
327          # chosen flowlabel after being passed through htonl) is converted to
328          # long, and _MakeMsgControl doesn't know what to do with longs.
329          cmsgs = [
330              (net_test.SOL_IPV6, IPV6_HOPLIMIT, 39),
331              (net_test.SOL_IPV6, IPV6_TCLASS, 0x83),
332              (net_test.SOL_IPV6, IPV6_FLOWINFO, ctypes.c_uint(htonl(0xbeef))),
333          ]
334        else:
335          # Support for setting IPv4 TOS and TTL via cmsg only appeared in 3.13.
336          cmsgs = []
337          s.setsockopt(net_test.SOL_IP, IP_TTL, 39)
338          s.setsockopt(net_test.SOL_IP, IP_TOS, 0x83)
339
340        dstaddr = self.GetRemoteAddress(version)
341        self.SendOnNetid(version, s, dstaddr, 53, netid, UDP_PAYLOAD, cmsgs)
342
343        sport = s.getsockname()[1]
344        srcaddr = self.MyAddress(version, netid)
345
346        desc, expected = packets.UDPWithOptions(version, srcaddr, dstaddr,
347                                                sport=sport)
348
349        msg = "IPv%d UDP using pktinfo routing: expected %s on %s" % (
350            version, desc, self.GetInterfaceName(netid))
351        self.ExpectPacketOn(netid, msg, expected)
352
353  def testIPv4PktinfoRouting(self):
354    self.CheckPktinfoRouting(4)
355
356  def testIPv6PktinfoRouting(self):
357    self.CheckPktinfoRouting(6)
358
359
360class MarkTest(multinetwork_base.InboundMarkingTest):
361
362  def CheckReflection(self, version, gen_packet, gen_reply):
363    """Checks that replies go out on the same interface as the original.
364
365    For each combination:
366     - Calls gen_packet to generate a packet to that IP address.
367     - Writes the packet generated by gen_packet on the given tun
368       interface, causing the kernel to receive it.
369     - Checks that the kernel's reply matches the packet generated by
370       gen_reply.
371
372    Args:
373      version: An integer, 4 or 6.
374      gen_packet: A function taking an IP version (an integer), a source
375        address and a destination address (strings), and returning a scapy
376        packet.
377      gen_reply: A function taking the same arguments as gen_packet,
378        plus a scapy packet, and returning a scapy packet.
379    """
380    for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
381      # Generate a test packet.
382      desc, packet = gen_packet(version, remoteaddr, myaddr)
383
384      # Test with mark reflection enabled and disabled.
385      for reflect in [0, 1]:
386        self.SetMarkReflectSysctls(reflect)
387        # HACK: IPv6 ping replies always do a routing lookup with the
388        # interface the ping came in on. So even if mark reflection is not
389        # working, IPv6 ping replies will be properly reflected. Don't
390        # fail when that happens.
391        if reflect or desc == "ICMPv6 echo":
392          reply_desc, reply = gen_reply(version, myaddr, remoteaddr, packet)
393        else:
394          reply_desc, reply = None, None
395
396        msg = self._FormatMessage(iif, ip_if, "reflect=%d" % reflect,
397                                  desc, reply_desc)
398        self._ReceiveAndExpectResponse(netid, packet, reply, msg)
399
400  def SYNToClosedPort(self, *args):
401    return packets.SYN(999, *args)
402
403  def testIPv4ICMPErrorsReflectMark(self):
404    self.CheckReflection(4, packets.UDP, packets.ICMPPortUnreachable)
405
406  def testIPv6ICMPErrorsReflectMark(self):
407    self.CheckReflection(6, packets.UDP, packets.ICMPPortUnreachable)
408
409  def testIPv4PingRepliesReflectMarkAndTos(self):
410    self.CheckReflection(4, packets.ICMPEcho, packets.ICMPReply)
411
412  def testIPv6PingRepliesReflectMarkAndTos(self):
413    self.CheckReflection(6, packets.ICMPEcho, packets.ICMPReply)
414
415  def testIPv4RSTsReflectMark(self):
416    self.CheckReflection(4, self.SYNToClosedPort, packets.RST)
417
418  def testIPv6RSTsReflectMark(self):
419    self.CheckReflection(6, self.SYNToClosedPort, packets.RST)
420
421
422class TCPAcceptTest(multinetwork_base.InboundMarkingTest):
423
424  MODE_BINDTODEVICE = "SO_BINDTODEVICE"
425  MODE_INCOMING_MARK = "incoming mark"
426  MODE_EXPLICIT_MARK = "explicit mark"
427  MODE_UID = "uid"
428
429  @classmethod
430  def setUpClass(cls):
431    super(TCPAcceptTest, cls).setUpClass()
432
433    # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
434    # will accept both IPv4 and IPv6 connections. We do this here instead of in
435    # each test so we can use the same socket every time. That way, if a kernel
436    # bug causes incoming packets to mark the listening socket instead of the
437    # accepted socket, the test will fail as soon as the next address/interface
438    # combination is tried.
439    cls.listensocket = net_test.IPv6TCPSocket()
440    cls.listenport = net_test.BindRandomPort(6, cls.listensocket)
441
442  def _SetTCPMarkAcceptSysctl(self, value):
443    self.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
444
445  def CheckTCPConnection(self, mode, listensocket, netid, version,
446                         myaddr, remoteaddr, packet, reply, msg):
447    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
448
449    # Attempt to confuse the kernel.
450    self.InvalidateDstCache(version, netid)
451
452    self.ReceivePacketOn(netid, establishing_ack)
453
454    # If we're using UID routing, the accept() call has to be run as a UID that
455    # is routed to the specified netid, because the UID of the socket returned
456    # by accept() is the effective UID of the process that calls it. It doesn't
457    # need to be the same UID; any UID that selects the same interface will do.
458    with net_test.RunAsUid(self.UidForNetid(netid)):
459      s, _ = listensocket.accept()
460
461    try:
462      # Check that data sent on the connection goes out on the right interface.
463      desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
464                               payload=UDP_PAYLOAD)
465      s.send(UDP_PAYLOAD)
466      self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
467      self.InvalidateDstCache(version, netid)
468
469      # Keep up our end of the conversation.
470      ack = packets.ACK(version, remoteaddr, myaddr, data)[1]
471      self.InvalidateDstCache(version, netid)
472      self.ReceivePacketOn(netid, ack)
473
474      mark = self.GetSocketMark(s)
475    finally:
476      self.InvalidateDstCache(version, netid)
477      s.close()
478      self.InvalidateDstCache(version, netid)
479
480    if mode == self.MODE_INCOMING_MARK:
481      self.assertEqual(netid, mark & self.NETID_FWMASK,
482                        msg + ": Accepted socket: Expected mark %d, got %d" % (
483                            netid, mark))
484    elif mode != self.MODE_EXPLICIT_MARK:
485      self.assertEqual(0, self.GetSocketMark(listensocket))
486
487    # Check the FIN was sent on the right interface, and ack it. We don't expect
488    # this to fail because by the time the connection is established things are
489    # likely working, but a) extra tests are always good and b) extra packets
490    # like the FIN (and retransmitted FINs) could cause later tests that expect
491    # no packets to fail.
492    desc, fin = packets.FIN(version, myaddr, remoteaddr, ack)
493    self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
494
495    desc, finack = packets.FIN(version, remoteaddr, myaddr, fin)
496    self.ReceivePacketOn(netid, finack)
497
498    # Since we called close() earlier, the userspace socket object is gone, so
499    # the socket has no UID. If we're doing UID routing, the ack might be routed
500    # incorrectly. Not much we can do here.
501    desc, finackack = packets.ACK(version, myaddr, remoteaddr, finack)
502    self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
503
504  def CheckTCP(self, version, modes):
505    """Checks that incoming TCP connections work.
506
507    Args:
508      version: An integer, 4 or 6.
509      modes: A list of modes to excercise.
510    """
511    for syncookies in [0, 2]:
512      for mode in modes:
513        for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
514          listensocket = self.listensocket
515          listenport = listensocket.getsockname()[1]
516
517          accept_sysctl = 1 if mode == self.MODE_INCOMING_MARK else 0
518          self._SetTCPMarkAcceptSysctl(accept_sysctl)
519          self.SetMarkReflectSysctls(accept_sysctl)
520
521          bound_dev = iif if mode == self.MODE_BINDTODEVICE else None
522          self.BindToDevice(listensocket, bound_dev)
523
524          mark = netid if mode == self.MODE_EXPLICIT_MARK else 0
525          self.SetSocketMark(listensocket, mark)
526
527          uid = self.UidForNetid(netid) if mode == self.MODE_UID else 0
528          os.fchown(listensocket.fileno(), uid, -1)
529
530          # Generate the packet here instead of in the outer loop, so
531          # subsequent TCP connections use different source ports and
532          # retransmissions from old connections don't confuse subsequent
533          # tests.
534          desc, packet = packets.SYN(listenport, version, remoteaddr, myaddr)
535
536          if mode:
537            reply_desc, reply = packets.SYNACK(version, myaddr, remoteaddr,
538                                               packet)
539          else:
540            reply_desc, reply = None, None
541
542          extra = "mode=%s, syncookies=%d" % (mode, syncookies)
543          msg = self._FormatMessage(iif, ip_if, extra, desc, reply_desc)
544          reply = self._ReceiveAndExpectResponse(netid, packet, reply, msg)
545          if reply:
546            self.CheckTCPConnection(mode, listensocket, netid, version, myaddr,
547                                    remoteaddr, packet, reply, msg)
548
549  def testBasicTCP(self):
550    self.CheckTCP(4, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
551    self.CheckTCP(6, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
552
553  def testIPv4MarkAccept(self):
554    self.CheckTCP(4, [self.MODE_INCOMING_MARK])
555
556  def testIPv6MarkAccept(self):
557    self.CheckTCP(6, [self.MODE_INCOMING_MARK])
558
559  def testIPv4UidAccept(self):
560    self.CheckTCP(4, [self.MODE_UID])
561
562  def testIPv6UidAccept(self):
563    self.CheckTCP(6, [self.MODE_UID])
564
565  def testIPv6ExplicitMark(self):
566    self.CheckTCP(6, [self.MODE_EXPLICIT_MARK])
567
568@unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
569                     "need support for per-table autoconf")
570class RIOTest(multinetwork_base.MultiNetworkBaseTest):
571  """Test for IPv6 RFC 4191 route information option
572
573  Relevant kernel commits:
574    upstream:
575      f104a567e673 ipv6: use rt6_get_dflt_router to get default router in rt6_route_rcv
576      bbea124bc99d net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
577
578    android-4.9:
579      d860b2e8a7f1 FROMLIST: net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs
580
581    android-4.4:
582      e953f89b8563 net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
583
584    android-4.1:
585      84f2f47716cd net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
586
587    android-3.18:
588      65f8936934fa net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
589
590    android-3.10:
591      161e88ebebc7 net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
592
593  """
594
595  def setUp(self):
596    super(RIOTest, self).setUp()
597    self.NETID = random.choice(self.NETIDS)
598    self.IFACE = self.GetInterfaceName(self.NETID)
599    # return min/max plen to default values before each test case
600    self.SetAcceptRaRtInfoMinPlen(0)
601    self.SetAcceptRaRtInfoMaxPlen(0)
602
603  def GetRoutingTable(self):
604    return self._TableForNetid(self.NETID)
605
606  def SetAcceptRaRtInfoMinPlen(self, plen):
607    self.SetSysctl(
608        "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_min_plen"
609        % self.IFACE, plen)
610
611  def GetAcceptRaRtInfoMinPlen(self):
612    return int(self.GetSysctl(
613        "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_min_plen" % self.IFACE))
614
615  def SetAcceptRaRtInfoMaxPlen(self, plen):
616    self.SetSysctl(
617        "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_max_plen"
618        % self.IFACE, plen)
619
620  def GetAcceptRaRtInfoMaxPlen(self):
621    return int(self.GetSysctl(
622        "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_max_plen" % self.IFACE))
623
624  def SendRIO(self, rtlifetime, plen, prefix, prf):
625    options = scapy.ICMPv6NDOptRouteInfo(rtlifetime=rtlifetime, plen=plen,
626                                         prefix=prefix, prf=prf)
627    self.SendRA(self.NETID, options=(options,))
628
629  def FindRoutesWithDestination(self, destination):
630    canonical = net_test.CanonicalizeIPv6Address(destination)
631    return [r for _, r in self.iproute.DumpRoutes(6, self.GetRoutingTable())
632            if ('RTA_DST' in r and r['RTA_DST'] == canonical)]
633
634  def FindRoutesWithGateway(self):
635    return [r for _, r in self.iproute.DumpRoutes(6, self.GetRoutingTable())
636            if 'RTA_GATEWAY' in r]
637
638  def CountRoutes(self):
639    return len(self.iproute.DumpRoutes(6, self.GetRoutingTable()))
640
641  def GetRouteExpiration(self, route):
642    return float(route['RTA_CACHEINFO'].expires) / 100.0
643
644  def AssertExpirationInRange(self, routes, lifetime, epsilon):
645    self.assertTrue(routes)
646    found = False
647    # Assert that at least one route in routes has the expected lifetime
648    for route in routes:
649      expiration = self.GetRouteExpiration(route)
650      if expiration < lifetime - epsilon:
651        continue
652      if expiration > lifetime + epsilon:
653        continue
654      found = True
655    self.assertTrue(found)
656
657  def DelRA6(self, prefix, plen):
658    version = 6
659    netid = self.NETID
660    table = self._TableForNetid(netid)
661    router = self._RouterAddress(netid, version)
662    ifindex = self.ifindices[netid]
663    # We actually want to specify RTPROT_RA, however an upstream
664    # kernel bug causes RAs to be installed with RTPROT_BOOT.
665    if HAVE_RTPROT_RA:
666       rtprot = iproute.RTPROT_RA
667    else:
668       rtprot = iproute.RTPROT_BOOT
669    self.iproute._Route(version, rtprot, iproute.RTM_DELROUTE,
670                        table, prefix, plen, router, ifindex, None, None)
671
672  def testSetAcceptRaRtInfoMinPlen(self):
673    for plen in range(-1, 130):
674      self.SetAcceptRaRtInfoMinPlen(plen)
675      self.assertEqual(plen, self.GetAcceptRaRtInfoMinPlen())
676
677  def testSetAcceptRaRtInfoMaxPlen(self):
678    for plen in range(-1, 130):
679      self.SetAcceptRaRtInfoMaxPlen(plen)
680      self.assertEqual(plen, self.GetAcceptRaRtInfoMaxPlen())
681
682  def testZeroRtLifetime(self):
683    PREFIX = "2001:db8:8901:2300::"
684    RTLIFETIME = 73500
685    PLEN = 56
686    PRF = 0
687    self.SetAcceptRaRtInfoMaxPlen(PLEN)
688    self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
689    # Give the kernel time to notice our RA
690    time.sleep(0.01)
691    self.assertTrue(self.FindRoutesWithDestination(PREFIX))
692    # RIO with rtlifetime = 0 should remove from routing table
693    self.SendRIO(0, PLEN, PREFIX, PRF)
694    # Give the kernel time to notice our RA
695    time.sleep(0.01)
696    self.assertFalse(self.FindRoutesWithDestination(PREFIX))
697
698  def testMinPrefixLenRejection(self):
699    PREFIX = "2001:db8:8902:2345::"
700    RTLIFETIME = 70372
701    PRF = 0
702    # sweep from high to low to avoid spurious failures from late arrivals.
703    for plen in range(130, 1, -1):
704      self.SetAcceptRaRtInfoMinPlen(plen)
705      # RIO with plen < min_plen should be ignored
706      self.SendRIO(RTLIFETIME, plen - 1, PREFIX, PRF)
707    # Give the kernel time to notice our RAs
708    time.sleep(0.1)
709    # Expect no routes
710    routes = self.FindRoutesWithDestination(PREFIX)
711    self.assertFalse(routes)
712
713  def testMaxPrefixLenRejection(self):
714    PREFIX = "2001:db8:8903:2345::"
715    RTLIFETIME = 73078
716    PRF = 0
717    # sweep from low to high to avoid spurious failures from late arrivals.
718    for plen in range(-1, 128, 1):
719      self.SetAcceptRaRtInfoMaxPlen(plen)
720      # RIO with plen > max_plen should be ignored
721      self.SendRIO(RTLIFETIME, plen + 1, PREFIX, PRF)
722    # Give the kernel time to notice our RAs
723    time.sleep(0.1)
724    # Expect no routes
725    routes = self.FindRoutesWithDestination(PREFIX)
726    self.assertFalse(routes)
727
728  def testSimpleAccept(self):
729    PREFIX = "2001:db8:8904:2345::"
730    RTLIFETIME = 9993
731    PRF = 0
732    PLEN = 56
733    self.SetAcceptRaRtInfoMinPlen(48)
734    self.SetAcceptRaRtInfoMaxPlen(64)
735    self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
736    # Give the kernel time to notice our RA
737    time.sleep(0.01)
738    routes = self.FindRoutesWithGateway()
739    self.AssertExpirationInRange(routes, RTLIFETIME, 1)
740    self.DelRA6(PREFIX, PLEN)
741
742  def testEqualMinMaxAccept(self):
743    PREFIX = "2001:db8:8905:2345::"
744    RTLIFETIME = 6326
745    PLEN = 21
746    PRF = 0
747    self.SetAcceptRaRtInfoMinPlen(PLEN)
748    self.SetAcceptRaRtInfoMaxPlen(PLEN)
749    self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
750    # Give the kernel time to notice our RA
751    time.sleep(0.01)
752    routes = self.FindRoutesWithGateway()
753    self.AssertExpirationInRange(routes, RTLIFETIME, 1)
754    self.DelRA6(PREFIX, PLEN)
755
756  def testZeroLengthPrefix(self):
757    PREFIX = "2001:db8:8906:2345::"
758    RTLIFETIME = self.RA_VALIDITY * 2
759    PLEN = 0
760    PRF = 0
761    # Max plen = 0 still allows default RIOs!
762    self.SetAcceptRaRtInfoMaxPlen(PLEN)
763    self.SendRA(self.NETID)
764    # Give the kernel time to notice our RA
765    time.sleep(0.01)
766    default = self.FindRoutesWithGateway()
767    self.AssertExpirationInRange(default, self.RA_VALIDITY, 1)
768    # RIO with prefix length = 0, should overwrite default route lifetime
769    # note that the RIO lifetime overwrites the RA lifetime.
770    self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
771    # Give the kernel time to notice our RA
772    time.sleep(0.01)
773    default = self.FindRoutesWithGateway()
774    self.AssertExpirationInRange(default, RTLIFETIME, 1)
775    self.DelRA6(PREFIX, PLEN)
776
777  def testManyRIOs(self):
778    RTLIFETIME = 68012
779    PLEN = 56
780    PRF = 0
781    COUNT = 1000
782    baseline = self.CountRoutes()
783    self.SetAcceptRaRtInfoMaxPlen(56)
784    # Send many RIOs compared to the expected number on a healthy system.
785    for i in range(0, COUNT):
786      prefix = "2001:db8:%x:1100::" % i
787      self.SendRIO(RTLIFETIME, PLEN, prefix, PRF)
788    time.sleep(0.1)
789    self.assertEqual(COUNT + baseline, self.CountRoutes())
790    for i in range(0, COUNT):
791      prefix = "2001:db8:%x:1100::" % i
792      self.DelRA6(prefix, PLEN)
793    # Expect that we can return to baseline config without lingering routes.
794    self.assertEqual(baseline, self.CountRoutes())
795
796class RATest(multinetwork_base.MultiNetworkBaseTest):
797
798  ND_ROUTER_ADVERT = 134
799  ND_OPT_PREF64 = 38
800  Pref64Option = cstruct.Struct("pref64_option", "!BBH12s",
801                                "type length lft_plc prefix")
802
803  def testDoesNotHaveObsoleteSysctl(self):
804    self.assertFalse(os.path.isfile(
805        "/proc/sys/net/ipv6/route/autoconf_table_offset"))
806
807  @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
808                       "no support for per-table autoconf")
809  def testPurgeDefaultRouters(self):
810
811    def CheckIPv6Connectivity(expect_connectivity):
812      for netid in self.NETIDS:
813        s = net_test.UDPSocket(AF_INET6)
814        self.SetSocketMark(s, netid)
815        if expect_connectivity:
816          self.assertTrue(s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 1234)))
817        else:
818          self.assertRaisesErrno(errno.ENETUNREACH, s.sendto, UDP_PAYLOAD,
819                                 (net_test.IPV6_ADDR, 1234))
820
821    try:
822      CheckIPv6Connectivity(True)
823      self.SetIPv6SysctlOnAllIfaces("accept_ra", 1)
824      self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
825      CheckIPv6Connectivity(False)
826    finally:
827      self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
828      for netid in self.NETIDS:
829        self.SendRA(netid)
830      CheckIPv6Connectivity(True)
831
832  def testOnlinkCommunication(self):
833    """Checks that on-link communication goes direct and not through routers."""
834    for netid in self.tuns:
835      # Send a UDP packet to a random on-link destination.
836      s = net_test.UDPSocket(AF_INET6)
837      iface = self.GetInterfaceName(netid)
838      self.BindToDevice(s, iface)
839      # dstaddr can never be our address because GetRandomDestination only fills
840      # in the lower 32 bits, but our address has 0xff in the byte before that
841      # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
842      dstaddr = self.GetRandomDestination(self.OnlinkPrefix(6, netid))
843      s.sendto(UDP_PAYLOAD, (dstaddr, 53))
844
845      # Expect an NS for that destination on the interface.
846      myaddr = self.MyAddress(6, netid)
847      mymac = self.MyMacAddress(netid)
848      desc, expected = packets.NS(myaddr, dstaddr, mymac)
849      msg = "Sending UDP packet to on-link destination: expecting %s" % desc
850      time.sleep(0.0001)  # Required to make the test work on kernel 3.1(!)
851      self.ExpectPacketOn(netid, msg, expected)
852
853      # Send an NA.
854      tgtmac = "02:00:00:00:%02x:99" % netid
855      _, reply = packets.NA(dstaddr, myaddr, tgtmac)
856      # Don't use ReceivePacketOn, since that uses the router's MAC address as
857      # the source. Instead, construct our own Ethernet header with source
858      # MAC of tgtmac.
859      reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
860      self.ReceiveEtherPacketOn(netid, reply)
861
862      # Expect the kernel to send the original UDP packet now that the ND cache
863      # entry has been populated.
864      sport = s.getsockname()[1]
865      desc, expected = packets.UDP(6, myaddr, dstaddr, sport=sport)
866      msg = "After NA response, expecting %s" % desc
867      self.ExpectPacketOn(netid, msg, expected)
868
869  # This test documents a known issue: routing tables are never deleted.
870  @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
871                       "no support for per-table autoconf")
872  def testLeftoverRoutes(self):
873    def GetNumRoutes():
874      return len(open("/proc/net/ipv6_route").readlines())
875
876    num_routes = GetNumRoutes()
877    for i in range(10, 20):
878      try:
879        self.tuns[i] = self.CreateTunInterface(i)
880        self.SendRA(i)
881        self.tuns[i].close()
882      finally:
883        del self.tuns[i]
884    self.assertLess(num_routes, GetNumRoutes())
885
886  def SendNdUseropt(self, option):
887    options = scapy.ICMPv6NDOptRouteInfo(rtlifetime=rtlifetime, plen=plen,
888                                         prefix=prefix, prf=prf)
889    self.SendRA(self.NETID, options=(options,))
890
891  def MakePref64Option(self, prefix, lifetime):
892    prefix = inet_pton(AF_INET6, prefix)[:12]
893    lft_plc = (lifetime & 0xfff8) | 0  # 96-bit prefix length
894    return self.Pref64Option((self.ND_OPT_PREF64, 2, lft_plc, prefix))
895
896  @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not backported")
897  def testPref64UserOption(self):
898    # Open a netlink socket to receive RTM_NEWNDUSEROPT messages.
899    s = netlink.NetlinkSocket(netlink.NETLINK_ROUTE, iproute.RTMGRP_ND_USEROPT)
900
901    # Send an RA with the PREF64 option.
902    netid = random.choice(self.NETIDS)
903    opt = self.MakePref64Option("64:ff9b::", 300)
904    self.SendRA(netid, options=(opt.Pack(),))
905
906    # Check that we get an an RTM_NEWNDUSEROPT message on the socket with the
907    # expected option.
908    csocket.SetSocketTimeout(s.sock, 100)
909    try:
910      data = s._Recv()
911    except IOError as e:
912      self.fail("Should have received an RTM_NEWNDUSEROPT message. "
913                "Please ensure the kernel supports receiving the "
914                "PREF64 RA option. Error: %s" % e)
915
916    # Check that the message is received correctly.
917    nlmsghdr, data = cstruct.Read(data, netlink.NLMsgHdr)
918    self.assertEqual(iproute.RTM_NEWNDUSEROPT, nlmsghdr.type)
919
920    # Check the option contents.
921    ndopthdr, data = cstruct.Read(data, iproute.NdUseroptMsg)
922    self.assertEqual(AF_INET6, ndopthdr.family)
923    self.assertEqual(self.ND_ROUTER_ADVERT, ndopthdr.icmp_type)
924    self.assertEqual(len(opt), ndopthdr.opts_len)
925
926    actual_opt = self.Pref64Option(data)
927    self.assertEqual(opt, actual_opt)
928
929
930
931class PMTUTest(multinetwork_base.InboundMarkingTest):
932
933  PAYLOAD_SIZE = 1400
934  dstaddrs = set()
935
936  def GetSocketMTU(self, version, s):
937    if version == 6:
938      ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, csocket.IPV6_PATHMTU, 32)
939      unused_sockaddr, mtu = struct.unpack("=28sI", ip6_mtuinfo)
940      return mtu
941    else:
942      return s.getsockopt(net_test.SOL_IP, csocket.IP_MTU)
943
944  def DisableFragmentationAndReportErrors(self, version, s):
945    if version == 4:
946      s.setsockopt(net_test.SOL_IP, csocket.IP_MTU_DISCOVER,
947                   csocket.IP_PMTUDISC_DO)
948      s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
949    else:
950      s.setsockopt(net_test.SOL_IPV6, csocket.IPV6_DONTFRAG, 1)
951      s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
952
953  def CheckPMTU(self, version, use_connect, modes):
954
955    def SendBigPacket(version, s, dstaddr, netid, payload):
956      if use_connect:
957        s.send(payload)
958      else:
959        self.SendOnNetid(version, s, dstaddr, 1234, netid, payload, [])
960
961    for netid in self.tuns:
962      for mode in modes:
963        s = self.BuildSocket(version, net_test.UDPSocket, netid, mode)
964        self.DisableFragmentationAndReportErrors(version, s)
965
966        srcaddr = self.MyAddress(version, netid)
967        dst_prefix, intermediate = {
968            4: ("172.19.", "172.16.9.12"),
969            6: ("2001:db8::", "2001:db8::1")
970        }[version]
971
972        # Run this test often enough (e.g., in presubmits), and eventually
973        # we'll be unlucky enough to pick the same address twice, in which
974        # case the test will fail because the kernel will already have seen
975        # the lower MTU. Don't do this.
976        dstaddr = self.GetRandomDestination(dst_prefix)
977        while dstaddr in self.dstaddrs:
978          dstaddr = self.GetRandomDestination(dst_prefix)
979        self.dstaddrs.add(dstaddr)
980
981        if use_connect:
982          s.connect((dstaddr, 1234))
983
984        payload = self.PAYLOAD_SIZE * "a"
985
986        # Send a packet and receive a packet too big.
987        SendBigPacket(version, s, dstaddr, netid, payload)
988        received = self.ReadAllPacketsOn(netid)
989        self.assertEqual(1, len(received),
990                          "unexpected packets: %s" % received[1:])
991        _, toobig = packets.ICMPPacketTooBig(version, intermediate, srcaddr,
992                                             received[0])
993        self.ReceivePacketOn(netid, toobig)
994
995        # Check that another send on the same socket returns EMSGSIZE.
996        self.assertRaisesErrno(
997            errno.EMSGSIZE,
998            SendBigPacket, version, s, dstaddr, netid, payload)
999
1000        # If this is a connected socket, make sure the socket MTU was set.
1001        # Note that in IPv4 this only started working in Linux 3.6!
1002        if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)):
1003          self.assertEqual(packets.PTB_MTU, self.GetSocketMTU(version, s))
1004
1005        s.close()
1006
1007        # Check that other sockets pick up the PMTU we have been told about by
1008        # connecting another socket to the same destination and getting its MTU.
1009        # This new socket can use any method to select its outgoing interface;
1010        # here we use a mark for simplicity.
1011        s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
1012        s2.connect((dstaddr, 1234))
1013        self.assertEqual(packets.PTB_MTU, self.GetSocketMTU(version, s2))
1014
1015        # Also check the MTU reported by ip route get, this time using the oif.
1016        routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None)
1017        self.assertTrue(routes)
1018        route = routes[0]
1019        rtmsg, attributes = route
1020        self.assertEqual(iproute.RTN_UNICAST, rtmsg.type)
1021        metrics = attributes["RTA_METRICS"]
1022        self.assertEqual(packets.PTB_MTU, metrics["RTAX_MTU"])
1023
1024  def testIPv4BasicPMTU(self):
1025    """Tests IPv4 path MTU discovery.
1026
1027    Relevant kernel commits:
1028      upstream net-next:
1029        6a66271 ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
1030
1031      android-3.10:
1032        4bc64dd ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
1033    """
1034
1035    self.CheckPMTU(4, True, ["mark", "oif"])
1036    self.CheckPMTU(4, False, ["mark", "oif"])
1037
1038  def testIPv6BasicPMTU(self):
1039    self.CheckPMTU(6, True, ["mark", "oif"])
1040    self.CheckPMTU(6, False, ["mark", "oif"])
1041
1042  def testIPv4UIDPMTU(self):
1043    self.CheckPMTU(4, True, ["uid"])
1044    self.CheckPMTU(4, False, ["uid"])
1045
1046  def testIPv6UIDPMTU(self):
1047    self.CheckPMTU(6, True, ["uid"])
1048    self.CheckPMTU(6, False, ["uid"])
1049
1050  # Making Path MTU Discovery work on unmarked  sockets requires that mark
1051  # reflection be enabled. Otherwise the kernel has no way to know what routing
1052  # table the original packet used, and thus it won't be able to clone the
1053  # correct route.
1054
1055  def testIPv4UnmarkedSocketPMTU(self):
1056    self.SetMarkReflectSysctls(1)
1057    try:
1058      self.CheckPMTU(4, False, [None])
1059    finally:
1060      self.SetMarkReflectSysctls(0)
1061
1062  def testIPv6UnmarkedSocketPMTU(self):
1063    self.SetMarkReflectSysctls(1)
1064    try:
1065      self.CheckPMTU(6, False, [None])
1066    finally:
1067      self.SetMarkReflectSysctls(0)
1068
1069
1070class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest):
1071  """Tests that per-UID routing works properly.
1072
1073  Relevant kernel commits:
1074    upstream net-next:
1075      7d99569460 net: ipv4: Don't crash if passing a null sk to ip_do_redirect.
1076      d109e61bfe net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
1077      35b80733b3 net: core: add missing check for uid_range in rule_exists.
1078      e2d118a1cb net: inet: Support UID-based routing in IP protocols.
1079      622ec2c9d5 net: core: add UID to flows, rules, and routes
1080      86741ec254 net: core: Add a UID field to struct sock.
1081
1082    android-3.18:
1083      b004e79504 net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
1084      04c0eace81 net: inet: Support UID-based routing in IP protocols.
1085      18c36d7b71 net: core: add UID to flows, rules, and routes
1086      80e3440721 net: core: Add a UID field to struct sock.
1087      fa8cc2c30c Revert "net: core: Support UID-based routing."
1088      b585141890 Revert "Handle 'sk' being NULL in UID-based routing."
1089      5115ab7514 Revert "net: core: fix UID-based routing build"
1090      f9f4281f79 Revert "ANDROID: net: fib: remove duplicate assignment"
1091
1092    android-4.4:
1093      341965cf10 net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
1094      344afd627c net: inet: Support UID-based routing in IP protocols.
1095      03441d56d8 net: core: add UID to flows, rules, and routes
1096      eb964bdba7 net: core: Add a UID field to struct sock.
1097      9789b697c6 Revert "net: core: Support UID-based routing."
1098  """
1099
1100  def GetRulesAtPriority(self, version, priority):
1101    rules = self.iproute.DumpRules(version)
1102    out = [(rule, attributes) for rule, attributes in rules
1103           if attributes.get("FRA_PRIORITY", 0) == priority]
1104    return out
1105
1106  def CheckInitialTablesHaveNoUIDs(self, version):
1107    rules = []
1108    for priority in [0, 32766, 32767]:
1109      rules.extend(self.GetRulesAtPriority(version, priority))
1110    for _, attributes in rules:
1111      self.assertNotIn("FRA_UID_RANGE", attributes)
1112
1113  def testIPv4InitialTablesHaveNoUIDs(self):
1114    self.CheckInitialTablesHaveNoUIDs(4)
1115
1116  def testIPv6InitialTablesHaveNoUIDs(self):
1117    self.CheckInitialTablesHaveNoUIDs(6)
1118
1119  @staticmethod
1120  def _Random():
1121    return random.randint(1000000, 2000000)
1122
1123  def CheckGetAndSetRules(self, version):
1124    start, end = tuple(sorted([self._Random(), self._Random()]))
1125    table = self._Random()
1126    priority = self._Random()
1127
1128    # Can't create a UID range to UID -1 because -1 is INVALID_UID...
1129    self.assertRaisesErrno(
1130        errno.EINVAL,
1131        self.iproute.UidRangeRule, version, True, 100, 0xffffffff, table,
1132        priority)
1133
1134    # ... but -2 is valid.
1135    self.iproute.UidRangeRule(version, True, 100, 0xfffffffe, table, priority)
1136    self.iproute.UidRangeRule(version, False, 100, 0xfffffffe, table, priority)
1137
1138    try:
1139      # Create a UID range rule.
1140      self.iproute.UidRangeRule(version, True, start, end, table, priority)
1141
1142      # Check that deleting the wrong UID range doesn't work.
1143      self.assertRaisesErrno(
1144          errno.ENOENT,
1145          self.iproute.UidRangeRule, version, False, start, end + 1, table,
1146          priority)
1147      self.assertRaisesErrno(errno.ENOENT,
1148        self.iproute.UidRangeRule, version, False, start + 1, end, table,
1149        priority)
1150
1151      # Check that the UID range appears in dumps.
1152      rules = self.GetRulesAtPriority(version, priority)
1153      self.assertTrue(rules)
1154      _, attributes = rules[-1]
1155      self.assertEqual(priority, attributes["FRA_PRIORITY"])
1156      uidrange = attributes["FRA_UID_RANGE"]
1157      self.assertEqual(start, uidrange.start)
1158      self.assertEqual(end, uidrange.end)
1159      self.assertEqual(table, attributes["FRA_TABLE"])
1160    finally:
1161      self.iproute.UidRangeRule(version, False, start, end, table, priority)
1162      self.assertRaisesErrno(
1163          errno.ENOENT,
1164          self.iproute.UidRangeRule, version, False, start, end, table,
1165          priority)
1166
1167    fwmask = 0xfefefefe
1168    try:
1169      # Create a rule without a UID range.
1170      self.iproute.FwmarkRule(version, True, 300, fwmask, 301, priority + 1)
1171
1172      # Check it doesn't have a UID range.
1173      rules = self.GetRulesAtPriority(version, priority + 1)
1174      self.assertTrue(rules)
1175      for _, attributes in rules:
1176        self.assertIn("FRA_TABLE", attributes)
1177        self.assertNotIn("FRA_UID_RANGE", attributes)
1178    finally:
1179      self.iproute.FwmarkRule(version, False, 300, fwmask, 301, priority + 1)
1180
1181    # Test that EEXIST worksfor UID range rules too. This behaviour was only
1182    # added in 4.8.
1183    if net_test.LINUX_VERSION >= (4, 8, 0):
1184      ranges = [(100, 101), (100, 102), (99, 101), (1234, 5678)]
1185      dup = ranges[0]
1186      try:
1187        # Check that otherwise identical rules with different UID ranges can be
1188        # created without EEXIST.
1189        for start, end in ranges:
1190          self.iproute.UidRangeRule(version, True, start, end, table, priority)
1191        # ... but EEXIST is returned if the UID range is identical.
1192        self.assertRaisesErrno(
1193          errno.EEXIST,
1194          self.iproute.UidRangeRule, version, True, dup[0], dup[1], table,
1195          priority)
1196      finally:
1197        # Clean up.
1198        for start, end in ranges + [dup]:
1199          try:
1200            self.iproute.UidRangeRule(version, False, start, end, table,
1201                                      priority)
1202          except IOError:
1203            pass
1204
1205  def testIPv4GetAndSetRules(self):
1206    self.CheckGetAndSetRules(4)
1207
1208  def testIPv6GetAndSetRules(self):
1209    self.CheckGetAndSetRules(6)
1210
1211  @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not backported")
1212  def testDeleteErrno(self):
1213    for version in [4, 6]:
1214      table = self._Random()
1215      priority = self._Random()
1216      self.assertRaisesErrno(
1217          errno.EINVAL,
1218          self.iproute.UidRangeRule, version, False, 100, 0xffffffff, table,
1219          priority)
1220
1221  def ExpectNoRoute(self, addr, oif, mark, uid):
1222    # The lack of a route may be either an error, or an unreachable route.
1223    try:
1224      routes = self.iproute.GetRoutes(addr, oif, mark, uid)
1225      rtmsg, _ = routes[0]
1226      self.assertEqual(iproute.RTN_UNREACHABLE, rtmsg.type)
1227    except IOError as e:
1228      if int(e.errno) != int(errno.ENETUNREACH):
1229        raise e
1230
1231  def ExpectRoute(self, addr, oif, mark, uid):
1232    routes = self.iproute.GetRoutes(addr, oif, mark, uid)
1233    rtmsg, _ = routes[0]
1234    self.assertEqual(iproute.RTN_UNICAST, rtmsg.type)
1235
1236  def CheckGetRoute(self, version, addr):
1237    self.ExpectNoRoute(addr, 0, 0, 0)
1238    for netid in self.NETIDS:
1239      uid = self.UidForNetid(netid)
1240      self.ExpectRoute(addr, 0, 0, uid)
1241    self.ExpectNoRoute(addr, 0, 0, 0)
1242
1243  def testIPv4RouteGet(self):
1244    self.CheckGetRoute(4, net_test.IPV4_ADDR)
1245
1246  def testIPv6RouteGet(self):
1247    self.CheckGetRoute(6, net_test.IPV6_ADDR)
1248
1249  def testChangeFdAttributes(self):
1250    netid = random.choice(self.NETIDS)
1251    uid = self._Random()
1252    table = self._TableForNetid(netid)
1253    remoteaddr = self.GetRemoteAddress(6)
1254    s = socket(AF_INET6, SOCK_DGRAM, 0)
1255
1256    def CheckSendFails():
1257      self.assertRaisesErrno(errno.ENETUNREACH,
1258                             s.sendto, "foo", (remoteaddr, 53))
1259    def CheckSendSucceeds():
1260      self.assertEqual(len("foo"), s.sendto("foo", (remoteaddr, 53)))
1261
1262    CheckSendFails()
1263    self.iproute.UidRangeRule(6, True, uid, uid, table, self.PRIORITY_UID)
1264    try:
1265      CheckSendFails()
1266      os.fchown(s.fileno(), uid, -1)
1267      CheckSendSucceeds()
1268      os.fchown(s.fileno(), -1, -1)
1269      CheckSendSucceeds()
1270      os.fchown(s.fileno(), -1, 12345)
1271      CheckSendSucceeds()
1272      os.fchmod(s.fileno(), 0o777)
1273      CheckSendSucceeds()
1274      os.fchown(s.fileno(), 0, -1)
1275      CheckSendFails()
1276    finally:
1277      self.iproute.UidRangeRule(6, False, uid, uid, table, self.PRIORITY_UID)
1278
1279
1280class RulesTest(net_test.NetworkTest):
1281
1282  RULE_PRIORITY = 99999
1283  FWMASK = 0xffffffff
1284
1285  def setUp(self):
1286    self.iproute = iproute.IPRoute()
1287    for version in [4, 6]:
1288      self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
1289
1290  def tearDown(self):
1291    for version in [4, 6]:
1292      self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
1293
1294  def testRuleDeletionMatchesTable(self):
1295    for version in [4, 6]:
1296      # Add rules with mark 300 pointing at tables 301 and 302.
1297      # This checks for a kernel bug where deletion request for tables > 256
1298      # ignored the table.
1299      self.iproute.FwmarkRule(version, True, 300, self.FWMASK, 301,
1300                              priority=self.RULE_PRIORITY)
1301      self.iproute.FwmarkRule(version, True, 300, self.FWMASK, 302,
1302                              priority=self.RULE_PRIORITY)
1303      # Delete rule with mark 300 pointing at table 302.
1304      self.iproute.FwmarkRule(version, False, 300, self.FWMASK, 302,
1305                              priority=self.RULE_PRIORITY)
1306      # Check that the rule pointing at table 301 is still around.
1307      attributes = [a for _, a in self.iproute.DumpRules(version)
1308                    if a.get("FRA_PRIORITY", 0) == self.RULE_PRIORITY]
1309      self.assertEqual(1, len(attributes))
1310      self.assertEqual(301, attributes[0]["FRA_TABLE"])
1311
1312
1313if __name__ == "__main__":
1314  unittest.main()
1315