• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19import binascii
20import os
21import random
22import select
23from socket import *  # pylint: disable=wildcard-import
24import struct
25import threading
26import time
27import unittest
28
29import cstruct
30import multinetwork_base
31import net_test
32import packets
33import sock_diag
34import tcp_test
35
36# Mostly empty structure definition containing only the fields we currently use.
37TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh")
38
39NUM_SOCKETS = 30
40NO_BYTECODE = b""
41
42IPPROTO_SCTP = 132
43
44def HaveSctp():
45  try:
46    s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP)
47    s.close()
48    return True
49  except IOError:
50    return False
51
52HAVE_SCTP = HaveSctp()
53
54
55class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
56  """Basic tests for SOCK_DIAG functionality.
57
58    Relevant kernel commits:
59      android-3.4:
60        ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
61        99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
62
63      android-3.10:
64        3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
65        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
66
67      android-3.18:
68        e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
69
70      android-4.4:
71        525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
72  """
73  @staticmethod
74  def _CreateLotsOfSockets(socktype):
75    # Dict mapping (addr, sport, dport) tuples to socketpairs.
76    socketpairs = {}
77    for _ in range(NUM_SOCKETS):
78      family, addr = random.choice([
79          (AF_INET, "127.0.0.1"),
80          (AF_INET6, "::1"),
81          (AF_INET6, "::ffff:127.0.0.1")])
82      socketpair = net_test.CreateSocketPair(family, socktype, addr)
83      sport, dport = (socketpair[0].getsockname()[1],
84                      socketpair[1].getsockname()[1])
85      socketpairs[(addr, sport, dport)] = socketpair
86    return socketpairs
87
88  def assertSocketClosed(self, sock):
89    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
90
91  def assertSocketConnected(self, sock):
92    sock.getpeername()  # No errors? Socket is alive and connected.
93
94  def assertSocketsClosed(self, socketpair):
95    for sock in socketpair:
96      self.assertSocketClosed(sock)
97
98  def assertMarkIs(self, mark, attrs):
99    self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))
100
101  def assertSockInfoMatchesSocket(self, s, info):
102    diag_msg, attrs = info
103    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
104    self.assertEqual(diag_msg.family, family)
105
106    src, sport = s.getsockname()[0:2]
107    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
108    self.assertEqual(diag_msg.id.sport, sport)
109
110    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
111      dst, dport = s.getpeername()[0:2]
112      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
113      self.assertEqual(diag_msg.id.dport, dport)
114    else:
115      self.assertRaisesErrno(ENOTCONN, s.getpeername)
116
117    mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
118    self.assertMarkIs(mark, attrs)
119
120  def PackAndCheckBytecode(self, instructions):
121    bytecode = self.sock_diag.PackBytecode(instructions)
122    decoded = self.sock_diag.DecodeBytecode(bytecode)
123    self.assertEqual(len(instructions), len(decoded))
124    self.assertFalse("???" in decoded)
125    return bytecode
126
127  def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
128    """Simulates an external event during a blocking call on sock.
129
130    Args:
131      sock: The socket to use.
132      call: A function, the call to make. Takes one parameter, sock.
133      expected_errno: The value that call is expected to fail with, or None if
134        call is expected to succeed.
135      event: A function, the event that will happen during the blocking call.
136        Takes one parameter, sock.
137    """
138    thread = SocketExceptionThread(sock, call)
139    thread.start()
140    time.sleep(0.1)
141    event(sock)
142    thread.join(1)
143    self.assertFalse(thread.is_alive())
144    if expected_errno is not None:
145      self.assertIsNotNone(thread.exception)
146      self.assertTrue(isinstance(thread.exception, IOError),
147                      "Expected IOError, got %s" % thread.exception)
148      self.assertEqual(expected_errno, thread.exception.errno)
149    else:
150      self.assertIsNone(thread.exception)
151    self.assertSocketClosed(sock)
152
153  def CloseDuringBlockingCall(self, sock, call, expected_errno):
154    self._EventDuringBlockingCall(
155        sock, call, expected_errno,
156        lambda sock: self.sock_diag.CloseSocketFromFd(sock))
157
158  def setUp(self):
159    super(SockDiagBaseTest, self).setUp()
160    self.sock_diag = sock_diag.SockDiag()
161    self.socketpairs = {}
162
163  def tearDown(self):
164    for socketpair in list(self.socketpairs.values()):
165      for s in socketpair:
166        s.close()
167    super(SockDiagBaseTest, self).tearDown()
168
169
170class SockDiagTest(SockDiagBaseTest):
171
172  def testFindsMappedSockets(self):
173    """Tests that inet_diag_find_one_icsk can find mapped sockets."""
174    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
175                                           "::ffff:127.0.0.1")
176    for sock in socketpair:
177      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
178      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
179      self.sock_diag.GetSockInfo(diag_req)
180      # No errors? Good.
181
182    for sock in socketpair:
183      sock.close()
184
185  def CheckFindsAllMySockets(self, socktype, proto):
186    """Tests that basic socket dumping works."""
187    self.socketpairs = self._CreateLotsOfSockets(socktype)
188    sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE)
189    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
190
191    # Find the cookies for all of our sockets.
192    cookies = {}
193    for diag_msg, unused_attrs in sockets:
194      addr = self.sock_diag.GetSourceAddress(diag_msg)
195      sport = diag_msg.id.sport
196      dport = diag_msg.id.dport
197      if (addr, sport, dport) in self.socketpairs:
198        cookies[(addr, sport, dport)] = diag_msg.id.cookie
199      elif (addr, dport, sport) in self.socketpairs:
200        cookies[(addr, sport, dport)] = diag_msg.id.cookie
201
202    # Did we find all the cookies?
203    self.assertEqual(2 * NUM_SOCKETS, len(cookies))
204
205    socketpairs = list(self.socketpairs.values())
206    random.shuffle(socketpairs)
207    for socketpair in socketpairs:
208      for sock in socketpair:
209        # Check that we can find a diag_msg by scanning a dump.
210        self.assertSockInfoMatchesSocket(
211            sock,
212            self.sock_diag.FindSockInfoFromFd(sock))
213        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
214
215        # Check that we can find a diag_msg once we know the cookie.
216        req = self.sock_diag.DiagReqFromSocket(sock)
217        req.id.cookie = cookie
218        if proto == IPPROTO_UDP:
219          # Kernel bug: for UDP sockets, the order of arguments must be swapped.
220          # See testDemonstrateUdpGetSockIdBug.
221          req.id.sport, req.id.dport = req.id.dport, req.id.sport
222          req.id.src, req.id.dst = req.id.dst, req.id.src
223        info = self.sock_diag.GetSockInfo(req)
224        self.assertSockInfoMatchesSocket(sock, info)
225
226    for socketpair in socketpairs:
227      for sock in socketpair:
228        sock.close()
229
230  def assertItemsEqual(self, expected, actual):
231    try:
232      super(SockDiagTest, self).assertItemsEqual(expected, actual)
233    except AttributeError:
234      # This was renamed in python3 but has the same behaviour.
235      super(SockDiagTest, self).assertCountEqual(expected, actual)
236
237  def testFindsAllMySocketsTcp(self):
238    self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP)
239
240  def testFindsAllMySocketsUdp(self):
241    self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP)
242
243  def testBytecodeCompilation(self):
244    # pylint: disable=bad-whitespace
245    instructions = [
246        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
247        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
248        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
249        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
250        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
251        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
252        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
253                                                                       # 76 acc
254                                                                       # 80 rej
255    ]
256    # pylint: enable=bad-whitespace
257    bytecode = self.PackAndCheckBytecode(instructions)
258    expected = (
259        b"0208500000000000"
260        b"050848000000ffff"
261        b"071c20000a800000ffffffff00000000000000000000000000000001"
262        b"01041c00"
263        b"0718200002200000ffffffff7f000001"
264        b"0508100000006566"
265        b"00040400"
266    )
267    states = 1 << tcp_test.TCP_ESTABLISHED
268    self.assertEqual(expected, binascii.hexlify(bytecode))
269    self.assertEqual(76, len(bytecode))
270    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
271    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
272                                                        states=states)
273    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
274                                                   states=states)
275    self.assertItemsEqual(allsockets, filteredsockets)
276
277    # Pick a few sockets in hash table order, and check that the bytecode we
278    # compiled selects them properly.
279    for socketpair in list(self.socketpairs.values())[:20]:
280      for s in socketpair:
281        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
282        instructions = [
283            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
284            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
285            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
286            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
287        ]
288        bytecode = self.PackAndCheckBytecode(instructions)
289        self.assertEqual(32, len(bytecode))
290        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
291        self.assertEqual(1, len(sockets))
292
293        # TODO: why doesn't comparing the cstructs work?
294        self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack())
295
296  def testCrossFamilyBytecode(self):
297    """Checks for a cross-family bug in inet_diag_hostcond matching.
298
299    Relevant kernel commits:
300      android-3.4:
301        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
302    """
303    # TODO: this is only here because the test fails if there are any open
304    # sockets other than the ones it creates itself. Make the bytecode more
305    # specific and remove it.
306    states = 1 << tcp_test.TCP_ESTABLISHED
307    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
308                                                       states=states))
309
310    unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
311    unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
312
313    bytecode4 = self.PackAndCheckBytecode([
314        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
315    bytecode6 = self.PackAndCheckBytecode([
316        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
317
318    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
319    v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
320                                                  states=states)
321    self.assertTrue(v4socks)
322    self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))
323
324    v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
325                                                  states=states)
326    self.assertTrue(v6socks)
327    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))
328
329    # Except for mapped addresses, which match both IPv4 and IPv6.
330    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
331                                      "::ffff:127.0.0.1")
332    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
333    v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
334                                                               bytecode4,
335                                                               states=states)]
336    v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
337                                                               bytecode6,
338                                                               states=states)]
339    self.assertTrue(all(d in v4socks for d in diag_msgs))
340    self.assertTrue(all(d in v6socks for d in diag_msgs))
341
342    for sock in unused_pair4:
343      sock.close()
344
345    for sock in unused_pair6:
346      sock.close()
347
348    for sock in pair5:
349      sock.close()
350
351  def testPortComparisonValidation(self):
352    """Checks for a bug in validating port comparison bytecode.
353
354    Relevant kernel commits:
355      android-3.4:
356        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
357    """
358    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
359    self.assertEqual("???",
360                      self.sock_diag.DecodeBytecode(bytecode))
361    self.assertRaisesErrno(
362        EINVAL,
363        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
364
365  def testNonSockDiagCommand(self):
366    def DiagDump(code):
367      sock_id = self.sock_diag._EmptyInetDiagSockId()
368      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
369                                     sock_id))
370      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg)
371
372    op = sock_diag.SOCK_DIAG_BY_FAMILY
373    DiagDump(op)  # No errors? Good.
374    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
375
376  def CheckSocketCookie(self, inet, addr):
377    """Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
378    socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
379    for sock in socketpair:
380      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
381      cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
382      self.assertEqual(diag_msg.id.cookie, cookie)
383
384    for sock in socketpair:
385      sock.close()
386
387  def testGetsockoptcookie(self):
388    self.CheckSocketCookie(AF_INET, "127.0.0.1")
389    self.CheckSocketCookie(AF_INET6, "::1")
390
391  def testDemonstrateUdpGetSockIdBug(self):
392    # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup
393    # by passing the source address as the source address argument.
394    # Unfortunately those functions are intended to match local sockets based
395    # on received packets, and the argument that ends up being compared with
396    # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not
397    # have this bug.  Upstream has confirmed that this will not be fixed:
398    # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html
399    """Documents a bug: getting UDP sockets requires swapping src and dst."""
400    for version in [4, 5, 6]:
401      family = net_test.GetAddressFamily(version)
402      s = socket(family, SOCK_DGRAM, 0)
403      self.SelectInterface(s, self.RandomNetid(), "mark")
404      s.connect((self.GetRemoteSocketAddress(version), 53))
405
406      # Create a fully-specified diag req from our socket, including cookie if
407      # we can get it.
408      req = self.sock_diag.DiagReqFromSocket(s)
409      req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
410
411      # As is, this request does not find anything.
412      with self.assertRaisesErrno(ENOENT):
413        self.sock_diag.GetSockInfo(req)
414
415      # But if we swap src and dst, the kernel finds our socket.
416      req.id.sport, req.id.dport = req.id.dport, req.id.sport
417      req.id.src, req.id.dst = req.id.dst, req.id.src
418
419      self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req))
420
421      s.close()
422
423
424class SockDestroyTest(SockDiagBaseTest):
425  """Tests that SOCK_DESTROY works correctly.
426
427  Relevant kernel commits:
428    net-next:
429      b613f56 net: diag: split inet_diag_dump_one_icsk into two
430      64be0ae net: diag: Add the ability to destroy a socket.
431      6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
432      c1e64e2 net: diag: Support destroying TCP sockets.
433      2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
434
435    android-3.4:
436      d48ec88 net: diag: split inet_diag_dump_one_icsk into two
437      2438189 net: diag: Add the ability to destroy a socket.
438      7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
439      44047b2 net: diag: Support destroying TCP sockets.
440      200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
441
442    android-3.10:
443      9eaff90 net: diag: split inet_diag_dump_one_icsk into two
444      d60326c net: diag: Add the ability to destroy a socket.
445      3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
446      529dfc6 net: diag: Support destroying TCP sockets.
447      9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
448
449    android-3.18:
450      100263d net: diag: split inet_diag_dump_one_icsk into two
451      194c5f3 net: diag: Add the ability to destroy a socket.
452      8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
453      b80585a net: diag: Support destroying TCP sockets.
454      476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
455
456    android-4.1:
457      56eebf8 net: diag: split inet_diag_dump_one_icsk into two
458      fb486c9 net: diag: Add the ability to destroy a socket.
459      0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
460      67c71d8 net: diag: Support destroying TCP sockets.
461      a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
462      e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
463
464    android-4.4:
465      76c83a9 net: diag: split inet_diag_dump_one_icsk into two
466      f7cf791 net: diag: Add the ability to destroy a socket.
467      1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
468      c9e8440d net: diag: Support destroying TCP sockets.
469      3d9502c tcp: diag: add support for request sockets to tcp_abort()
470      001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
471  """
472
473  def testClosesSockets(self):
474    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
475    for _, socketpair in self.socketpairs.items():
476      # Close one of the sockets.
477      # This will send a RST that will close the other side as well.
478      s = random.choice(socketpair)
479      if random.randrange(0, 2) == 1:
480        self.sock_diag.CloseSocketFromFd(s)
481      else:
482        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
483
484        # Get the cookie wrong and ensure that we get an error and the socket
485        # is not closed.
486        real_cookie = diag_msg.id.cookie
487        diag_msg.id.cookie = os.urandom(len(real_cookie))
488        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
489        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
490        self.assertSocketConnected(s)
491
492        # Now close it with the correct cookie.
493        req.id.cookie = real_cookie
494        self.sock_diag.CloseSocket(req)
495
496      # Check that both sockets in the pair are closed.
497      self.assertSocketsClosed(socketpair)
498
499  # TODO:
500  # Test that killing unix sockets returns EOPNOTSUPP.
501
502
503class SocketExceptionThread(threading.Thread):
504
505  def __init__(self, sock, operation):
506    self.exception = None
507    super(SocketExceptionThread, self).__init__()
508    self.daemon = True
509    self.sock = sock
510    self.operation = operation
511
512  def run(self):
513    try:
514      self.operation(self.sock)
515    except (IOError, AssertionError) as e:
516      self.exception = e
517
518
519class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
520
521  def testIpv4MappedSynRecvSocket(self):
522    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
523
524    Relevant kernel commits:
525         android-3.4:
526           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
527    """
528    netid = random.choice(list(self.tuns.keys()))
529    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
530    sock_id = self.sock_diag._EmptyInetDiagSockId()
531    sock_id.sport = self.port
532    states = 1 << tcp_test.TCP_SYN_RECV
533    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
534    children = self.sock_diag.Dump(req, NO_BYTECODE)
535
536    self.assertTrue(children)
537    for child, unused_args in children:
538      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
539      self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr),
540                       child.id.dst)
541      self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr),
542                       child.id.src)
543
544
545class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
546
547  RWND_SIZE = 64000
548  TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd"
549
550  def setUp(self):
551    super(TcpRcvWindowTest, self).setUp()
552    self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w")
553
554  def checkInitRwndSize(self, version, netid):
555    self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
556    tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP,
557                                               net_test.TCP_INFO, len(TcpInfo)))
558    self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh,
559                    "Tcp rwnd of netid=%d, version=%d is not enough. "
560                    "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE,
561                                                tcpInfo.tcpi_rcv_ssthresh))
562    self.CloseSockets()
563
564  def checkSynPacketWindowSize(self, version, netid):
565    s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark")
566    myaddr = self.MyAddress(version, netid)
567    dstaddr = self.GetRemoteAddress(version)
568    dstsockaddr = self.GetRemoteSocketAddress(version)
569    desc, expected = packets.SYN(53, version, myaddr, dstaddr,
570                                 sport=None, seq=None)
571    self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53))
572    msg = "IPv%s TCP connect: expected %s on %s" % (
573        version, desc, self.GetInterfaceName(netid))
574    syn = self.ExpectPacketOn(netid, msg, expected)
575    self.assertLess(self.RWND_SIZE, syn.window)
576    s.close()
577
578  def testTcpCwndSize(self):
579    for version in [4, 5, 6]:
580      for netid in self.NETIDS:
581        self.checkInitRwndSize(version, netid)
582        self.checkSynPacketWindowSize(version, netid)
583
584
585class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
586
587  def setUp(self):
588    super(SockDestroyTcpTest, self).setUp()
589    self.netid = random.choice(list(self.tuns.keys()))
590
591  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
592    """Closes the socket and checks whether a RST is sent or not."""
593    if sock is not None:
594      self.assertIsNone(req, "Must specify sock or req, not both")
595      self.sock_diag.CloseSocketFromFd(sock)
596      self.assertRaisesErrno(EINVAL, sock.accept)
597    else:
598      self.assertIsNone(sock, "Must specify sock or req, not both")
599      self.sock_diag.CloseSocket(req)
600
601    if expect_reset:
602      desc, rst = self.RstPacket()
603      msg = "%s: expecting %s: " % (msg, desc)
604      self.ExpectPacketOn(self.netid, msg, rst)
605    else:
606      msg = "%s: " % msg
607      self.ExpectNoPacketsOn(self.netid, msg)
608
609    if sock is not None and do_close:
610      sock.close()
611
612  def CheckTcpReset(self, state, statename):
613    for version in [4, 5, 6]:
614      msg = "Closing incoming IPv%d %s socket" % (version, statename)
615      self.IncomingConnection(version, state, self.netid)
616      self.CheckRstOnClose(self.s, None, False, msg)
617      if state != tcp_test.TCP_LISTEN:
618        msg = "Closing accepted IPv%d %s socket" % (version, statename)
619        self.CheckRstOnClose(self.accepted, None, True, msg)
620        self.CloseSockets()
621
622  def testTcpResets(self):
623    """Checks that closing sockets in appropriate states sends a RST."""
624    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
625    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
626    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
627
628  def testFinWait1Socket(self):
629    for version in [4, 5, 6]:
630      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
631
632      # Get the cookie so we can find this socket after we close it.
633      diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
634      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
635
636      # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
637      net_test.EnableFinWait(self.accepted)
638      self.accepted.close()
639      del self.accepted
640      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
641      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
642      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
643      desc, fin = self.FinPacket()
644      self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin)
645
646      # Destroy the socket and expect no RST.
647      self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket")
648      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
649
650      # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing
651      # because userspace had already closed it.
652      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
653
654      # ACK the FIN so we don't trip over retransmits in future tests.
655      finversion = 4 if version == 5 else version
656      desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
657      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
658      self.ReceivePacketOn(self.netid, finack)
659
660      # See if we can find the resulting FIN_WAIT2 socket.
661      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
662      infos = self.sock_diag.Dump(diag_req, NO_BYTECODE)
663      self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
664                          for diag_msg, attrs in infos),
665                      "Expected to find FIN_WAIT2 socket in %s" % infos)
666
667      self.CloseSockets()
668
669  def FindChildSockets(self, s):
670    """Finds the SYN_RECV child sockets of a given listening socket."""
671    d = self.sock_diag.FindSockDiagFromFd(self.s)
672    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
673    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
674    req.id.cookie = b"\x00" * 8
675
676    bad_bytecode = self.PackAndCheckBytecode(
677        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
678    self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))
679
680    bytecode = self.PackAndCheckBytecode(
681        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
682    children = self.sock_diag.Dump(req, bytecode)
683    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
684            for d, _ in children]
685
686  def CheckChildSocket(self, version, statename, parent_first):
687    state = getattr(tcp_test, statename)
688
689    self.IncomingConnection(version, state, self.netid)
690
691    d = self.sock_diag.FindSockDiagFromFd(self.s)
692    parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
693    children = self.FindChildSockets(self.s)
694    self.assertEqual(1, len(children))
695
696    is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
697    expected_state = tcp_test.TCP_ESTABLISHED if is_established else state
698
699    for child in children:
700      diag_msg, attrs = self.sock_diag.GetSockInfo(child)
701      self.assertEqual(diag_msg.state, expected_state)
702      self.assertMarkIs(self.netid, attrs)
703
704    def CloseParent(expect_reset):
705      msg = "Closing parent IPv%d %s socket %s child" % (
706          version, statename, "before" if parent_first else "after")
707      self.CheckRstOnClose(self.s, None, expect_reset, msg)
708      self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)
709
710    def CheckChildrenClosed():
711      for child in children:
712        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
713
714    def CloseChildren():
715      for child in children:
716        msg = "Closing child IPv%d %s socket %s parent" % (
717            version, statename, "after" if parent_first else "before")
718        self.sock_diag.GetSockInfo(child)
719        self.CheckRstOnClose(None, child, is_established, msg)
720        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
721      CheckChildrenClosed()
722
723    if parent_first:
724      # Closing the parent will close child sockets, which will send a RST,
725      # iff they are already established.
726      CloseParent(is_established)
727      if is_established:
728        CheckChildrenClosed()
729      else:
730        CloseChildren()
731        CheckChildrenClosed()
732    else:
733      CloseChildren()
734      CloseParent(False)
735
736    self.CloseSockets()
737
738  def testChildSockets(self):
739    for version in [4, 5, 6]:
740      self.CheckChildSocket(version, "TCP_SYN_RECV", False)
741      self.CheckChildSocket(version, "TCP_SYN_RECV", True)
742      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
743      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
744
745  def testAcceptInterrupted(self):
746    """Tests that accept() is interrupted by SOCK_DESTROY."""
747    for version in [4, 5, 6]:
748      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
749      self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
750      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
751      self.assertRaisesErrno(ECONNABORTED, self.s.send, b"foo")
752      self.assertRaisesErrno(EINVAL, self.s.accept)
753      # TODO: this should really return an error such as ENOTCONN...
754      self.assertEqual(b"", self.s.recv(4096))
755      self.CloseSockets()
756
757  def testReadInterrupted(self):
758    """Tests that read() is interrupted by SOCK_DESTROY."""
759    for version in [4, 5, 6]:
760      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
761      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
762                                   ECONNABORTED)
763      # Writing returns EPIPE, and reading returns EOF.
764      self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo")
765      self.assertEqual(b"", self.accepted.recv(4096))
766      self.assertEqual(b"", self.accepted.recv(4096))
767      self.CloseSockets()
768
769  def testConnectInterrupted(self):
770    """Tests that connect() is interrupted by SOCK_DESTROY."""
771    for version in [4, 5, 6]:
772      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
773      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
774      self.SelectInterface(s, self.netid, "mark")
775
776      remotesockaddr = self.GetRemoteSocketAddress(version)
777      remoteaddr = self.GetRemoteAddress(version)
778      s.bind(("", 0))
779      _, sport = s.getsockname()[:2]
780      self.CloseDuringBlockingCall(
781          s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
782      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
783                              remoteaddr, sport=sport, seq=None)
784      self.ExpectPacketOn(self.netid, desc, syn)
785      msg = "SOCK_DESTROY of socket in connect, expected no RST"
786      self.ExpectNoPacketsOn(self.netid, msg)
787      s.close()
788
789
790class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
791  """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
792
793  The behaviour of poll() in these cases is not what we might expect: if only
794  POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
795  is (also) specified, it will only return POLLOUT.
796  """
797
798  POLLIN_OUT = select.POLLIN | select.POLLOUT
799  POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
800
801  def setUp(self):
802    super(PollOnCloseTest, self).setUp()
803    self.netid = random.choice(list(self.tuns.keys()))
804
805  POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
806                (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
807
808  def PollResultToString(self, poll_events, ignoremask):
809    out = []
810    for fd, event in poll_events:
811      flags = [name for (flag, name) in self.POLL_FLAGS
812               if event & flag & ~ignoremask != 0]
813      out.append((fd, "|".join(flags)))
814    return out
815
816  def BlockingPoll(self, sock, mask, expected, ignoremask):
817    p = select.poll()
818    p.register(sock, mask)
819    expected_fds = [(sock.fileno(), expected)]
820    # Don't block forever or we'll hang continuous test runs on failure.
821    # A 5-second timeout should be long enough not to be flaky.
822    actual_fds = p.poll(5000)
823    self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
824                     self.PollResultToString(actual_fds, ignoremask))
825
826  def RstDuringBlockingCall(self, sock, call, expected_errno):
827    self._EventDuringBlockingCall(
828        sock, call, expected_errno,
829        lambda _: self.ReceiveRstPacketOn(self.netid))
830
831  def assertSocketErrors(self, errno):
832    # The first operation returns the expected errno.
833    self.assertRaisesErrno(errno, self.accepted.recv, 4096)
834
835    # Subsequent operations behave as normal.
836    self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo")
837    self.assertEqual(b"", self.accepted.recv(4096))
838    self.assertEqual(b"", self.accepted.recv(4096))
839
840  def CheckPollDestroy(self, mask, expected, ignoremask):
841    """Interrupts a poll() with SOCK_DESTROY."""
842    for version in [4, 5, 6]:
843      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
844      self.CloseDuringBlockingCall(
845          self.accepted,
846          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
847          None)
848      self.assertSocketErrors(ECONNABORTED)
849      self.CloseSockets()
850
851  def CheckPollRst(self, mask, expected, ignoremask):
852    """Interrupts a poll() by receiving a TCP RST."""
853    for version in [4, 5, 6]:
854      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
855      self.RstDuringBlockingCall(
856          self.accepted,
857          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
858          None)
859      self.assertSocketErrors(ECONNRESET)
860      self.CloseSockets()
861
862  def testReadPollRst(self):
863    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, 0)
864
865  def testWritePollRst(self):
866    self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
867
868  def testReadWritePollRst(self):
869    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
870
871  def testReadPollDestroy(self):
872    # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
873    ignoremask = select.POLLIN | select.POLLHUP
874    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
875
876  def testWritePollDestroy(self):
877    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
878
879  def testReadWritePollDestroy(self):
880    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
881
882
883class SockDestroyUdpTest(SockDiagBaseTest):
884
885  """Tests SOCK_DESTROY on UDP sockets.
886
887    Relevant kernel commits:
888      upstream net-next:
889        5d77dca net: diag: support SOCK_DESTROY for UDP sockets
890        f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
891  """
892
893  def testClosesUdpSockets(self):
894    self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
895    for _, socketpair in self.socketpairs.items():
896      s1, s2 = socketpair
897
898      self.assertSocketConnected(s1)
899      self.sock_diag.CloseSocketFromFd(s1)
900      self.assertSocketClosed(s1)
901
902      self.assertSocketConnected(s2)
903      self.sock_diag.CloseSocketFromFd(s2)
904      self.assertSocketClosed(s2)
905
906  def BindToRandomPort(self, s, addr):
907    ATTEMPTS = 20
908    for i in range(20):
909      port = random.randrange(1024, 65535)
910      try:
911        s.bind((addr, port))
912        return port
913      except error as e:
914        if e.errno != EADDRINUSE:
915          raise e
916    raise ValueError("Could not find a free port on %s after %d attempts" %
917                     (addr, ATTEMPTS))
918
919  def testSocketAddressesAfterClose(self):
920    for version in 4, 5, 6:
921      netid = random.choice(self.NETIDS)
922      dst = self.GetRemoteSocketAddress(version)
923      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
924      unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
925
926      # Closing a socket that was not explicitly bound (i.e., bound via
927      # connect(), not bind()) clears the source address and port.
928      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
929      self.SelectInterface(s, netid, "mark")
930      s.connect((dst, 53))
931      self.sock_diag.CloseSocketFromFd(s)
932      self.assertEqual((unspec, 0), s.getsockname()[:2])
933      s.close()
934
935      # Closing a socket bound to an IP address leaves the address as is.
936      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
937      src = self.MySocketAddress(version, netid)
938      s.bind((src, 0))
939      s.connect((dst, 53))
940      port = s.getsockname()[1]
941      self.sock_diag.CloseSocketFromFd(s)
942      self.assertEqual((src, 0), s.getsockname()[:2])
943      s.close()
944
945      # Closing a socket bound to a port leaves the port as is.
946      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
947      port = self.BindToRandomPort(s, "")
948      s.connect((dst, 53))
949      self.sock_diag.CloseSocketFromFd(s)
950      self.assertEqual((unspec, port), s.getsockname()[:2])
951      s.close()
952
953      # Closing a socket bound to IP address and port leaves both as is.
954      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
955      src = self.MySocketAddress(version, netid)
956      port = self.BindToRandomPort(s, src)
957      self.sock_diag.CloseSocketFromFd(s)
958      self.assertEqual((src, port), s.getsockname()[:2])
959      s.close()
960
961  def testReadInterrupted(self):
962    """Tests that read() is interrupted by SOCK_DESTROY."""
963    for version in [4, 5, 6]:
964      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
965      s = net_test.UDPSocket(family)
966      self.SelectInterface(s, random.choice(self.NETIDS), "mark")
967      addr = self.GetRemoteSocketAddress(version)
968
969      # Check that reads on connected sockets are interrupted.
970      s.connect((addr, 53))
971      self.assertEqual(3, s.send(b"foo"))
972      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
973                                   ECONNABORTED)
974
975      # A destroyed socket is no longer connected, but still usable.
976      self.assertRaisesErrno(EDESTADDRREQ, s.send, b"foo")
977      self.assertEqual(3, s.sendto(b"foo", (addr, 53)))
978
979      # Check that reads on unconnected sockets are also interrupted.
980      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
981                                   ECONNABORTED)
982
983      s.close()
984
985class SockDestroyPermissionTest(SockDiagBaseTest):
986
987  def CheckPermissions(self, socktype):
988    s = socket(AF_INET6, socktype, 0)
989    self.SelectInterface(s, random.choice(self.NETIDS), "mark")
990    if socktype == SOCK_STREAM:
991      s.listen(1)
992      expectedstate = tcp_test.TCP_LISTEN
993    else:
994      s.connect((self.GetRemoteAddress(6), 53))
995      expectedstate = tcp_test.TCP_ESTABLISHED
996
997    with net_test.RunAsUid(12345):
998      self.assertRaisesErrno(
999          EPERM, self.sock_diag.CloseSocketFromFd, s)
1000
1001    self.sock_diag.CloseSocketFromFd(s)
1002    self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
1003
1004    s.close()
1005
1006
1007  def testUdp(self):
1008    self.CheckPermissions(SOCK_DGRAM)
1009
1010  def testTcp(self):
1011    self.CheckPermissions(SOCK_STREAM)
1012
1013
1014class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
1015
1016  """Tests SOCK_DIAG bytecode filters that use marks.
1017
1018    Relevant kernel commits:
1019      upstream net-next:
1020        627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
1021        a52e95a net: diag: allow socket bytecode filters to match socket marks
1022        d545cac net: inet: diag: expose the socket mark to privileged processes.
1023  """
1024
1025  def FilterEstablishedSockets(self, mark, mask):
1026    instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
1027    bytecode = self.sock_diag.PackBytecode(instructions)
1028    return self.sock_diag.DumpAllInetSockets(
1029        IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))
1030
1031  def assertSamePorts(self, ports, diag_msgs):
1032    expected = sorted(ports)
1033    actual = sorted([msg[0].id.sport for msg in diag_msgs])
1034    self.assertEqual(expected, actual)
1035
1036  def SockInfoMatchesSocket(self, s, info):
1037    try:
1038      self.assertSockInfoMatchesSocket(s, info)
1039      return True
1040    except AssertionError:
1041      return False
1042
1043  @staticmethod
1044  def SocketDescription(s):
1045    return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))
1046
1047  def assertFoundSockets(self, infos, sockets):
1048    matches = {}
1049    for s in sockets:
1050      match = None
1051      for info in infos:
1052        if self.SockInfoMatchesSocket(s, info):
1053          if match:
1054            self.fail("Socket %s matched both %s and %s" %
1055                      (self.SocketDescription(s), match, info))
1056          matches[s] = info
1057      self.assertTrue(s in matches, "Did not find socket %s in dump" %
1058                      self.SocketDescription(s))
1059
1060    for i in infos:
1061       if i not in list(matches.values()):
1062         self.fail("Too many sockets in dump, first unexpected: %s" % str(i))
1063
1064  def testMarkBytecode(self):
1065    family, addr = random.choice([
1066        (AF_INET, "127.0.0.1"),
1067        (AF_INET6, "::1"),
1068        (AF_INET6, "::ffff:127.0.0.1")])
1069    s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
1070    s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
1071    s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)
1072
1073    infos = self.FilterEstablishedSockets(0x1234, 0xffff)
1074    self.assertFoundSockets(infos, [s1])
1075
1076    infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
1077    self.assertFoundSockets(infos, [s1, s2])
1078
1079    infos = self.FilterEstablishedSockets(0x1235, 0xffff)
1080    self.assertFoundSockets(infos, [s2])
1081
1082    infos = self.FilterEstablishedSockets(0x0, 0x0)
1083    self.assertFoundSockets(infos, [s1, s2])
1084
1085    infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
1086    self.assertEqual(0, len(infos))
1087
1088    with net_test.RunAsUid(12345):
1089        self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
1090                               0xfff0000, 0xf0fed00)
1091
1092    s1.close()
1093    s2.close()
1094
1095  @staticmethod
1096  def SetRandomMark(s):
1097    # Python doesn't like marks that don't fit into a signed int.
1098    mark = random.randrange(0, 2**31 - 1)
1099    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
1100    return mark
1101
1102  def assertSocketMarkIs(self, s, mark):
1103    diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1104    self.assertMarkIs(mark, attrs)
1105    with net_test.RunAsUid(12345):
1106      diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1107      self.assertMarkIs(None, attrs)
1108
1109  def testMarkInAttributes(self):
1110    testcases = [(AF_INET, "127.0.0.1"),
1111                 (AF_INET6, "::1"),
1112                 (AF_INET6, "::ffff:127.0.0.1")]
1113    for family, addr in testcases:
1114      # TCP listen sockets.
1115      server = socket(family, SOCK_STREAM, 0)
1116      server.bind((addr, 0))
1117      port = server.getsockname()[1]
1118      server.listen(1)  # Or the socket won't be in the hashtables.
1119      server_mark = self.SetRandomMark(server)
1120      self.assertSocketMarkIs(server, server_mark)
1121
1122      # TCP client sockets.
1123      client = socket(family, SOCK_STREAM, 0)
1124      client_mark = self.SetRandomMark(client)
1125      client.connect((addr, port))
1126      self.assertSocketMarkIs(client, client_mark)
1127
1128      # TCP server sockets.
1129      accepted, _ = server.accept()
1130      self.assertSocketMarkIs(accepted, server_mark)
1131
1132      accepted_mark = self.SetRandomMark(accepted)
1133      self.assertSocketMarkIs(accepted, accepted_mark)
1134      self.assertSocketMarkIs(server, server_mark)
1135
1136      accepted.close()
1137      server.close()
1138      client.close()
1139
1140      # Other TCP states are tested in SockDestroyTcpTest.
1141
1142      # UDP sockets.
1143      s = socket(family, SOCK_DGRAM, 0)
1144      mark = self.SetRandomMark(s)
1145      s.connect(("", 53))
1146      self.assertSocketMarkIs(s, mark)
1147      s.close()
1148
1149      # Basic test for SCTP. sctp_diag was only added in 4.7.
1150      if HAVE_SCTP:
1151        s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
1152        s.bind((addr, 0))
1153        s.listen(1)
1154        mark = self.SetRandomMark(s)
1155        self.assertSocketMarkIs(s, mark)
1156        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
1157        self.assertEqual(1, len(sockets))
1158        self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
1159        s.close()
1160
1161
1162if __name__ == "__main__":
1163  unittest.main()
1164