• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import ctypes
18import errno
19import os
20import random
21from socket import *  # pylint: disable=wildcard-import
22import struct
23import time           # pylint: disable=unused-import
24import unittest
25
26from scapy import all as scapy
27
28import csocket
29import iproute
30import multinetwork_base
31import net_test
32import packets
33
34# For brevity.
35UDP_PAYLOAD = net_test.UDP_PAYLOAD
36
37IPV6_FLOWINFO = 11
38
39SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
40TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
41
42# The IP[V6]UNICAST_IF socket option was added between 3.1 and 3.4.
43HAVE_UNICAST_IF = net_test.LINUX_VERSION >= (3, 4, 0)
44
45# RTPROT_RA is working properly with 4.14
46HAVE_RTPROT_RA = net_test.LINUX_VERSION >= (4, 14, 0)
47
48class ConfigurationError(AssertionError):
49  pass
50
51
52class OutgoingTest(multinetwork_base.MultiNetworkBaseTest):
53
54  # How many times to run outgoing packet tests.
55  ITERATIONS = 5
56
57  def CheckPingPacket(self, version, netid, routing_mode, packet):
58    s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
59
60    myaddr = self.MyAddress(version, netid)
61    mysockaddr = self.MySocketAddress(version, netid)
62    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
63    s.bind((mysockaddr, packets.PING_IDENT))
64    net_test.SetSocketTos(s, packets.PING_TOS)
65
66    dstaddr = self.GetRemoteAddress(version)
67    dstsockaddr = self.GetRemoteSocketAddress(version)
68    desc, expected = packets.ICMPEcho(version, myaddr, dstaddr)
69    msg = "IPv%d ping: expected %s on %s" % (
70        version, desc, self.GetInterfaceName(netid))
71
72    s.sendto(packet + packets.PING_PAYLOAD, (dstsockaddr, 19321))
73
74    self.ExpectPacketOn(netid, msg, expected)
75
76  def CheckTCPSYNPacket(self, version, netid, routing_mode):
77    s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
78
79    myaddr = self.MyAddress(version, netid)
80    dstaddr = self.GetRemoteAddress(version)
81    dstsockaddr = self.GetRemoteSocketAddress(version)
82    desc, expected = packets.SYN(53, version, myaddr, dstaddr,
83                                 sport=None, seq=None)
84
85
86    # Non-blocking TCP connects always return EINPROGRESS.
87    self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstsockaddr, 53))
88    msg = "IPv%s TCP connect: expected %s on %s" % (
89        version, desc, self.GetInterfaceName(netid))
90    self.ExpectPacketOn(netid, msg, expected)
91    s.close()
92
93  def CheckUDPPacket(self, version, netid, routing_mode):
94    s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
95
96    myaddr = self.MyAddress(version, netid)
97    dstaddr = self.GetRemoteAddress(version)
98    dstsockaddr = self.GetRemoteSocketAddress(version)
99
100    desc, expected = packets.UDP(version, myaddr, dstaddr, sport=None)
101    msg = "IPv%s UDP %%s: expected %s on %s" % (
102        version, desc, self.GetInterfaceName(netid))
103
104    s.sendto(UDP_PAYLOAD, (dstsockaddr, 53))
105    self.ExpectPacketOn(netid, msg % "sendto", expected)
106
107    # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
108    if routing_mode != "ucast_oif":
109      s.connect((dstsockaddr, 53))
110      s.send(UDP_PAYLOAD)
111      self.ExpectPacketOn(netid, msg % "connect/send", expected)
112      s.close()
113
114  def CheckRawGrePacket(self, version, netid, routing_mode):
115    s = self.BuildSocket(version, net_test.RawGRESocket, netid, routing_mode)
116
117    inner_version = {4: 6, 6: 4}[version]
118    inner_src = self.MyAddress(inner_version, netid)
119    inner_dst = self.GetRemoteAddress(inner_version)
120    inner = str(packets.UDP(inner_version, inner_src, inner_dst, sport=None)[1])
121
122    ethertype = {4: net_test.ETH_P_IP, 6: net_test.ETH_P_IPV6}[inner_version]
123    # A GRE header can be as simple as two zero bytes and the ethertype.
124    packet = struct.pack("!i", ethertype) + inner
125    myaddr = self.MyAddress(version, netid)
126    dstaddr = self.GetRemoteAddress(version)
127
128    s.sendto(packet, (dstaddr, IPPROTO_GRE))
129    desc, expected = packets.GRE(version, myaddr, dstaddr, ethertype, inner)
130    msg = "Raw IPv%d GRE with inner IPv%d UDP: expected %s on %s" % (
131        version, inner_version, desc, self.GetInterfaceName(netid))
132    self.ExpectPacketOn(netid, msg, expected)
133
134  def CheckOutgoingPackets(self, routing_mode):
135    for _ in xrange(self.ITERATIONS):
136      for netid in self.tuns:
137
138        self.CheckPingPacket(4, netid, routing_mode, self.IPV4_PING)
139        # Kernel bug.
140        if routing_mode != "oif":
141          self.CheckPingPacket(6, netid, routing_mode, self.IPV6_PING)
142
143        # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
144        if routing_mode != "ucast_oif":
145          self.CheckTCPSYNPacket(4, netid, routing_mode)
146          self.CheckTCPSYNPacket(6, netid, routing_mode)
147          self.CheckTCPSYNPacket(5, netid, routing_mode)
148
149        self.CheckUDPPacket(4, netid, routing_mode)
150        self.CheckUDPPacket(6, netid, routing_mode)
151        self.CheckUDPPacket(5, netid, routing_mode)
152
153        # Creating raw sockets on non-root UIDs requires properly setting
154        # capabilities, which is hard to do from Python.
155        # IP_UNICAST_IF is not supported on raw sockets.
156        if routing_mode not in ["uid", "ucast_oif"]:
157          self.CheckRawGrePacket(4, netid, routing_mode)
158          self.CheckRawGrePacket(6, netid, routing_mode)
159
160  def testMarkRouting(self):
161    """Checks that socket marking selects the right outgoing interface."""
162    self.CheckOutgoingPackets("mark")
163
164  def testUidRouting(self):
165    """Checks that UID routing selects the right outgoing interface."""
166    self.CheckOutgoingPackets("uid")
167
168  def testOifRouting(self):
169    """Checks that oif routing selects the right outgoing interface."""
170    self.CheckOutgoingPackets("oif")
171
172  @unittest.skipUnless(HAVE_UNICAST_IF, "no support for UNICAST_IF")
173  def testUcastOifRouting(self):
174    """Checks that ucast oif routing selects the right outgoing interface."""
175    self.CheckOutgoingPackets("ucast_oif")
176
177  def CheckRemarking(self, version, use_connect):
178    modes = ["mark", "oif", "uid"]
179    # Setting UNICAST_IF on connected sockets does not work.
180    if not use_connect and HAVE_UNICAST_IF:
181      modes += ["ucast_oif"]
182
183    for mode in modes:
184      s = net_test.UDPSocket(self.GetProtocolFamily(version))
185
186      # Figure out what packets to expect.
187      sport = net_test.BindRandomPort(version, s)
188      dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
189      unspec = {4: "0.0.0.0", 6: "::"}[version]  # Placeholder.
190      desc, expected = packets.UDP(version, unspec, dstaddr, sport)
191
192      # If we're testing connected sockets, connect the socket on the first
193      # netid now.
194      if use_connect:
195        netid = self.tuns.keys()[0]
196        self.SelectInterface(s, netid, mode)
197        s.connect((dstaddr, 53))
198        expected.src = self.MyAddress(version, netid)
199
200      # For each netid, select that network without closing the socket, and
201      # check that the packets sent on that socket go out on the right network.
202      #
203      # For connected sockets, routing is cached in the socket's destination
204      # cache entry. In this case, we check that just re-selecting the netid
205      # (except via SO_BINDTODEVICE) does not change routing, but that
206      # subsequently invalidating the destination cache entry does. Arguably
207      # this is a bug in the kernel because re-selecting the netid should cause
208      # routing to change. But it is a convenient way to check that
209      # InvalidateDstCache actually works.
210      prevnetid = None
211      for netid in self.tuns:
212        self.SelectInterface(s, netid, mode)
213        if not use_connect:
214          expected.src = self.MyAddress(version, netid)
215
216        def ExpectSendUsesNetid(netid):
217          connected_str = "Connected" if use_connect else "Unconnected"
218          msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
219              connected_str, version, mode, desc, self.GetInterfaceName(netid))
220          if use_connect:
221            s.send(UDP_PAYLOAD)
222          else:
223            s.sendto(UDP_PAYLOAD, (dstaddr, 53))
224          self.ExpectPacketOn(netid, msg, expected)
225
226        if use_connect and mode in ["mark", "uid", "ucast_oif"]:
227          # If we have a destination cache entry, packets are not rerouted...
228          if prevnetid:
229            ExpectSendUsesNetid(prevnetid)
230            # ... until we invalidate it.
231            self.InvalidateDstCache(version, prevnetid)
232          ExpectSendUsesNetid(netid)
233        else:
234          ExpectSendUsesNetid(netid)
235
236        self.SelectInterface(s, None, mode)
237        prevnetid = netid
238
239  def testIPv4Remarking(self):
240    """Checks that updating the mark on an IPv4 socket changes routing."""
241    self.CheckRemarking(4, False)
242    self.CheckRemarking(4, True)
243
244  def testIPv6Remarking(self):
245    """Checks that updating the mark on an IPv6 socket changes routing."""
246    self.CheckRemarking(6, False)
247    self.CheckRemarking(6, True)
248
249  def testIPv6StickyPktinfo(self):
250    for _ in xrange(self.ITERATIONS):
251      for netid in self.tuns:
252        s = net_test.UDPSocket(AF_INET6)
253
254        # Set a flowlabel.
255        net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xdead)
256        s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_FLOWINFO_SEND, 1)
257
258        # Set some destination options.
259        nonce = "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c"
260        dstopts = "".join([
261            "\x11\x02",              # Next header=UDP, 24 bytes of options.
262            "\x01\x06", "\x00" * 6,  # PadN, 6 bytes of padding.
263            "\x8b\x0c",              # ILNP nonce, 12 bytes.
264            nonce
265        ])
266        s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, dstopts)
267        s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_HOPS, 255)
268
269        pktinfo = multinetwork_base.MakePktInfo(6, None, self.ifindices[netid])
270
271        # Set the sticky pktinfo option.
272        s.setsockopt(net_test.SOL_IPV6, IPV6_PKTINFO, pktinfo)
273
274        # Specify the flowlabel in the destination address.
275        s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 53, 0xdead, 0))
276
277        sport = s.getsockname()[1]
278        srcaddr = self.MyAddress(6, netid)
279        expected = (scapy.IPv6(src=srcaddr, dst=net_test.IPV6_ADDR,
280                               fl=0xdead, hlim=255) /
281                    scapy.IPv6ExtHdrDestOpt(
282                        options=[scapy.PadN(optdata="\x00\x00\x00\x00\x00\x00"),
283                                 scapy.HBHOptUnknown(otype=0x8b,
284                                                     optdata=nonce)]) /
285                    scapy.UDP(sport=sport, dport=53) /
286                    UDP_PAYLOAD)
287        msg = "IPv6 UDP using sticky pktinfo: expected UDP packet on %s" % (
288            self.GetInterfaceName(netid))
289        self.ExpectPacketOn(netid, msg, expected)
290
291  def CheckPktinfoRouting(self, version):
292    for _ in xrange(self.ITERATIONS):
293      for netid in self.tuns:
294        family = self.GetProtocolFamily(version)
295        s = net_test.UDPSocket(family)
296
297        if version == 6:
298          # Create a flowlabel so we can use it.
299          net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xbeef)
300
301          # Specify some arbitrary options.
302          # We declare the flowlabel as ctypes.c_uint32 because on a 32-bit
303          # Python interpreter an integer greater than 0x7fffffff (such as our
304          # chosen flowlabel after being passed through htonl) is converted to
305          # long, and _MakeMsgControl doesn't know what to do with longs.
306          cmsgs = [
307              (net_test.SOL_IPV6, IPV6_HOPLIMIT, 39),
308              (net_test.SOL_IPV6, IPV6_TCLASS, 0x83),
309              (net_test.SOL_IPV6, IPV6_FLOWINFO, ctypes.c_uint(htonl(0xbeef))),
310          ]
311        else:
312          # Support for setting IPv4 TOS and TTL via cmsg only appeared in 3.13.
313          cmsgs = []
314          s.setsockopt(net_test.SOL_IP, IP_TTL, 39)
315          s.setsockopt(net_test.SOL_IP, IP_TOS, 0x83)
316
317        dstaddr = self.GetRemoteAddress(version)
318        self.SendOnNetid(version, s, dstaddr, 53, netid, UDP_PAYLOAD, cmsgs)
319
320        sport = s.getsockname()[1]
321        srcaddr = self.MyAddress(version, netid)
322
323        desc, expected = packets.UDPWithOptions(version, srcaddr, dstaddr,
324                                                sport=sport)
325
326        msg = "IPv%d UDP using pktinfo routing: expected %s on %s" % (
327            version, desc, self.GetInterfaceName(netid))
328        self.ExpectPacketOn(netid, msg, expected)
329
330  def testIPv4PktinfoRouting(self):
331    self.CheckPktinfoRouting(4)
332
333  def testIPv6PktinfoRouting(self):
334    self.CheckPktinfoRouting(6)
335
336
337class MarkTest(multinetwork_base.InboundMarkingTest):
338
339  def CheckReflection(self, version, gen_packet, gen_reply):
340    """Checks that replies go out on the same interface as the original.
341
342    For each combination:
343     - Calls gen_packet to generate a packet to that IP address.
344     - Writes the packet generated by gen_packet on the given tun
345       interface, causing the kernel to receive it.
346     - Checks that the kernel's reply matches the packet generated by
347       gen_reply.
348
349    Args:
350      version: An integer, 4 or 6.
351      gen_packet: A function taking an IP version (an integer), a source
352        address and a destination address (strings), and returning a scapy
353        packet.
354      gen_reply: A function taking the same arguments as gen_packet,
355        plus a scapy packet, and returning a scapy packet.
356    """
357    for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
358      # Generate a test packet.
359      desc, packet = gen_packet(version, remoteaddr, myaddr)
360
361      # Test with mark reflection enabled and disabled.
362      for reflect in [0, 1]:
363        self.SetMarkReflectSysctls(reflect)
364        # HACK: IPv6 ping replies always do a routing lookup with the
365        # interface the ping came in on. So even if mark reflection is not
366        # working, IPv6 ping replies will be properly reflected. Don't
367        # fail when that happens.
368        if reflect or desc == "ICMPv6 echo":
369          reply_desc, reply = gen_reply(version, myaddr, remoteaddr, packet)
370        else:
371          reply_desc, reply = None, None
372
373        msg = self._FormatMessage(iif, ip_if, "reflect=%d" % reflect,
374                                  desc, reply_desc)
375        self._ReceiveAndExpectResponse(netid, packet, reply, msg)
376
377  def SYNToClosedPort(self, *args):
378    return packets.SYN(999, *args)
379
380  def testIPv4ICMPErrorsReflectMark(self):
381    self.CheckReflection(4, packets.UDP, packets.ICMPPortUnreachable)
382
383  def testIPv6ICMPErrorsReflectMark(self):
384    self.CheckReflection(6, packets.UDP, packets.ICMPPortUnreachable)
385
386  def testIPv4PingRepliesReflectMarkAndTos(self):
387    self.CheckReflection(4, packets.ICMPEcho, packets.ICMPReply)
388
389  def testIPv6PingRepliesReflectMarkAndTos(self):
390    self.CheckReflection(6, packets.ICMPEcho, packets.ICMPReply)
391
392  def testIPv4RSTsReflectMark(self):
393    self.CheckReflection(4, self.SYNToClosedPort, packets.RST)
394
395  def testIPv6RSTsReflectMark(self):
396    self.CheckReflection(6, self.SYNToClosedPort, packets.RST)
397
398
399class TCPAcceptTest(multinetwork_base.InboundMarkingTest):
400
401  MODE_BINDTODEVICE = "SO_BINDTODEVICE"
402  MODE_INCOMING_MARK = "incoming mark"
403  MODE_EXPLICIT_MARK = "explicit mark"
404  MODE_UID = "uid"
405
406  @classmethod
407  def setUpClass(cls):
408    super(TCPAcceptTest, cls).setUpClass()
409
410    # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
411    # will accept both IPv4 and IPv6 connections. We do this here instead of in
412    # each test so we can use the same socket every time. That way, if a kernel
413    # bug causes incoming packets to mark the listening socket instead of the
414    # accepted socket, the test will fail as soon as the next address/interface
415    # combination is tried.
416    cls.listensocket = net_test.IPv6TCPSocket()
417    cls.listenport = net_test.BindRandomPort(6, cls.listensocket)
418
419  def _SetTCPMarkAcceptSysctl(self, value):
420    self.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
421
422  def CheckTCPConnection(self, mode, listensocket, netid, version,
423                         myaddr, remoteaddr, packet, reply, msg):
424    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
425
426    # Attempt to confuse the kernel.
427    self.InvalidateDstCache(version, netid)
428
429    self.ReceivePacketOn(netid, establishing_ack)
430
431    # If we're using UID routing, the accept() call has to be run as a UID that
432    # is routed to the specified netid, because the UID of the socket returned
433    # by accept() is the effective UID of the process that calls it. It doesn't
434    # need to be the same UID; any UID that selects the same interface will do.
435    with net_test.RunAsUid(self.UidForNetid(netid)):
436      s, _ = listensocket.accept()
437
438    try:
439      # Check that data sent on the connection goes out on the right interface.
440      desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
441                               payload=UDP_PAYLOAD)
442      s.send(UDP_PAYLOAD)
443      self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
444      self.InvalidateDstCache(version, netid)
445
446      # Keep up our end of the conversation.
447      ack = packets.ACK(version, remoteaddr, myaddr, data)[1]
448      self.InvalidateDstCache(version, netid)
449      self.ReceivePacketOn(netid, ack)
450
451      mark = self.GetSocketMark(s)
452    finally:
453      self.InvalidateDstCache(version, netid)
454      s.close()
455      self.InvalidateDstCache(version, netid)
456
457    if mode == self.MODE_INCOMING_MARK:
458      self.assertEquals(netid, mark & self.NETID_FWMASK,
459                        msg + ": Accepted socket: Expected mark %d, got %d" % (
460                            netid, mark))
461    elif mode != self.MODE_EXPLICIT_MARK:
462      self.assertEquals(0, self.GetSocketMark(listensocket))
463
464    # Check the FIN was sent on the right interface, and ack it. We don't expect
465    # this to fail because by the time the connection is established things are
466    # likely working, but a) extra tests are always good and b) extra packets
467    # like the FIN (and retransmitted FINs) could cause later tests that expect
468    # no packets to fail.
469    desc, fin = packets.FIN(version, myaddr, remoteaddr, ack)
470    self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
471
472    desc, finack = packets.FIN(version, remoteaddr, myaddr, fin)
473    self.ReceivePacketOn(netid, finack)
474
475    # Since we called close() earlier, the userspace socket object is gone, so
476    # the socket has no UID. If we're doing UID routing, the ack might be routed
477    # incorrectly. Not much we can do here.
478    desc, finackack = packets.ACK(version, myaddr, remoteaddr, finack)
479    self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
480
481  def CheckTCP(self, version, modes):
482    """Checks that incoming TCP connections work.
483
484    Args:
485      version: An integer, 4 or 6.
486      modes: A list of modes to excercise.
487    """
488    for syncookies in [0, 2]:
489      for mode in modes:
490        for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
491          listensocket = self.listensocket
492          listenport = listensocket.getsockname()[1]
493
494          accept_sysctl = 1 if mode == self.MODE_INCOMING_MARK else 0
495          self._SetTCPMarkAcceptSysctl(accept_sysctl)
496          self.SetMarkReflectSysctls(accept_sysctl)
497
498          bound_dev = iif if mode == self.MODE_BINDTODEVICE else None
499          self.BindToDevice(listensocket, bound_dev)
500
501          mark = netid if mode == self.MODE_EXPLICIT_MARK else 0
502          self.SetSocketMark(listensocket, mark)
503
504          uid = self.UidForNetid(netid) if mode == self.MODE_UID else 0
505          os.fchown(listensocket.fileno(), uid, -1)
506
507          # Generate the packet here instead of in the outer loop, so
508          # subsequent TCP connections use different source ports and
509          # retransmissions from old connections don't confuse subsequent
510          # tests.
511          desc, packet = packets.SYN(listenport, version, remoteaddr, myaddr)
512
513          if mode:
514            reply_desc, reply = packets.SYNACK(version, myaddr, remoteaddr,
515                                               packet)
516          else:
517            reply_desc, reply = None, None
518
519          extra = "mode=%s, syncookies=%d" % (mode, syncookies)
520          msg = self._FormatMessage(iif, ip_if, extra, desc, reply_desc)
521          reply = self._ReceiveAndExpectResponse(netid, packet, reply, msg)
522          if reply:
523            self.CheckTCPConnection(mode, listensocket, netid, version, myaddr,
524                                    remoteaddr, packet, reply, msg)
525
526  def testBasicTCP(self):
527    self.CheckTCP(4, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
528    self.CheckTCP(6, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
529
530  def testIPv4MarkAccept(self):
531    self.CheckTCP(4, [self.MODE_INCOMING_MARK])
532
533  def testIPv6MarkAccept(self):
534    self.CheckTCP(6, [self.MODE_INCOMING_MARK])
535
536  def testIPv4UidAccept(self):
537    self.CheckTCP(4, [self.MODE_UID])
538
539  def testIPv6UidAccept(self):
540    self.CheckTCP(6, [self.MODE_UID])
541
542  def testIPv6ExplicitMark(self):
543    self.CheckTCP(6, [self.MODE_EXPLICIT_MARK])
544
545@unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
546                     "need support for per-table autoconf")
547class RIOTest(multinetwork_base.MultiNetworkBaseTest):
548  """Test for IPv6 RFC 4191 route information option
549
550  Relevant kernel commits:
551    upstream:
552      f104a567e673 ipv6: use rt6_get_dflt_router to get default router in rt6_route_rcv
553      bbea124bc99d net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
554
555    android-4.9:
556      d860b2e8a7f1 FROMLIST: net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs
557
558    android-4.4:
559      e953f89b8563 net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
560
561    android-4.1:
562      84f2f47716cd net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
563
564    android-3.18:
565      65f8936934fa net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
566
567    android-3.10:
568      161e88ebebc7 net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
569
570  """
571
572  def setUp(self):
573    super(RIOTest, self).setUp()
574    self.NETID = random.choice(self.NETIDS)
575    self.IFACE = self.GetInterfaceName(self.NETID)
576    # return min/max plen to default values before each test case
577    self.SetAcceptRaRtInfoMinPlen(0)
578    self.SetAcceptRaRtInfoMaxPlen(0)
579
580  def GetRoutingTable(self):
581    return self._TableForNetid(self.NETID)
582
583  def SetAcceptRaRtInfoMinPlen(self, plen):
584    self.SetSysctl(
585        "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_min_plen"
586        % self.IFACE, plen)
587
588  def GetAcceptRaRtInfoMinPlen(self):
589    return int(self.GetSysctl(
590        "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_min_plen" % self.IFACE))
591
592  def SetAcceptRaRtInfoMaxPlen(self, plen):
593    self.SetSysctl(
594        "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_max_plen"
595        % self.IFACE, plen)
596
597  def GetAcceptRaRtInfoMaxPlen(self):
598    return int(self.GetSysctl(
599        "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_max_plen" % self.IFACE))
600
601  def SendRIO(self, rtlifetime, plen, prefix, prf):
602    options = scapy.ICMPv6NDOptRouteInfo(rtlifetime=rtlifetime, plen=plen,
603                                         prefix=prefix, prf=prf)
604    self.SendRA(self.NETID, options=(options,))
605
606  def FindRoutesWithDestination(self, destination):
607    canonical = net_test.CanonicalizeIPv6Address(destination)
608    return [r for _, r in self.iproute.DumpRoutes(6, self.GetRoutingTable())
609            if ('RTA_DST' in r and r['RTA_DST'] == canonical)]
610
611  def FindRoutesWithGateway(self):
612    return [r for _, r in self.iproute.DumpRoutes(6, self.GetRoutingTable())
613            if 'RTA_GATEWAY' in r]
614
615  def CountRoutes(self):
616    return len(self.iproute.DumpRoutes(6, self.GetRoutingTable()))
617
618  def GetRouteExpiration(self, route):
619    return float(route['RTA_CACHEINFO'].expires) / 100.0
620
621  def AssertExpirationInRange(self, routes, lifetime, epsilon):
622    self.assertTrue(routes)
623    found = False
624    # Assert that at least one route in routes has the expected lifetime
625    for route in routes:
626      expiration = self.GetRouteExpiration(route)
627      if expiration < lifetime - epsilon:
628        continue
629      if expiration > lifetime + epsilon:
630        continue
631      found = True
632    self.assertTrue(found)
633
634  def DelRA6(self, prefix, plen):
635    version = 6
636    netid = self.NETID
637    table = self._TableForNetid(netid)
638    router = self._RouterAddress(netid, version)
639    ifindex = self.ifindices[netid]
640    # We actually want to specify RTPROT_RA, however an upstream
641    # kernel bug causes RAs to be installed with RTPROT_BOOT.
642    if HAVE_RTPROT_RA:
643       rtprot = iproute.RTPROT_RA
644    else:
645       rtprot = iproute.RTPROT_BOOT
646    self.iproute._Route(version, rtprot, iproute.RTM_DELROUTE,
647                        table, prefix, plen, router, ifindex, None, None)
648
649  def testSetAcceptRaRtInfoMinPlen(self):
650    for plen in xrange(-1, 130):
651      self.SetAcceptRaRtInfoMinPlen(plen)
652      self.assertEquals(plen, self.GetAcceptRaRtInfoMinPlen())
653
654  def testSetAcceptRaRtInfoMaxPlen(self):
655    for plen in xrange(-1, 130):
656      self.SetAcceptRaRtInfoMaxPlen(plen)
657      self.assertEquals(plen, self.GetAcceptRaRtInfoMaxPlen())
658
659  def testZeroRtLifetime(self):
660    PREFIX = "2001:db8:8901:2300::"
661    RTLIFETIME = 73500
662    PLEN = 56
663    PRF = 0
664    self.SetAcceptRaRtInfoMaxPlen(PLEN)
665    self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
666    # Give the kernel time to notice our RA
667    time.sleep(0.01)
668    self.assertTrue(self.FindRoutesWithDestination(PREFIX))
669    # RIO with rtlifetime = 0 should remove from routing table
670    self.SendRIO(0, PLEN, PREFIX, PRF)
671    # Give the kernel time to notice our RA
672    time.sleep(0.01)
673    self.assertFalse(self.FindRoutesWithDestination(PREFIX))
674
675  def testMinPrefixLenRejection(self):
676    PREFIX = "2001:db8:8902:2345::"
677    RTLIFETIME = 70372
678    PRF = 0
679    # sweep from high to low to avoid spurious failures from late arrivals.
680    for plen in xrange(130, 1, -1):
681      self.SetAcceptRaRtInfoMinPlen(plen)
682      # RIO with plen < min_plen should be ignored
683      self.SendRIO(RTLIFETIME, plen - 1, PREFIX, PRF)
684    # Give the kernel time to notice our RAs
685    time.sleep(0.1)
686    # Expect no routes
687    routes = self.FindRoutesWithDestination(PREFIX)
688    self.assertFalse(routes)
689
690  def testMaxPrefixLenRejection(self):
691    PREFIX = "2001:db8:8903:2345::"
692    RTLIFETIME = 73078
693    PRF = 0
694    # sweep from low to high to avoid spurious failures from late arrivals.
695    for plen in xrange(-1, 128, 1):
696      self.SetAcceptRaRtInfoMaxPlen(plen)
697      # RIO with plen > max_plen should be ignored
698      self.SendRIO(RTLIFETIME, plen + 1, PREFIX, PRF)
699    # Give the kernel time to notice our RAs
700    time.sleep(0.1)
701    # Expect no routes
702    routes = self.FindRoutesWithDestination(PREFIX)
703    self.assertFalse(routes)
704
705  def testSimpleAccept(self):
706    PREFIX = "2001:db8:8904:2345::"
707    RTLIFETIME = 9993
708    PRF = 0
709    PLEN = 56
710    self.SetAcceptRaRtInfoMinPlen(48)
711    self.SetAcceptRaRtInfoMaxPlen(64)
712    self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
713    # Give the kernel time to notice our RA
714    time.sleep(0.01)
715    routes = self.FindRoutesWithGateway()
716    self.AssertExpirationInRange(routes, RTLIFETIME, 1)
717    self.DelRA6(PREFIX, PLEN)
718
719  def testEqualMinMaxAccept(self):
720    PREFIX = "2001:db8:8905:2345::"
721    RTLIFETIME = 6326
722    PLEN = 21
723    PRF = 0
724    self.SetAcceptRaRtInfoMinPlen(PLEN)
725    self.SetAcceptRaRtInfoMaxPlen(PLEN)
726    self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
727    # Give the kernel time to notice our RA
728    time.sleep(0.01)
729    routes = self.FindRoutesWithGateway()
730    self.AssertExpirationInRange(routes, RTLIFETIME, 1)
731    self.DelRA6(PREFIX, PLEN)
732
733  def testZeroLengthPrefix(self):
734    PREFIX = "2001:db8:8906:2345::"
735    RTLIFETIME = self.RA_VALIDITY * 2
736    PLEN = 0
737    PRF = 0
738    # Max plen = 0 still allows default RIOs!
739    self.SetAcceptRaRtInfoMaxPlen(PLEN)
740    self.SendRA(self.NETID)
741    # Give the kernel time to notice our RA
742    time.sleep(0.01)
743    default = self.FindRoutesWithGateway()
744    self.AssertExpirationInRange(default, self.RA_VALIDITY, 1)
745    # RIO with prefix length = 0, should overwrite default route lifetime
746    # note that the RIO lifetime overwrites the RA lifetime.
747    self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
748    # Give the kernel time to notice our RA
749    time.sleep(0.01)
750    default = self.FindRoutesWithGateway()
751    self.AssertExpirationInRange(default, RTLIFETIME, 1)
752    self.DelRA6(PREFIX, PLEN)
753
754  def testManyRIOs(self):
755    RTLIFETIME = 68012
756    PLEN = 56
757    PRF = 0
758    COUNT = 1000
759    baseline = self.CountRoutes()
760    self.SetAcceptRaRtInfoMaxPlen(56)
761    # Send many RIOs compared to the expected number on a healthy system.
762    for i in xrange(0, COUNT):
763      prefix = "2001:db8:%x:1100::" % i
764      self.SendRIO(RTLIFETIME, PLEN, prefix, PRF)
765    time.sleep(0.1)
766    self.assertEquals(COUNT + baseline, self.CountRoutes())
767    for i in xrange(0, COUNT):
768      prefix = "2001:db8:%x:1100::" % i
769      self.DelRA6(prefix, PLEN)
770    # Expect that we can return to baseline config without lingering routes.
771    self.assertEquals(baseline, self.CountRoutes())
772
773class RATest(multinetwork_base.MultiNetworkBaseTest):
774
775  def testDoesNotHaveObsoleteSysctl(self):
776    self.assertFalse(os.path.isfile(
777        "/proc/sys/net/ipv6/route/autoconf_table_offset"))
778
779  @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
780                       "no support for per-table autoconf")
781  def testPurgeDefaultRouters(self):
782
783    def CheckIPv6Connectivity(expect_connectivity):
784      for netid in self.NETIDS:
785        s = net_test.UDPSocket(AF_INET6)
786        self.SetSocketMark(s, netid)
787        if expect_connectivity:
788          self.assertTrue(s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 1234)))
789        else:
790          self.assertRaisesErrno(errno.ENETUNREACH, s.sendto, UDP_PAYLOAD,
791                                 (net_test.IPV6_ADDR, 1234))
792
793    try:
794      CheckIPv6Connectivity(True)
795      self.SetIPv6SysctlOnAllIfaces("accept_ra", 1)
796      self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
797      CheckIPv6Connectivity(False)
798    finally:
799      self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
800      for netid in self.NETIDS:
801        self.SendRA(netid)
802      CheckIPv6Connectivity(True)
803
804  def testOnlinkCommunication(self):
805    """Checks that on-link communication goes direct and not through routers."""
806    for netid in self.tuns:
807      # Send a UDP packet to a random on-link destination.
808      s = net_test.UDPSocket(AF_INET6)
809      iface = self.GetInterfaceName(netid)
810      self.BindToDevice(s, iface)
811      # dstaddr can never be our address because GetRandomDestination only fills
812      # in the lower 32 bits, but our address has 0xff in the byte before that
813      # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
814      dstaddr = self.GetRandomDestination(self.OnlinkPrefix(6, netid))
815      s.sendto(UDP_PAYLOAD, (dstaddr, 53))
816
817      # Expect an NS for that destination on the interface.
818      myaddr = self.MyAddress(6, netid)
819      mymac = self.MyMacAddress(netid)
820      desc, expected = packets.NS(myaddr, dstaddr, mymac)
821      msg = "Sending UDP packet to on-link destination: expecting %s" % desc
822      time.sleep(0.0001)  # Required to make the test work on kernel 3.1(!)
823      self.ExpectPacketOn(netid, msg, expected)
824
825      # Send an NA.
826      tgtmac = "02:00:00:00:%02x:99" % netid
827      _, reply = packets.NA(dstaddr, myaddr, tgtmac)
828      # Don't use ReceivePacketOn, since that uses the router's MAC address as
829      # the source. Instead, construct our own Ethernet header with source
830      # MAC of tgtmac.
831      reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
832      self.ReceiveEtherPacketOn(netid, reply)
833
834      # Expect the kernel to send the original UDP packet now that the ND cache
835      # entry has been populated.
836      sport = s.getsockname()[1]
837      desc, expected = packets.UDP(6, myaddr, dstaddr, sport=sport)
838      msg = "After NA response, expecting %s" % desc
839      self.ExpectPacketOn(netid, msg, expected)
840
841  # This test documents a known issue: routing tables are never deleted.
842  @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
843                       "no support for per-table autoconf")
844  def testLeftoverRoutes(self):
845    def GetNumRoutes():
846      return len(open("/proc/net/ipv6_route").readlines())
847
848    num_routes = GetNumRoutes()
849    for i in xrange(10, 20):
850      try:
851        self.tuns[i] = self.CreateTunInterface(i)
852        self.SendRA(i)
853        self.tuns[i].close()
854      finally:
855        del self.tuns[i]
856    self.assertLess(num_routes, GetNumRoutes())
857
858
859class PMTUTest(multinetwork_base.InboundMarkingTest):
860
861  PAYLOAD_SIZE = 1400
862  dstaddrs = set()
863
864  def GetSocketMTU(self, version, s):
865    if version == 6:
866      ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, csocket.IPV6_PATHMTU, 32)
867      unused_sockaddr, mtu = struct.unpack("=28sI", ip6_mtuinfo)
868      return mtu
869    else:
870      return s.getsockopt(net_test.SOL_IP, csocket.IP_MTU)
871
872  def DisableFragmentationAndReportErrors(self, version, s):
873    if version == 4:
874      s.setsockopt(net_test.SOL_IP, csocket.IP_MTU_DISCOVER,
875                   csocket.IP_PMTUDISC_DO)
876      s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
877    else:
878      s.setsockopt(net_test.SOL_IPV6, csocket.IPV6_DONTFRAG, 1)
879      s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
880
881  def CheckPMTU(self, version, use_connect, modes):
882
883    def SendBigPacket(version, s, dstaddr, netid, payload):
884      if use_connect:
885        s.send(payload)
886      else:
887        self.SendOnNetid(version, s, dstaddr, 1234, netid, payload, [])
888
889    for netid in self.tuns:
890      for mode in modes:
891        s = self.BuildSocket(version, net_test.UDPSocket, netid, mode)
892        self.DisableFragmentationAndReportErrors(version, s)
893
894        srcaddr = self.MyAddress(version, netid)
895        dst_prefix, intermediate = {
896            4: ("172.19.", "172.16.9.12"),
897            6: ("2001:db8::", "2001:db8::1")
898        }[version]
899
900        # Run this test often enough (e.g., in presubmits), and eventually
901        # we'll be unlucky enough to pick the same address twice, in which
902        # case the test will fail because the kernel will already have seen
903        # the lower MTU. Don't do this.
904        dstaddr = self.GetRandomDestination(dst_prefix)
905        while dstaddr in self.dstaddrs:
906          dstaddr = self.GetRandomDestination(dst_prefix)
907        self.dstaddrs.add(dstaddr)
908
909        if use_connect:
910          s.connect((dstaddr, 1234))
911
912        payload = self.PAYLOAD_SIZE * "a"
913
914        # Send a packet and receive a packet too big.
915        SendBigPacket(version, s, dstaddr, netid, payload)
916        received = self.ReadAllPacketsOn(netid)
917        self.assertEquals(1, len(received),
918                          "unexpected packets: %s" % received[1:])
919        _, toobig = packets.ICMPPacketTooBig(version, intermediate, srcaddr,
920                                             received[0])
921        self.ReceivePacketOn(netid, toobig)
922
923        # Check that another send on the same socket returns EMSGSIZE.
924        self.assertRaisesErrno(
925            errno.EMSGSIZE,
926            SendBigPacket, version, s, dstaddr, netid, payload)
927
928        # If this is a connected socket, make sure the socket MTU was set.
929        # Note that in IPv4 this only started working in Linux 3.6!
930        if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)):
931          self.assertEquals(packets.PTB_MTU, self.GetSocketMTU(version, s))
932
933        s.close()
934
935        # Check that other sockets pick up the PMTU we have been told about by
936        # connecting another socket to the same destination and getting its MTU.
937        # This new socket can use any method to select its outgoing interface;
938        # here we use a mark for simplicity.
939        s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
940        s2.connect((dstaddr, 1234))
941        self.assertEquals(packets.PTB_MTU, self.GetSocketMTU(version, s2))
942
943        # Also check the MTU reported by ip route get, this time using the oif.
944        routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None)
945        self.assertTrue(routes)
946        route = routes[0]
947        rtmsg, attributes = route
948        self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
949        metrics = attributes["RTA_METRICS"]
950        self.assertEquals(packets.PTB_MTU, metrics["RTAX_MTU"])
951
952  def testIPv4BasicPMTU(self):
953    """Tests IPv4 path MTU discovery.
954
955    Relevant kernel commits:
956      upstream net-next:
957        6a66271 ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
958
959      android-3.10:
960        4bc64dd ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
961    """
962
963    self.CheckPMTU(4, True, ["mark", "oif"])
964    self.CheckPMTU(4, False, ["mark", "oif"])
965
966  def testIPv6BasicPMTU(self):
967    self.CheckPMTU(6, True, ["mark", "oif"])
968    self.CheckPMTU(6, False, ["mark", "oif"])
969
970  def testIPv4UIDPMTU(self):
971    self.CheckPMTU(4, True, ["uid"])
972    self.CheckPMTU(4, False, ["uid"])
973
974  def testIPv6UIDPMTU(self):
975    self.CheckPMTU(6, True, ["uid"])
976    self.CheckPMTU(6, False, ["uid"])
977
978  # Making Path MTU Discovery work on unmarked  sockets requires that mark
979  # reflection be enabled. Otherwise the kernel has no way to know what routing
980  # table the original packet used, and thus it won't be able to clone the
981  # correct route.
982
983  def testIPv4UnmarkedSocketPMTU(self):
984    self.SetMarkReflectSysctls(1)
985    try:
986      self.CheckPMTU(4, False, [None])
987    finally:
988      self.SetMarkReflectSysctls(0)
989
990  def testIPv6UnmarkedSocketPMTU(self):
991    self.SetMarkReflectSysctls(1)
992    try:
993      self.CheckPMTU(6, False, [None])
994    finally:
995      self.SetMarkReflectSysctls(0)
996
997
998class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest):
999  """Tests that per-UID routing works properly.
1000
1001  Relevant kernel commits:
1002    upstream net-next:
1003      7d99569460 net: ipv4: Don't crash if passing a null sk to ip_do_redirect.
1004      d109e61bfe net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
1005      35b80733b3 net: core: add missing check for uid_range in rule_exists.
1006      e2d118a1cb net: inet: Support UID-based routing in IP protocols.
1007      622ec2c9d5 net: core: add UID to flows, rules, and routes
1008      86741ec254 net: core: Add a UID field to struct sock.
1009
1010    android-3.18:
1011      b004e79504 net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
1012      04c0eace81 net: inet: Support UID-based routing in IP protocols.
1013      18c36d7b71 net: core: add UID to flows, rules, and routes
1014      80e3440721 net: core: Add a UID field to struct sock.
1015      fa8cc2c30c Revert "net: core: Support UID-based routing."
1016      b585141890 Revert "Handle 'sk' being NULL in UID-based routing."
1017      5115ab7514 Revert "net: core: fix UID-based routing build"
1018      f9f4281f79 Revert "ANDROID: net: fib: remove duplicate assignment"
1019
1020    android-4.4:
1021      341965cf10 net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
1022      344afd627c net: inet: Support UID-based routing in IP protocols.
1023      03441d56d8 net: core: add UID to flows, rules, and routes
1024      eb964bdba7 net: core: Add a UID field to struct sock.
1025      9789b697c6 Revert "net: core: Support UID-based routing."
1026  """
1027
1028  def GetRulesAtPriority(self, version, priority):
1029    rules = self.iproute.DumpRules(version)
1030    out = [(rule, attributes) for rule, attributes in rules
1031           if attributes.get("FRA_PRIORITY", 0) == priority]
1032    return out
1033
1034  def CheckInitialTablesHaveNoUIDs(self, version):
1035    rules = []
1036    for priority in [0, 32766, 32767]:
1037      rules.extend(self.GetRulesAtPriority(version, priority))
1038    for _, attributes in rules:
1039      self.assertNotIn("FRA_UID_RANGE", attributes)
1040
1041  def testIPv4InitialTablesHaveNoUIDs(self):
1042    self.CheckInitialTablesHaveNoUIDs(4)
1043
1044  def testIPv6InitialTablesHaveNoUIDs(self):
1045    self.CheckInitialTablesHaveNoUIDs(6)
1046
1047  @staticmethod
1048  def _Random():
1049    return random.randint(1000000, 2000000)
1050
1051  def CheckGetAndSetRules(self, version):
1052    start, end = tuple(sorted([self._Random(), self._Random()]))
1053    table = self._Random()
1054    priority = self._Random()
1055
1056    # Can't create a UID range to UID -1 because -1 is INVALID_UID...
1057    self.assertRaisesErrno(
1058        errno.EINVAL,
1059        self.iproute.UidRangeRule, version, True, 100, 0xffffffff, table,
1060        priority)
1061
1062    # ... but -2 is valid.
1063    self.iproute.UidRangeRule(version, True, 100, 0xfffffffe, table, priority)
1064    self.iproute.UidRangeRule(version, False, 100, 0xfffffffe, table, priority)
1065
1066    try:
1067      # Create a UID range rule.
1068      self.iproute.UidRangeRule(version, True, start, end, table, priority)
1069
1070      # Check that deleting the wrong UID range doesn't work.
1071      self.assertRaisesErrno(
1072          errno.ENOENT,
1073          self.iproute.UidRangeRule, version, False, start, end + 1, table,
1074          priority)
1075      self.assertRaisesErrno(errno.ENOENT,
1076        self.iproute.UidRangeRule, version, False, start + 1, end, table,
1077        priority)
1078
1079      # Check that the UID range appears in dumps.
1080      rules = self.GetRulesAtPriority(version, priority)
1081      self.assertTrue(rules)
1082      _, attributes = rules[-1]
1083      self.assertEquals(priority, attributes["FRA_PRIORITY"])
1084      uidrange = attributes["FRA_UID_RANGE"]
1085      self.assertEquals(start, uidrange.start)
1086      self.assertEquals(end, uidrange.end)
1087      self.assertEquals(table, attributes["FRA_TABLE"])
1088    finally:
1089      self.iproute.UidRangeRule(version, False, start, end, table, priority)
1090      self.assertRaisesErrno(
1091          errno.ENOENT,
1092          self.iproute.UidRangeRule, version, False, start, end, table,
1093          priority)
1094
1095    fwmask = 0xfefefefe
1096    try:
1097      # Create a rule without a UID range.
1098      self.iproute.FwmarkRule(version, True, 300, fwmask, 301, priority + 1)
1099
1100      # Check it doesn't have a UID range.
1101      rules = self.GetRulesAtPriority(version, priority + 1)
1102      self.assertTrue(rules)
1103      for _, attributes in rules:
1104        self.assertIn("FRA_TABLE", attributes)
1105        self.assertNotIn("FRA_UID_RANGE", attributes)
1106    finally:
1107      self.iproute.FwmarkRule(version, False, 300, fwmask, 301, priority + 1)
1108
1109    # Test that EEXIST worksfor UID range rules too. This behaviour was only
1110    # added in 4.8.
1111    if net_test.LINUX_VERSION >= (4, 8, 0):
1112      ranges = [(100, 101), (100, 102), (99, 101), (1234, 5678)]
1113      dup = ranges[0]
1114      try:
1115        # Check that otherwise identical rules with different UID ranges can be
1116        # created without EEXIST.
1117        for start, end in ranges:
1118          self.iproute.UidRangeRule(version, True, start, end, table, priority)
1119        # ... but EEXIST is returned if the UID range is identical.
1120        self.assertRaisesErrno(
1121          errno.EEXIST,
1122          self.iproute.UidRangeRule, version, True, dup[0], dup[1], table,
1123          priority)
1124      finally:
1125        # Clean up.
1126        for start, end in ranges + [dup]:
1127          try:
1128            self.iproute.UidRangeRule(version, False, start, end, table,
1129                                      priority)
1130          except IOError:
1131            pass
1132
1133  def testIPv4GetAndSetRules(self):
1134    self.CheckGetAndSetRules(4)
1135
1136  def testIPv6GetAndSetRules(self):
1137    self.CheckGetAndSetRules(6)
1138
1139  @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not backported")
1140  def testDeleteErrno(self):
1141    for version in [4, 6]:
1142      table = self._Random()
1143      priority = self._Random()
1144      self.assertRaisesErrno(
1145          errno.EINVAL,
1146          self.iproute.UidRangeRule, version, False, 100, 0xffffffff, table,
1147          priority)
1148
1149  def ExpectNoRoute(self, addr, oif, mark, uid):
1150    # The lack of a route may be either an error, or an unreachable route.
1151    try:
1152      routes = self.iproute.GetRoutes(addr, oif, mark, uid)
1153      rtmsg, _ = routes[0]
1154      self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type)
1155    except IOError, e:
1156      if int(e.errno) != int(errno.ENETUNREACH):
1157        raise e
1158
1159  def ExpectRoute(self, addr, oif, mark, uid):
1160    routes = self.iproute.GetRoutes(addr, oif, mark, uid)
1161    rtmsg, _ = routes[0]
1162    self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
1163
1164  def CheckGetRoute(self, version, addr):
1165    self.ExpectNoRoute(addr, 0, 0, 0)
1166    for netid in self.NETIDS:
1167      uid = self.UidForNetid(netid)
1168      self.ExpectRoute(addr, 0, 0, uid)
1169    self.ExpectNoRoute(addr, 0, 0, 0)
1170
1171  def testIPv4RouteGet(self):
1172    self.CheckGetRoute(4, net_test.IPV4_ADDR)
1173
1174  def testIPv6RouteGet(self):
1175    self.CheckGetRoute(6, net_test.IPV6_ADDR)
1176
1177  def testChangeFdAttributes(self):
1178    netid = random.choice(self.NETIDS)
1179    uid = self._Random()
1180    table = self._TableForNetid(netid)
1181    remoteaddr = self.GetRemoteAddress(6)
1182    s = socket(AF_INET6, SOCK_DGRAM, 0)
1183
1184    def CheckSendFails():
1185      self.assertRaisesErrno(errno.ENETUNREACH,
1186                             s.sendto, "foo", (remoteaddr, 53))
1187    def CheckSendSucceeds():
1188      self.assertEquals(len("foo"), s.sendto("foo", (remoteaddr, 53)))
1189
1190    CheckSendFails()
1191    self.iproute.UidRangeRule(6, True, uid, uid, table, self.PRIORITY_UID)
1192    try:
1193      CheckSendFails()
1194      os.fchown(s.fileno(), uid, -1)
1195      CheckSendSucceeds()
1196      os.fchown(s.fileno(), -1, -1)
1197      CheckSendSucceeds()
1198      os.fchown(s.fileno(), -1, 12345)
1199      CheckSendSucceeds()
1200      os.fchmod(s.fileno(), 0777)
1201      CheckSendSucceeds()
1202      os.fchown(s.fileno(), 0, -1)
1203      CheckSendFails()
1204    finally:
1205      self.iproute.UidRangeRule(6, False, uid, uid, table, self.PRIORITY_UID)
1206
1207
1208class RulesTest(net_test.NetworkTest):
1209
1210  RULE_PRIORITY = 99999
1211  FWMASK = 0xffffffff
1212
1213  def setUp(self):
1214    self.iproute = iproute.IPRoute()
1215    for version in [4, 6]:
1216      self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
1217
1218  def tearDown(self):
1219    for version in [4, 6]:
1220      self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
1221
1222  def testRuleDeletionMatchesTable(self):
1223    for version in [4, 6]:
1224      # Add rules with mark 300 pointing at tables 301 and 302.
1225      # This checks for a kernel bug where deletion request for tables > 256
1226      # ignored the table.
1227      self.iproute.FwmarkRule(version, True, 300, self.FWMASK, 301,
1228                              priority=self.RULE_PRIORITY)
1229      self.iproute.FwmarkRule(version, True, 300, self.FWMASK, 302,
1230                              priority=self.RULE_PRIORITY)
1231      # Delete rule with mark 300 pointing at table 302.
1232      self.iproute.FwmarkRule(version, False, 300, self.FWMASK, 302,
1233                              priority=self.RULE_PRIORITY)
1234      # Check that the rule pointing at table 301 is still around.
1235      attributes = [a for _, a in self.iproute.DumpRules(version)
1236                    if a.get("FRA_PRIORITY", 0) == self.RULE_PRIORITY]
1237      self.assertEquals(1, len(attributes))
1238      self.assertEquals(301, attributes[0]["FRA_TABLE"])
1239
1240
1241if __name__ == "__main__":
1242  unittest.main()
1243