• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19import binascii
20import os
21import random
22import select
23from socket import *  # pylint: disable=wildcard-import
24import struct
25import threading
26import time
27import unittest
28
29import cstruct
30import multinetwork_base
31import net_test
32import packets
33import sock_diag
34import tcp_test
35
36# Mostly empty structure definition containing only the fields we currently use.
37TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh")
38
39NUM_SOCKETS = 30
40NO_BYTECODE = b""
41LINUX_4_19_OR_ABOVE = net_test.LINUX_VERSION >= (4, 19, 0)
42
43IPPROTO_SCTP = 132
44
45def HaveSctp():
46  try:
47    s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP)
48    s.close()
49    return True
50  except IOError:
51    return False
52
53HAVE_SCTP = HaveSctp()
54
55
56class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
57  """Basic tests for SOCK_DIAG functionality.
58
59    Relevant kernel commits:
60      android-3.4:
61        ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
62        99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
63
64      android-3.10:
65        3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
66        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
67
68      android-3.18:
69        e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
70
71      android-4.4:
72        525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
73  """
74  @staticmethod
75  def _CreateLotsOfSockets(socktype):
76    # Dict mapping (addr, sport, dport) tuples to socketpairs.
77    socketpairs = {}
78    for _ in range(NUM_SOCKETS):
79      family, addr = random.choice([
80          (AF_INET, "127.0.0.1"),
81          (AF_INET6, "::1"),
82          (AF_INET6, "::ffff:127.0.0.1")])
83      socketpair = net_test.CreateSocketPair(family, socktype, addr)
84      sport, dport = (socketpair[0].getsockname()[1],
85                      socketpair[1].getsockname()[1])
86      socketpairs[(addr, sport, dport)] = socketpair
87    return socketpairs
88
89  def assertSocketClosed(self, sock):
90    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
91
92  def assertSocketConnected(self, sock):
93    sock.getpeername()  # No errors? Socket is alive and connected.
94
95  def assertSocketsClosed(self, socketpair):
96    for sock in socketpair:
97      self.assertSocketClosed(sock)
98
99  def assertMarkIs(self, mark, attrs):
100    self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))
101
102  def assertSockInfoMatchesSocket(self, s, info):
103    diag_msg, attrs = info
104    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
105    self.assertEqual(diag_msg.family, family)
106
107    src, sport = s.getsockname()[0:2]
108    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
109    self.assertEqual(diag_msg.id.sport, sport)
110
111    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
112      dst, dport = s.getpeername()[0:2]
113      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
114      self.assertEqual(diag_msg.id.dport, dport)
115    else:
116      self.assertRaisesErrno(ENOTCONN, s.getpeername)
117
118    mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
119    self.assertMarkIs(mark, attrs)
120
121  def PackAndCheckBytecode(self, instructions):
122    bytecode = self.sock_diag.PackBytecode(instructions)
123    decoded = self.sock_diag.DecodeBytecode(bytecode)
124    self.assertEqual(len(instructions), len(decoded))
125    self.assertFalse("???" in decoded)
126    return bytecode
127
128  def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
129    """Simulates an external event during a blocking call on sock.
130
131    Args:
132      sock: The socket to use.
133      call: A function, the call to make. Takes one parameter, sock.
134      expected_errno: The value that call is expected to fail with, or None if
135        call is expected to succeed.
136      event: A function, the event that will happen during the blocking call.
137        Takes one parameter, sock.
138    """
139    thread = SocketExceptionThread(sock, call)
140    thread.start()
141    time.sleep(0.1)
142    event(sock)
143    thread.join(1)
144    self.assertFalse(thread.is_alive())
145    if expected_errno is not None:
146      self.assertIsNotNone(thread.exception)
147      self.assertTrue(isinstance(thread.exception, IOError),
148                      "Expected IOError, got %s" % thread.exception)
149      self.assertEqual(expected_errno, thread.exception.errno)
150    else:
151      self.assertIsNone(thread.exception)
152    self.assertSocketClosed(sock)
153
154  def CloseDuringBlockingCall(self, sock, call, expected_errno):
155    self._EventDuringBlockingCall(
156        sock, call, expected_errno,
157        lambda sock: self.sock_diag.CloseSocketFromFd(sock))
158
159  def setUp(self):
160    super(SockDiagBaseTest, self).setUp()
161    self.sock_diag = sock_diag.SockDiag()
162    self.socketpairs = {}
163
164  def tearDown(self):
165    for socketpair in list(self.socketpairs.values()):
166      for s in socketpair:
167        s.close()
168    super(SockDiagBaseTest, self).tearDown()
169
170
171class SockDiagTest(SockDiagBaseTest):
172
173  def testFindsMappedSockets(self):
174    """Tests that inet_diag_find_one_icsk can find mapped sockets."""
175    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
176                                           "::ffff:127.0.0.1")
177    for sock in socketpair:
178      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
179      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
180      self.sock_diag.GetSockInfo(diag_req)
181      # No errors? Good.
182
183  def CheckFindsAllMySockets(self, socktype, proto):
184    """Tests that basic socket dumping works."""
185    self.socketpairs = self._CreateLotsOfSockets(socktype)
186    sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE)
187    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
188
189    # Find the cookies for all of our sockets.
190    cookies = {}
191    for diag_msg, unused_attrs in sockets:
192      addr = self.sock_diag.GetSourceAddress(diag_msg)
193      sport = diag_msg.id.sport
194      dport = diag_msg.id.dport
195      if (addr, sport, dport) in self.socketpairs:
196        cookies[(addr, sport, dport)] = diag_msg.id.cookie
197      elif (addr, dport, sport) in self.socketpairs:
198        cookies[(addr, sport, dport)] = diag_msg.id.cookie
199
200    # Did we find all the cookies?
201    self.assertEqual(2 * NUM_SOCKETS, len(cookies))
202
203    socketpairs = list(self.socketpairs.values())
204    random.shuffle(socketpairs)
205    for socketpair in socketpairs:
206      for sock in socketpair:
207        # Check that we can find a diag_msg by scanning a dump.
208        self.assertSockInfoMatchesSocket(
209            sock,
210            self.sock_diag.FindSockInfoFromFd(sock))
211        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
212
213        # Check that we can find a diag_msg once we know the cookie.
214        req = self.sock_diag.DiagReqFromSocket(sock)
215        req.id.cookie = cookie
216        if proto == IPPROTO_UDP:
217          # Kernel bug: for UDP sockets, the order of arguments must be swapped.
218          # See testDemonstrateUdpGetSockIdBug.
219          req.id.sport, req.id.dport = req.id.dport, req.id.sport
220          req.id.src, req.id.dst = req.id.dst, req.id.src
221        info = self.sock_diag.GetSockInfo(req)
222        self.assertSockInfoMatchesSocket(sock, info)
223
224  def assertItemsEqual(self, expected, actual):
225    try:
226      super(SockDiagTest, self).assertItemsEqual(expected, actual)
227    except AttributeError:
228      # This was renamed in python3 but has the same behaviour.
229      super(SockDiagTest, self).assertCountEqual(expected, actual)
230
231  def testFindsAllMySocketsTcp(self):
232    self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP)
233
234  def testFindsAllMySocketsUdp(self):
235    self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP)
236
237  def testBytecodeCompilation(self):
238    # pylint: disable=bad-whitespace
239    instructions = [
240        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
241        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
242        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
243        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
244        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
245        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
246        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
247                                                                       # 76 acc
248                                                                       # 80 rej
249    ]
250    # pylint: enable=bad-whitespace
251    bytecode = self.PackAndCheckBytecode(instructions)
252    expected = (
253        b"0208500000000000"
254        b"050848000000ffff"
255        b"071c20000a800000ffffffff00000000000000000000000000000001"
256        b"01041c00"
257        b"0718200002200000ffffffff7f000001"
258        b"0508100000006566"
259        b"00040400"
260    )
261    states = 1 << tcp_test.TCP_ESTABLISHED
262    self.assertEqual(expected, binascii.hexlify(bytecode))
263    self.assertEqual(76, len(bytecode))
264    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
265    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
266                                                        states=states)
267    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
268                                                   states=states)
269    self.assertItemsEqual(allsockets, filteredsockets)
270
271    # Pick a few sockets in hash table order, and check that the bytecode we
272    # compiled selects them properly.
273    for socketpair in list(self.socketpairs.values())[:20]:
274      for s in socketpair:
275        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
276        instructions = [
277            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
278            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
279            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
280            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
281        ]
282        bytecode = self.PackAndCheckBytecode(instructions)
283        self.assertEqual(32, len(bytecode))
284        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
285        self.assertEqual(1, len(sockets))
286
287        # TODO: why doesn't comparing the cstructs work?
288        self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack())
289
290  def testCrossFamilyBytecode(self):
291    """Checks for a cross-family bug in inet_diag_hostcond matching.
292
293    Relevant kernel commits:
294      android-3.4:
295        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
296    """
297    # TODO: this is only here because the test fails if there are any open
298    # sockets other than the ones it creates itself. Make the bytecode more
299    # specific and remove it.
300    states = 1 << tcp_test.TCP_ESTABLISHED
301    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
302                                                       states=states))
303
304    unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
305    unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
306
307    bytecode4 = self.PackAndCheckBytecode([
308        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
309    bytecode6 = self.PackAndCheckBytecode([
310        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
311
312    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
313    v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
314                                                  states=states)
315    self.assertTrue(v4socks)
316    self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))
317
318    v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
319                                                  states=states)
320    self.assertTrue(v6socks)
321    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))
322
323    # Except for mapped addresses, which match both IPv4 and IPv6.
324    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
325                                      "::ffff:127.0.0.1")
326    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
327    v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
328                                                               bytecode4,
329                                                               states=states)]
330    v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
331                                                               bytecode6,
332                                                               states=states)]
333    self.assertTrue(all(d in v4socks for d in diag_msgs))
334    self.assertTrue(all(d in v6socks for d in diag_msgs))
335
336  def testPortComparisonValidation(self):
337    """Checks for a bug in validating port comparison bytecode.
338
339    Relevant kernel commits:
340      android-3.4:
341        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
342    """
343    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
344    self.assertEqual("???",
345                      self.sock_diag.DecodeBytecode(bytecode))
346    self.assertRaisesErrno(
347        EINVAL,
348        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
349
350  def testNonSockDiagCommand(self):
351    def DiagDump(code):
352      sock_id = self.sock_diag._EmptyInetDiagSockId()
353      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
354                                     sock_id))
355      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg)
356
357    op = sock_diag.SOCK_DIAG_BY_FAMILY
358    DiagDump(op)  # No errors? Good.
359    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
360
361  def CheckSocketCookie(self, inet, addr):
362    """Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
363    socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
364    for sock in socketpair:
365      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
366      cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
367      self.assertEqual(diag_msg.id.cookie, cookie)
368
369  def testGetsockoptcookie(self):
370    self.CheckSocketCookie(AF_INET, "127.0.0.1")
371    self.CheckSocketCookie(AF_INET6, "::1")
372
373  def testDemonstrateUdpGetSockIdBug(self):
374    # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup
375    # by passing the source address as the source address argument.
376    # Unfortunately those functions are intended to match local sockets based
377    # on received packets, and the argument that ends up being compared with
378    # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not
379    # have this bug.  Upstream has confirmed that this will not be fixed:
380    # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html
381    """Documents a bug: getting UDP sockets requires swapping src and dst."""
382    for version in [4, 5, 6]:
383      family = net_test.GetAddressFamily(version)
384      s = socket(family, SOCK_DGRAM, 0)
385      self.SelectInterface(s, self.RandomNetid(), "mark")
386      s.connect((self.GetRemoteSocketAddress(version), 53))
387
388      # Create a fully-specified diag req from our socket, including cookie if
389      # we can get it.
390      req = self.sock_diag.DiagReqFromSocket(s)
391      req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
392
393      # As is, this request does not find anything.
394      with self.assertRaisesErrno(ENOENT):
395        self.sock_diag.GetSockInfo(req)
396
397      # But if we swap src and dst, the kernel finds our socket.
398      req.id.sport, req.id.dport = req.id.dport, req.id.sport
399      req.id.src, req.id.dst = req.id.dst, req.id.src
400
401      self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req))
402
403
404class SockDestroyTest(SockDiagBaseTest):
405  """Tests that SOCK_DESTROY works correctly.
406
407  Relevant kernel commits:
408    net-next:
409      b613f56 net: diag: split inet_diag_dump_one_icsk into two
410      64be0ae net: diag: Add the ability to destroy a socket.
411      6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
412      c1e64e2 net: diag: Support destroying TCP sockets.
413      2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
414
415    android-3.4:
416      d48ec88 net: diag: split inet_diag_dump_one_icsk into two
417      2438189 net: diag: Add the ability to destroy a socket.
418      7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
419      44047b2 net: diag: Support destroying TCP sockets.
420      200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
421
422    android-3.10:
423      9eaff90 net: diag: split inet_diag_dump_one_icsk into two
424      d60326c net: diag: Add the ability to destroy a socket.
425      3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
426      529dfc6 net: diag: Support destroying TCP sockets.
427      9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
428
429    android-3.18:
430      100263d net: diag: split inet_diag_dump_one_icsk into two
431      194c5f3 net: diag: Add the ability to destroy a socket.
432      8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
433      b80585a net: diag: Support destroying TCP sockets.
434      476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
435
436    android-4.1:
437      56eebf8 net: diag: split inet_diag_dump_one_icsk into two
438      fb486c9 net: diag: Add the ability to destroy a socket.
439      0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
440      67c71d8 net: diag: Support destroying TCP sockets.
441      a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
442      e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
443
444    android-4.4:
445      76c83a9 net: diag: split inet_diag_dump_one_icsk into two
446      f7cf791 net: diag: Add the ability to destroy a socket.
447      1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
448      c9e8440d net: diag: Support destroying TCP sockets.
449      3d9502c tcp: diag: add support for request sockets to tcp_abort()
450      001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
451  """
452
453  def testClosesSockets(self):
454    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
455    for _, socketpair in self.socketpairs.items():
456      # Close one of the sockets.
457      # This will send a RST that will close the other side as well.
458      s = random.choice(socketpair)
459      if random.randrange(0, 2) == 1:
460        self.sock_diag.CloseSocketFromFd(s)
461      else:
462        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
463
464        # Get the cookie wrong and ensure that we get an error and the socket
465        # is not closed.
466        real_cookie = diag_msg.id.cookie
467        diag_msg.id.cookie = os.urandom(len(real_cookie))
468        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
469        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
470        self.assertSocketConnected(s)
471
472        # Now close it with the correct cookie.
473        req.id.cookie = real_cookie
474        self.sock_diag.CloseSocket(req)
475
476      # Check that both sockets in the pair are closed.
477      self.assertSocketsClosed(socketpair)
478
479  # TODO:
480  # Test that killing unix sockets returns EOPNOTSUPP.
481
482
483class SocketExceptionThread(threading.Thread):
484
485  def __init__(self, sock, operation):
486    self.exception = None
487    super(SocketExceptionThread, self).__init__()
488    self.daemon = True
489    self.sock = sock
490    self.operation = operation
491
492  def run(self):
493    try:
494      self.operation(self.sock)
495    except (IOError, AssertionError) as e:
496      self.exception = e
497
498
499class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
500
501  def testIpv4MappedSynRecvSocket(self):
502    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
503
504    Relevant kernel commits:
505         android-3.4:
506           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
507    """
508    netid = random.choice(list(self.tuns.keys()))
509    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
510    sock_id = self.sock_diag._EmptyInetDiagSockId()
511    sock_id.sport = self.port
512    states = 1 << tcp_test.TCP_SYN_RECV
513    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
514    children = self.sock_diag.Dump(req, NO_BYTECODE)
515
516    self.assertTrue(children)
517    for child, unused_args in children:
518      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
519      self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr),
520                       child.id.dst)
521      self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr),
522                       child.id.src)
523
524
525class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
526
527  RWND_SIZE = 64000 if LINUX_4_19_OR_ABOVE else 42000
528  TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd"
529
530  def setUp(self):
531    super(TcpRcvWindowTest, self).setUp()
532    if LINUX_4_19_OR_ABOVE:
533      self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w")
534      return
535
536    try:
537      f = open(self.TCP_DEFAULT_INIT_RWND, "w")
538    except IOError as e:
539      # sysctl was namespace-ified on May 25, 2020 in android-4.14-stable [R]
540      # just after 4.14.181 by:
541      #   https://android-review.googlesource.com/c/kernel/common/+/1312623
542      #   ANDROID: namespace'ify tcp_default_init_rwnd implementation
543      # But that commit might be missing in Q era kernels even when > 4.14.181
544      # when running T vts.
545      if net_test.LINUX_VERSION >= (4, 15, 0):
546        raise
547      if e.errno != ENOENT:
548        raise
549      # we rely on the network namespace creation code
550      # modifying the root netns sysctl before the namespace is even created
551      return
552
553    f.write("60")
554    f.close()
555
556  def checkInitRwndSize(self, version, netid):
557    self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
558    tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP,
559                                               net_test.TCP_INFO, len(TcpInfo)))
560    self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh,
561                    "Tcp rwnd of netid=%d, version=%d is not enough. "
562                    "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE,
563                                                tcpInfo.tcpi_rcv_ssthresh))
564
565  def checkSynPacketWindowSize(self, version, netid):
566    s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark")
567    myaddr = self.MyAddress(version, netid)
568    dstaddr = self.GetRemoteAddress(version)
569    dstsockaddr = self.GetRemoteSocketAddress(version)
570    desc, expected = packets.SYN(53, version, myaddr, dstaddr,
571                                 sport=None, seq=None)
572    self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53))
573    msg = "IPv%s TCP connect: expected %s on %s" % (
574        version, desc, self.GetInterfaceName(netid))
575    syn = self.ExpectPacketOn(netid, msg, expected)
576    self.assertLess(self.RWND_SIZE, syn.window)
577    s.close()
578
579  def testTcpCwndSize(self):
580    for version in [4, 5, 6]:
581      for netid in self.NETIDS:
582        self.checkInitRwndSize(version, netid)
583        self.checkSynPacketWindowSize(version, netid)
584
585
586class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
587
588  def setUp(self):
589    super(SockDestroyTcpTest, self).setUp()
590    self.netid = random.choice(list(self.tuns.keys()))
591
592  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
593    """Closes the socket and checks whether a RST is sent or not."""
594    if sock is not None:
595      self.assertIsNone(req, "Must specify sock or req, not both")
596      self.sock_diag.CloseSocketFromFd(sock)
597      self.assertRaisesErrno(EINVAL, sock.accept)
598    else:
599      self.assertIsNone(sock, "Must specify sock or req, not both")
600      self.sock_diag.CloseSocket(req)
601
602    if expect_reset:
603      desc, rst = self.RstPacket()
604      msg = "%s: expecting %s: " % (msg, desc)
605      self.ExpectPacketOn(self.netid, msg, rst)
606    else:
607      msg = "%s: " % msg
608      self.ExpectNoPacketsOn(self.netid, msg)
609
610    if sock is not None and do_close:
611      sock.close()
612
613  def CheckTcpReset(self, state, statename):
614    for version in [4, 5, 6]:
615      msg = "Closing incoming IPv%d %s socket" % (version, statename)
616      self.IncomingConnection(version, state, self.netid)
617      self.CheckRstOnClose(self.s, None, False, msg)
618      if state != tcp_test.TCP_LISTEN:
619        msg = "Closing accepted IPv%d %s socket" % (version, statename)
620        self.CheckRstOnClose(self.accepted, None, True, msg)
621
622  def testTcpResets(self):
623    """Checks that closing sockets in appropriate states sends a RST."""
624    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
625    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
626    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
627
628  def testFinWait1Socket(self):
629    for version in [4, 5, 6]:
630      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
631
632      # Get the cookie so we can find this socket after we close it.
633      diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
634      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
635
636      # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
637      net_test.EnableFinWait(self.accepted)
638      self.accepted.close()
639      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
640      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
641      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
642      desc, fin = self.FinPacket()
643      self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin)
644
645      # Destroy the socket and expect no RST.
646      self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket")
647      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
648
649      # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing
650      # because userspace had already closed it.
651      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
652
653      # ACK the FIN so we don't trip over retransmits in future tests.
654      finversion = 4 if version == 5 else version
655      desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
656      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
657      self.ReceivePacketOn(self.netid, finack)
658
659      # See if we can find the resulting FIN_WAIT2 socket.
660      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
661      infos = self.sock_diag.Dump(diag_req, NO_BYTECODE)
662      self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
663                          for diag_msg, attrs in infos),
664                      "Expected to find FIN_WAIT2 socket in %s" % infos)
665
666  def FindChildSockets(self, s):
667    """Finds the SYN_RECV child sockets of a given listening socket."""
668    d = self.sock_diag.FindSockDiagFromFd(self.s)
669    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
670    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
671    req.id.cookie = b"\x00" * 8
672
673    bad_bytecode = self.PackAndCheckBytecode(
674        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
675    self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))
676
677    bytecode = self.PackAndCheckBytecode(
678        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
679    children = self.sock_diag.Dump(req, bytecode)
680    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
681            for d, _ in children]
682
683  def CheckChildSocket(self, version, statename, parent_first):
684    state = getattr(tcp_test, statename)
685
686    self.IncomingConnection(version, state, self.netid)
687
688    d = self.sock_diag.FindSockDiagFromFd(self.s)
689    parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
690    children = self.FindChildSockets(self.s)
691    self.assertEqual(1, len(children))
692
693    is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
694    expected_state = tcp_test.TCP_ESTABLISHED if is_established else state
695
696    for child in children:
697      diag_msg, attrs = self.sock_diag.GetSockInfo(child)
698      self.assertEqual(diag_msg.state, expected_state)
699      self.assertMarkIs(self.netid, attrs)
700
701    def CloseParent(expect_reset):
702      msg = "Closing parent IPv%d %s socket %s child" % (
703          version, statename, "before" if parent_first else "after")
704      self.CheckRstOnClose(self.s, None, expect_reset, msg)
705      self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)
706
707    def CheckChildrenClosed():
708      for child in children:
709        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
710
711    def CloseChildren():
712      for child in children:
713        msg = "Closing child IPv%d %s socket %s parent" % (
714            version, statename, "after" if parent_first else "before")
715        self.sock_diag.GetSockInfo(child)
716        self.CheckRstOnClose(None, child, is_established, msg)
717        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
718      CheckChildrenClosed()
719
720    if parent_first:
721      # Closing the parent will close child sockets, which will send a RST,
722      # iff they are already established.
723      CloseParent(is_established)
724      if is_established:
725        CheckChildrenClosed()
726      else:
727        CloseChildren()
728        CheckChildrenClosed()
729      self.s.close()
730    else:
731      CloseChildren()
732      CloseParent(False)
733      self.s.close()
734
735  def testChildSockets(self):
736    for version in [4, 5, 6]:
737      self.CheckChildSocket(version, "TCP_SYN_RECV", False)
738      self.CheckChildSocket(version, "TCP_SYN_RECV", True)
739      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
740      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
741
742  def testAcceptInterrupted(self):
743    """Tests that accept() is interrupted by SOCK_DESTROY."""
744    for version in [4, 5, 6]:
745      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
746      self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
747      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
748      self.assertRaisesErrno(ECONNABORTED, self.s.send, b"foo")
749      self.assertRaisesErrno(EINVAL, self.s.accept)
750      # TODO: this should really return an error such as ENOTCONN...
751      self.assertEqual(b"", self.s.recv(4096))
752
753  def testReadInterrupted(self):
754    """Tests that read() is interrupted by SOCK_DESTROY."""
755    for version in [4, 5, 6]:
756      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
757      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
758                                   ECONNABORTED)
759      # Writing returns EPIPE, and reading returns EOF.
760      self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo")
761      self.assertEqual(b"", self.accepted.recv(4096))
762      self.assertEqual(b"", self.accepted.recv(4096))
763
764  def testConnectInterrupted(self):
765    """Tests that connect() is interrupted by SOCK_DESTROY."""
766    for version in [4, 5, 6]:
767      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
768      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
769      self.SelectInterface(s, self.netid, "mark")
770
771      remotesockaddr = self.GetRemoteSocketAddress(version)
772      remoteaddr = self.GetRemoteAddress(version)
773      s.bind(("", 0))
774      _, sport = s.getsockname()[:2]
775      self.CloseDuringBlockingCall(
776          s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
777      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
778                              remoteaddr, sport=sport, seq=None)
779      self.ExpectPacketOn(self.netid, desc, syn)
780      msg = "SOCK_DESTROY of socket in connect, expected no RST"
781      self.ExpectNoPacketsOn(self.netid, msg)
782
783
784class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
785  """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
786
787  The behaviour of poll() in these cases is not what we might expect: if only
788  POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
789  is (also) specified, it will only return POLLOUT.
790  """
791
792  POLLIN_OUT = select.POLLIN | select.POLLOUT
793  POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
794
795  def setUp(self):
796    super(PollOnCloseTest, self).setUp()
797    self.netid = random.choice(list(self.tuns.keys()))
798
799  POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
800                (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
801
802  def PollResultToString(self, poll_events, ignoremask):
803    out = []
804    for fd, event in poll_events:
805      flags = [name for (flag, name) in self.POLL_FLAGS
806               if event & flag & ~ignoremask != 0]
807      out.append((fd, "|".join(flags)))
808    return out
809
810  def BlockingPoll(self, sock, mask, expected, ignoremask):
811    p = select.poll()
812    p.register(sock, mask)
813    expected_fds = [(sock.fileno(), expected)]
814    # Don't block forever or we'll hang continuous test runs on failure.
815    # A 5-second timeout should be long enough not to be flaky.
816    actual_fds = p.poll(5000)
817    self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
818                     self.PollResultToString(actual_fds, ignoremask))
819
820  def RstDuringBlockingCall(self, sock, call, expected_errno):
821    self._EventDuringBlockingCall(
822        sock, call, expected_errno,
823        lambda _: self.ReceiveRstPacketOn(self.netid))
824
825  def assertSocketErrors(self, errno):
826    # The first operation returns the expected errno.
827    self.assertRaisesErrno(errno, self.accepted.recv, 4096)
828
829    # Subsequent operations behave as normal.
830    self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo")
831    self.assertEqual(b"", self.accepted.recv(4096))
832    self.assertEqual(b"", self.accepted.recv(4096))
833
834  def CheckPollDestroy(self, mask, expected, ignoremask):
835    """Interrupts a poll() with SOCK_DESTROY."""
836    for version in [4, 5, 6]:
837      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
838      self.CloseDuringBlockingCall(
839          self.accepted,
840          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
841          None)
842      self.assertSocketErrors(ECONNABORTED)
843
844  def CheckPollRst(self, mask, expected, ignoremask):
845    """Interrupts a poll() by receiving a TCP RST."""
846    for version in [4, 5, 6]:
847      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
848      self.RstDuringBlockingCall(
849          self.accepted,
850          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
851          None)
852      self.assertSocketErrors(ECONNRESET)
853
854  def testReadPollRst(self):
855    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, 0)
856
857  def testWritePollRst(self):
858    self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
859
860  def testReadWritePollRst(self):
861    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
862
863  def testReadPollDestroy(self):
864    # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
865    ignoremask = select.POLLIN | select.POLLHUP
866    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
867
868  def testWritePollDestroy(self):
869    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
870
871  def testReadWritePollDestroy(self):
872    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
873
874
875class SockDestroyUdpTest(SockDiagBaseTest):
876
877  """Tests SOCK_DESTROY on UDP sockets.
878
879    Relevant kernel commits:
880      upstream net-next:
881        5d77dca net: diag: support SOCK_DESTROY for UDP sockets
882        f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
883  """
884
885  def testClosesUdpSockets(self):
886    self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
887    for _, socketpair in self.socketpairs.items():
888      s1, s2 = socketpair
889
890      self.assertSocketConnected(s1)
891      self.sock_diag.CloseSocketFromFd(s1)
892      self.assertSocketClosed(s1)
893
894      self.assertSocketConnected(s2)
895      self.sock_diag.CloseSocketFromFd(s2)
896      self.assertSocketClosed(s2)
897
898  def BindToRandomPort(self, s, addr):
899    ATTEMPTS = 20
900    for i in range(20):
901      port = random.randrange(1024, 65535)
902      try:
903        s.bind((addr, port))
904        return port
905      except error as e:
906        if e.errno != EADDRINUSE:
907          raise e
908    raise ValueError("Could not find a free port on %s after %d attempts" %
909                     (addr, ATTEMPTS))
910
911  def testSocketAddressesAfterClose(self):
912    for version in 4, 5, 6:
913      netid = random.choice(self.NETIDS)
914      dst = self.GetRemoteSocketAddress(version)
915      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
916      unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
917
918      # Closing a socket that was not explicitly bound (i.e., bound via
919      # connect(), not bind()) clears the source address and port.
920      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
921      self.SelectInterface(s, netid, "mark")
922      s.connect((dst, 53))
923      self.sock_diag.CloseSocketFromFd(s)
924      self.assertEqual((unspec, 0), s.getsockname()[:2])
925
926      # Closing a socket bound to an IP address leaves the address as is.
927      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
928      src = self.MySocketAddress(version, netid)
929      s.bind((src, 0))
930      s.connect((dst, 53))
931      port = s.getsockname()[1]
932      self.sock_diag.CloseSocketFromFd(s)
933      self.assertEqual((src, 0), s.getsockname()[:2])
934
935      # Closing a socket bound to a port leaves the port as is.
936      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
937      port = self.BindToRandomPort(s, "")
938      s.connect((dst, 53))
939      self.sock_diag.CloseSocketFromFd(s)
940      self.assertEqual((unspec, port), s.getsockname()[:2])
941
942      # Closing a socket bound to IP address and port leaves both as is.
943      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
944      src = self.MySocketAddress(version, netid)
945      port = self.BindToRandomPort(s, src)
946      self.sock_diag.CloseSocketFromFd(s)
947      self.assertEqual((src, port), s.getsockname()[:2])
948
949  def testReadInterrupted(self):
950    """Tests that read() is interrupted by SOCK_DESTROY."""
951    for version in [4, 5, 6]:
952      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
953      s = net_test.UDPSocket(family)
954      self.SelectInterface(s, random.choice(self.NETIDS), "mark")
955      addr = self.GetRemoteSocketAddress(version)
956
957      # Check that reads on connected sockets are interrupted.
958      s.connect((addr, 53))
959      self.assertEqual(3, s.send(b"foo"))
960      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
961                                   ECONNABORTED)
962
963      # A destroyed socket is no longer connected, but still usable.
964      self.assertRaisesErrno(EDESTADDRREQ, s.send, b"foo")
965      self.assertEqual(3, s.sendto(b"foo", (addr, 53)))
966
967      # Check that reads on unconnected sockets are also interrupted.
968      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
969                                   ECONNABORTED)
970
971class SockDestroyPermissionTest(SockDiagBaseTest):
972
973  def CheckPermissions(self, socktype):
974    s = socket(AF_INET6, socktype, 0)
975    self.SelectInterface(s, random.choice(self.NETIDS), "mark")
976    if socktype == SOCK_STREAM:
977      s.listen(1)
978      expectedstate = tcp_test.TCP_LISTEN
979    else:
980      s.connect((self.GetRemoteAddress(6), 53))
981      expectedstate = tcp_test.TCP_ESTABLISHED
982
983    with net_test.RunAsUid(12345):
984      self.assertRaisesErrno(
985          EPERM, self.sock_diag.CloseSocketFromFd, s)
986
987    self.sock_diag.CloseSocketFromFd(s)
988    self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
989
990
991  def testUdp(self):
992    self.CheckPermissions(SOCK_DGRAM)
993
994  def testTcp(self):
995    self.CheckPermissions(SOCK_STREAM)
996
997
998class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
999
1000  """Tests SOCK_DIAG bytecode filters that use marks.
1001
1002    Relevant kernel commits:
1003      upstream net-next:
1004        627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
1005        a52e95a net: diag: allow socket bytecode filters to match socket marks
1006        d545cac net: inet: diag: expose the socket mark to privileged processes.
1007  """
1008
1009  def FilterEstablishedSockets(self, mark, mask):
1010    instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
1011    bytecode = self.sock_diag.PackBytecode(instructions)
1012    return self.sock_diag.DumpAllInetSockets(
1013        IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))
1014
1015  def assertSamePorts(self, ports, diag_msgs):
1016    expected = sorted(ports)
1017    actual = sorted([msg[0].id.sport for msg in diag_msgs])
1018    self.assertEqual(expected, actual)
1019
1020  def SockInfoMatchesSocket(self, s, info):
1021    try:
1022      self.assertSockInfoMatchesSocket(s, info)
1023      return True
1024    except AssertionError:
1025      return False
1026
1027  @staticmethod
1028  def SocketDescription(s):
1029    return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))
1030
1031  def assertFoundSockets(self, infos, sockets):
1032    matches = {}
1033    for s in sockets:
1034      match = None
1035      for info in infos:
1036        if self.SockInfoMatchesSocket(s, info):
1037          if match:
1038            self.fail("Socket %s matched both %s and %s" %
1039                      (self.SocketDescription(s), match, info))
1040          matches[s] = info
1041      self.assertTrue(s in matches, "Did not find socket %s in dump" %
1042                      self.SocketDescription(s))
1043
1044    for i in infos:
1045       if i not in list(matches.values()):
1046         self.fail("Too many sockets in dump, first unexpected: %s" % str(i))
1047
1048  def testMarkBytecode(self):
1049    family, addr = random.choice([
1050        (AF_INET, "127.0.0.1"),
1051        (AF_INET6, "::1"),
1052        (AF_INET6, "::ffff:127.0.0.1")])
1053    s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
1054    s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
1055    s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)
1056
1057    infos = self.FilterEstablishedSockets(0x1234, 0xffff)
1058    self.assertFoundSockets(infos, [s1])
1059
1060    infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
1061    self.assertFoundSockets(infos, [s1, s2])
1062
1063    infos = self.FilterEstablishedSockets(0x1235, 0xffff)
1064    self.assertFoundSockets(infos, [s2])
1065
1066    infos = self.FilterEstablishedSockets(0x0, 0x0)
1067    self.assertFoundSockets(infos, [s1, s2])
1068
1069    infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
1070    self.assertEqual(0, len(infos))
1071
1072    with net_test.RunAsUid(12345):
1073        self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
1074                               0xfff0000, 0xf0fed00)
1075
1076  @staticmethod
1077  def SetRandomMark(s):
1078    # Python doesn't like marks that don't fit into a signed int.
1079    mark = random.randrange(0, 2**31 - 1)
1080    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
1081    return mark
1082
1083  def assertSocketMarkIs(self, s, mark):
1084    diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1085    self.assertMarkIs(mark, attrs)
1086    with net_test.RunAsUid(12345):
1087      diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1088      self.assertMarkIs(None, attrs)
1089
1090  def testMarkInAttributes(self):
1091    testcases = [(AF_INET, "127.0.0.1"),
1092                 (AF_INET6, "::1"),
1093                 (AF_INET6, "::ffff:127.0.0.1")]
1094    for family, addr in testcases:
1095      # TCP listen sockets.
1096      server = socket(family, SOCK_STREAM, 0)
1097      server.bind((addr, 0))
1098      port = server.getsockname()[1]
1099      server.listen(1)  # Or the socket won't be in the hashtables.
1100      server_mark = self.SetRandomMark(server)
1101      self.assertSocketMarkIs(server, server_mark)
1102
1103      # TCP client sockets.
1104      client = socket(family, SOCK_STREAM, 0)
1105      client_mark = self.SetRandomMark(client)
1106      client.connect((addr, port))
1107      self.assertSocketMarkIs(client, client_mark)
1108
1109      # TCP server sockets.
1110      accepted, _ = server.accept()
1111      self.assertSocketMarkIs(accepted, server_mark)
1112
1113      accepted_mark = self.SetRandomMark(accepted)
1114      self.assertSocketMarkIs(accepted, accepted_mark)
1115      self.assertSocketMarkIs(server, server_mark)
1116
1117      server.close()
1118      client.close()
1119
1120      # Other TCP states are tested in SockDestroyTcpTest.
1121
1122      # UDP sockets.
1123      s = socket(family, SOCK_DGRAM, 0)
1124      mark = self.SetRandomMark(s)
1125      s.connect(("", 53))
1126      self.assertSocketMarkIs(s, mark)
1127      s.close()
1128
1129      # Basic test for SCTP. sctp_diag was only added in 4.7.
1130      if HAVE_SCTP:
1131        s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
1132        s.bind((addr, 0))
1133        s.listen(1)
1134        mark = self.SetRandomMark(s)
1135        self.assertSocketMarkIs(s, mark)
1136        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
1137        self.assertEqual(1, len(sockets))
1138        self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
1139        s.close()
1140
1141
1142if __name__ == "__main__":
1143  unittest.main()
1144