• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19import os
20import random
21import select
22from socket import *  # pylint: disable=wildcard-import
23import struct
24import threading
25import time
26import unittest
27
28import cstruct
29import multinetwork_base
30import net_test
31import packets
32import sock_diag
33import tcp_test
34
35# Mostly empty structure definition containing only the fields we currently use.
36TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh")
37
38NUM_SOCKETS = 30
39NO_BYTECODE = ""
40LINUX_4_9_OR_ABOVE = net_test.LINUX_VERSION >= (4, 9, 0)
41LINUX_4_19_OR_ABOVE = net_test.LINUX_VERSION >= (4, 19, 0)
42
43IPPROTO_SCTP = 132
44
45def HaveUdpDiag():
46  """Checks if the current kernel has config CONFIG_INET_UDP_DIAG enabled.
47
48  This config is required for device running 4.9 kernel that ship with P, In
49  this case always assume the config is there and use the tests to check if the
50  config is enabled as required.
51
52  For all ther other kernel version, there is no way to tell whether a dump
53  succeeded: if the appropriate handler wasn't found, __inet_diag_dump just
54  returns an empty result instead of an error. So, just check to see if a UDP
55  dump returns no sockets when we know it should return one. If not, some tests
56  will be skipped.
57
58  Returns:
59    True if the kernel is 4.9 or above, or the CONFIG_INET_UDP_DIAG is enabled.
60    False otherwise.
61  """
62  if LINUX_4_9_OR_ABOVE:
63      return True;
64  s = socket(AF_INET6, SOCK_DGRAM, 0)
65  s.bind(("::", 0))
66  s.connect((s.getsockname()))
67  sd = sock_diag.SockDiag()
68  have_udp_diag = len(sd.DumpAllInetSockets(IPPROTO_UDP, "")) > 0
69  s.close()
70  return have_udp_diag
71
72def HaveSctp():
73  if net_test.LINUX_VERSION < (4, 7, 0):
74    return False
75  try:
76    s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP)
77    s.close()
78    return True
79  except IOError:
80    return False
81
82HAVE_UDP_DIAG = HaveUdpDiag()
83HAVE_SCTP = HaveSctp()
84
85
86class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
87  """Basic tests for SOCK_DIAG functionality.
88
89    Relevant kernel commits:
90      android-3.4:
91        ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
92        99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
93
94      android-3.10:
95        3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
96        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
97
98      android-3.18:
99        e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
100
101      android-4.4:
102        525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
103  """
104  @staticmethod
105  def _CreateLotsOfSockets(socktype):
106    # Dict mapping (addr, sport, dport) tuples to socketpairs.
107    socketpairs = {}
108    for _ in range(NUM_SOCKETS):
109      family, addr = random.choice([
110          (AF_INET, "127.0.0.1"),
111          (AF_INET6, "::1"),
112          (AF_INET6, "::ffff:127.0.0.1")])
113      socketpair = net_test.CreateSocketPair(family, socktype, addr)
114      sport, dport = (socketpair[0].getsockname()[1],
115                      socketpair[1].getsockname()[1])
116      socketpairs[(addr, sport, dport)] = socketpair
117    return socketpairs
118
119  def assertSocketClosed(self, sock):
120    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
121
122  def assertSocketConnected(self, sock):
123    sock.getpeername()  # No errors? Socket is alive and connected.
124
125  def assertSocketsClosed(self, socketpair):
126    for sock in socketpair:
127      self.assertSocketClosed(sock)
128
129  def assertMarkIs(self, mark, attrs):
130    self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))
131
132  def assertSockInfoMatchesSocket(self, s, info):
133    diag_msg, attrs = info
134    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
135    self.assertEqual(diag_msg.family, family)
136
137    src, sport = s.getsockname()[0:2]
138    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
139    self.assertEqual(diag_msg.id.sport, sport)
140
141    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
142      dst, dport = s.getpeername()[0:2]
143      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
144      self.assertEqual(diag_msg.id.dport, dport)
145    else:
146      self.assertRaisesErrno(ENOTCONN, s.getpeername)
147
148    mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
149    self.assertMarkIs(mark, attrs)
150
151  def PackAndCheckBytecode(self, instructions):
152    bytecode = self.sock_diag.PackBytecode(instructions)
153    decoded = self.sock_diag.DecodeBytecode(bytecode)
154    self.assertEqual(len(instructions), len(decoded))
155    self.assertFalse("???" in decoded)
156    return bytecode
157
158  def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
159    """Simulates an external event during a blocking call on sock.
160
161    Args:
162      sock: The socket to use.
163      call: A function, the call to make. Takes one parameter, sock.
164      expected_errno: The value that call is expected to fail with, or None if
165        call is expected to succeed.
166      event: A function, the event that will happen during the blocking call.
167        Takes one parameter, sock.
168    """
169    thread = SocketExceptionThread(sock, call)
170    thread.start()
171    time.sleep(0.1)
172    event(sock)
173    thread.join(1)
174    self.assertFalse(thread.is_alive())
175    if expected_errno is not None:
176      self.assertIsNotNone(thread.exception)
177      self.assertTrue(isinstance(thread.exception, IOError),
178                      "Expected IOError, got %s" % thread.exception)
179      self.assertEqual(expected_errno, thread.exception.errno)
180    else:
181      self.assertIsNone(thread.exception)
182    self.assertSocketClosed(sock)
183
184  def CloseDuringBlockingCall(self, sock, call, expected_errno):
185    self._EventDuringBlockingCall(
186        sock, call, expected_errno,
187        lambda sock: self.sock_diag.CloseSocketFromFd(sock))
188
189  def setUp(self):
190    super(SockDiagBaseTest, self).setUp()
191    self.sock_diag = sock_diag.SockDiag()
192    self.socketpairs = {}
193
194  def tearDown(self):
195    for socketpair in list(self.socketpairs.values()):
196      for s in socketpair:
197        s.close()
198    super(SockDiagBaseTest, self).tearDown()
199
200
201class SockDiagTest(SockDiagBaseTest):
202
203  def testFindsMappedSockets(self):
204    """Tests that inet_diag_find_one_icsk can find mapped sockets."""
205    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
206                                           "::ffff:127.0.0.1")
207    for sock in socketpair:
208      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
209      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
210      self.sock_diag.GetSockInfo(diag_req)
211      # No errors? Good.
212
213  def CheckFindsAllMySockets(self, socktype, proto):
214    """Tests that basic socket dumping works."""
215    self.socketpairs = self._CreateLotsOfSockets(socktype)
216    sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE)
217    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
218
219    # Find the cookies for all of our sockets.
220    cookies = {}
221    for diag_msg, unused_attrs in sockets:
222      addr = self.sock_diag.GetSourceAddress(diag_msg)
223      sport = diag_msg.id.sport
224      dport = diag_msg.id.dport
225      if (addr, sport, dport) in self.socketpairs:
226        cookies[(addr, sport, dport)] = diag_msg.id.cookie
227      elif (addr, dport, sport) in self.socketpairs:
228        cookies[(addr, sport, dport)] = diag_msg.id.cookie
229
230    # Did we find all the cookies?
231    self.assertEqual(2 * NUM_SOCKETS, len(cookies))
232
233    socketpairs = list(self.socketpairs.values())
234    random.shuffle(socketpairs)
235    for socketpair in socketpairs:
236      for sock in socketpair:
237        # Check that we can find a diag_msg by scanning a dump.
238        self.assertSockInfoMatchesSocket(
239            sock,
240            self.sock_diag.FindSockInfoFromFd(sock))
241        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
242
243        # Check that we can find a diag_msg once we know the cookie.
244        req = self.sock_diag.DiagReqFromSocket(sock)
245        req.id.cookie = cookie
246        if proto == IPPROTO_UDP:
247          # Kernel bug: for UDP sockets, the order of arguments must be swapped.
248          # See testDemonstrateUdpGetSockIdBug.
249          req.id.sport, req.id.dport = req.id.dport, req.id.sport
250          req.id.src, req.id.dst = req.id.dst, req.id.src
251        info = self.sock_diag.GetSockInfo(req)
252        self.assertSockInfoMatchesSocket(sock, info)
253
254  def testFindsAllMySocketsTcp(self):
255    self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP)
256
257  @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
258  def testFindsAllMySocketsUdp(self):
259    self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP)
260
261  def testBytecodeCompilation(self):
262    # pylint: disable=bad-whitespace
263    instructions = [
264        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
265        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
266        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
267        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
268        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
269        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
270        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
271                                                                       # 76 acc
272                                                                       # 80 rej
273    ]
274    # pylint: enable=bad-whitespace
275    bytecode = self.PackAndCheckBytecode(instructions)
276    expected = (
277        "0208500000000000"
278        "050848000000ffff"
279        "071c20000a800000ffffffff00000000000000000000000000000001"
280        "01041c00"
281        "0718200002200000ffffffff7f000001"
282        "0508100000006566"
283        "00040400"
284    )
285    states = 1 << tcp_test.TCP_ESTABLISHED
286    self.assertMultiLineEqual(expected, bytecode.encode("hex"))
287    self.assertEqual(76, len(bytecode))
288    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
289    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
290                                                        states=states)
291    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
292                                                   states=states)
293    self.assertItemsEqual(allsockets, filteredsockets)
294
295    # Pick a few sockets in hash table order, and check that the bytecode we
296    # compiled selects them properly.
297    for socketpair in list(self.socketpairs.values())[:20]:
298      for s in socketpair:
299        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
300        instructions = [
301            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
302            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
303            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
304            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
305        ]
306        bytecode = self.PackAndCheckBytecode(instructions)
307        self.assertEqual(32, len(bytecode))
308        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
309        self.assertEqual(1, len(sockets))
310
311        # TODO: why doesn't comparing the cstructs work?
312        self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack())
313
314  def testCrossFamilyBytecode(self):
315    """Checks for a cross-family bug in inet_diag_hostcond matching.
316
317    Relevant kernel commits:
318      android-3.4:
319        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
320    """
321    # TODO: this is only here because the test fails if there are any open
322    # sockets other than the ones it creates itself. Make the bytecode more
323    # specific and remove it.
324    states = 1 << tcp_test.TCP_ESTABLISHED
325    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "",
326                                                       states=states))
327
328    unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
329    unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
330
331    bytecode4 = self.PackAndCheckBytecode([
332        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
333    bytecode6 = self.PackAndCheckBytecode([
334        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
335
336    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
337    v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
338                                                  states=states)
339    self.assertTrue(v4socks)
340    self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))
341
342    v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
343                                                  states=states)
344    self.assertTrue(v6socks)
345    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))
346
347    # Except for mapped addresses, which match both IPv4 and IPv6.
348    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
349                                      "::ffff:127.0.0.1")
350    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
351    v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
352                                                               bytecode4,
353                                                               states=states)]
354    v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
355                                                               bytecode6,
356                                                               states=states)]
357    self.assertTrue(all(d in v4socks for d in diag_msgs))
358    self.assertTrue(all(d in v6socks for d in diag_msgs))
359
360  def testPortComparisonValidation(self):
361    """Checks for a bug in validating port comparison bytecode.
362
363    Relevant kernel commits:
364      android-3.4:
365        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
366    """
367    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
368    self.assertEqual("???",
369                      self.sock_diag.DecodeBytecode(bytecode))
370    self.assertRaisesErrno(
371        EINVAL,
372        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
373
374  def testNonSockDiagCommand(self):
375    def DiagDump(code):
376      sock_id = self.sock_diag._EmptyInetDiagSockId()
377      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
378                                     sock_id))
379      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
380
381    op = sock_diag.SOCK_DIAG_BY_FAMILY
382    DiagDump(op)  # No errors? Good.
383    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
384
385  def CheckSocketCookie(self, inet, addr):
386    """Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
387    socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
388    for sock in socketpair:
389      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
390      cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
391      self.assertEqual(diag_msg.id.cookie, cookie)
392
393  @unittest.skipUnless(LINUX_4_9_OR_ABOVE, "SO_COOKIE not supported")
394  def testGetsockoptcookie(self):
395    self.CheckSocketCookie(AF_INET, "127.0.0.1")
396    self.CheckSocketCookie(AF_INET6, "::1")
397
398  @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
399  def testDemonstrateUdpGetSockIdBug(self):
400    # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup
401    # by passing the source address as the source address argument.
402    # Unfortunately those functions are intended to match local sockets based
403    # on received packets, and the argument that ends up being compared with
404    # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not
405    # have this bug.  Upstream has confirmed that this will not be fixed:
406    # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html
407    """Documents a bug: getting UDP sockets requires swapping src and dst."""
408    for version in [4, 5, 6]:
409      family = net_test.GetAddressFamily(version)
410      s = socket(family, SOCK_DGRAM, 0)
411      self.SelectInterface(s, self.RandomNetid(), "mark")
412      s.connect((self.GetRemoteSocketAddress(version), 53))
413
414      # Create a fully-specified diag req from our socket, including cookie if
415      # we can get it.
416      req = self.sock_diag.DiagReqFromSocket(s)
417      if LINUX_4_9_OR_ABOVE:
418        req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
419      else:
420        req.id.cookie = "\xff" * 16  # INET_DIAG_NOCOOKIE[2]
421
422      # As is, this request does not find anything.
423      with self.assertRaisesErrno(ENOENT):
424        self.sock_diag.GetSockInfo(req)
425
426      # But if we swap src and dst, the kernel finds our socket.
427      req.id.sport, req.id.dport = req.id.dport, req.id.sport
428      req.id.src, req.id.dst = req.id.dst, req.id.src
429
430      self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req))
431
432
433class SockDestroyTest(SockDiagBaseTest):
434  """Tests that SOCK_DESTROY works correctly.
435
436  Relevant kernel commits:
437    net-next:
438      b613f56 net: diag: split inet_diag_dump_one_icsk into two
439      64be0ae net: diag: Add the ability to destroy a socket.
440      6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
441      c1e64e2 net: diag: Support destroying TCP sockets.
442      2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
443
444    android-3.4:
445      d48ec88 net: diag: split inet_diag_dump_one_icsk into two
446      2438189 net: diag: Add the ability to destroy a socket.
447      7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
448      44047b2 net: diag: Support destroying TCP sockets.
449      200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
450
451    android-3.10:
452      9eaff90 net: diag: split inet_diag_dump_one_icsk into two
453      d60326c net: diag: Add the ability to destroy a socket.
454      3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
455      529dfc6 net: diag: Support destroying TCP sockets.
456      9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
457
458    android-3.18:
459      100263d net: diag: split inet_diag_dump_one_icsk into two
460      194c5f3 net: diag: Add the ability to destroy a socket.
461      8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
462      b80585a net: diag: Support destroying TCP sockets.
463      476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
464
465    android-4.1:
466      56eebf8 net: diag: split inet_diag_dump_one_icsk into two
467      fb486c9 net: diag: Add the ability to destroy a socket.
468      0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
469      67c71d8 net: diag: Support destroying TCP sockets.
470      a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
471      e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
472
473    android-4.4:
474      76c83a9 net: diag: split inet_diag_dump_one_icsk into two
475      f7cf791 net: diag: Add the ability to destroy a socket.
476      1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
477      c9e8440d net: diag: Support destroying TCP sockets.
478      3d9502c tcp: diag: add support for request sockets to tcp_abort()
479      001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
480  """
481
482  def testClosesSockets(self):
483    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
484    for _, socketpair in self.socketpairs.items():
485      # Close one of the sockets.
486      # This will send a RST that will close the other side as well.
487      s = random.choice(socketpair)
488      if random.randrange(0, 2) == 1:
489        self.sock_diag.CloseSocketFromFd(s)
490      else:
491        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
492
493        # Get the cookie wrong and ensure that we get an error and the socket
494        # is not closed.
495        real_cookie = diag_msg.id.cookie
496        diag_msg.id.cookie = os.urandom(len(real_cookie))
497        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
498        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
499        self.assertSocketConnected(s)
500
501        # Now close it with the correct cookie.
502        req.id.cookie = real_cookie
503        self.sock_diag.CloseSocket(req)
504
505      # Check that both sockets in the pair are closed.
506      self.assertSocketsClosed(socketpair)
507
508  # TODO:
509  # Test that killing unix sockets returns EOPNOTSUPP.
510
511
512class SocketExceptionThread(threading.Thread):
513
514  def __init__(self, sock, operation):
515    self.exception = None
516    super(SocketExceptionThread, self).__init__()
517    self.daemon = True
518    self.sock = sock
519    self.operation = operation
520
521  def run(self):
522    try:
523      self.operation(self.sock)
524    except (IOError, AssertionError) as e:
525      self.exception = e
526
527
528class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
529
530  def testIpv4MappedSynRecvSocket(self):
531    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
532
533    Relevant kernel commits:
534         android-3.4:
535           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
536    """
537    netid = random.choice(list(self.tuns.keys()))
538    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
539    sock_id = self.sock_diag._EmptyInetDiagSockId()
540    sock_id.sport = self.port
541    states = 1 << tcp_test.TCP_SYN_RECV
542    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
543    children = self.sock_diag.Dump(req, NO_BYTECODE)
544
545    self.assertTrue(children)
546    for child, unused_args in children:
547      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
548      self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr),
549                       child.id.dst)
550      self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr),
551                       child.id.src)
552
553
554class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
555
556  RWND_SIZE = 64000 if LINUX_4_19_OR_ABOVE else 42000
557  TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd"
558
559  def setUp(self):
560    super(TcpRcvWindowTest, self).setUp()
561    if LINUX_4_19_OR_ABOVE:
562      self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w")
563      return
564
565    try:
566      f = open(self.TCP_DEFAULT_INIT_RWND, "w")
567    except IOError as e:
568      # sysctl was namespace-ified on May 25, 2020 in android-4.14-stable [R]
569      # just after 4.14.181 by:
570      #   https://android-review.googlesource.com/c/kernel/common/+/1312623
571      #   ANDROID: namespace'ify tcp_default_init_rwnd implementation
572      # But that commit might be missing in Q era kernels even when > 4.14.181
573      # when running T vts.
574      if net_test.LINUX_VERSION >= (4, 15, 0):
575        raise
576      if e.errno != ENOENT:
577        raise
578      # we rely on the network namespace creation code
579      # modifying the root netns sysctl before the namespace is even created
580      return
581
582    f.write("60")
583
584  def checkInitRwndSize(self, version, netid):
585    self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
586    tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP,
587                                               net_test.TCP_INFO, len(TcpInfo)))
588    self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh,
589                    "Tcp rwnd of netid=%d, version=%d is not enough. "
590                    "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE,
591                                                tcpInfo.tcpi_rcv_ssthresh))
592
593  def checkSynPacketWindowSize(self, version, netid):
594    s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark")
595    myaddr = self.MyAddress(version, netid)
596    dstaddr = self.GetRemoteAddress(version)
597    dstsockaddr = self.GetRemoteSocketAddress(version)
598    desc, expected = packets.SYN(53, version, myaddr, dstaddr,
599                                 sport=None, seq=None)
600    self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53))
601    msg = "IPv%s TCP connect: expected %s on %s" % (
602        version, desc, self.GetInterfaceName(netid))
603    syn = self.ExpectPacketOn(netid, msg, expected)
604    self.assertLess(self.RWND_SIZE, syn.window)
605    s.close()
606
607  def testTcpCwndSize(self):
608    for version in [4, 5, 6]:
609      for netid in self.NETIDS:
610        self.checkInitRwndSize(version, netid)
611        self.checkSynPacketWindowSize(version, netid)
612
613
614class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
615
616  def setUp(self):
617    super(SockDestroyTcpTest, self).setUp()
618    self.netid = random.choice(list(self.tuns.keys()))
619
620  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
621    """Closes the socket and checks whether a RST is sent or not."""
622    if sock is not None:
623      self.assertIsNone(req, "Must specify sock or req, not both")
624      self.sock_diag.CloseSocketFromFd(sock)
625      self.assertRaisesErrno(EINVAL, sock.accept)
626    else:
627      self.assertIsNone(sock, "Must specify sock or req, not both")
628      self.sock_diag.CloseSocket(req)
629
630    if expect_reset:
631      desc, rst = self.RstPacket()
632      msg = "%s: expecting %s: " % (msg, desc)
633      self.ExpectPacketOn(self.netid, msg, rst)
634    else:
635      msg = "%s: " % msg
636      self.ExpectNoPacketsOn(self.netid, msg)
637
638    if sock is not None and do_close:
639      sock.close()
640
641  def CheckTcpReset(self, state, statename):
642    for version in [4, 5, 6]:
643      msg = "Closing incoming IPv%d %s socket" % (version, statename)
644      self.IncomingConnection(version, state, self.netid)
645      self.CheckRstOnClose(self.s, None, False, msg)
646      if state != tcp_test.TCP_LISTEN:
647        msg = "Closing accepted IPv%d %s socket" % (version, statename)
648        self.CheckRstOnClose(self.accepted, None, True, msg)
649
650  def testTcpResets(self):
651    """Checks that closing sockets in appropriate states sends a RST."""
652    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
653    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
654    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
655
656  def testFinWait1Socket(self):
657    for version in [4, 5, 6]:
658      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
659
660      # Get the cookie so we can find this socket after we close it.
661      diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
662      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
663
664      # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
665      net_test.EnableFinWait(self.accepted)
666      self.accepted.close()
667      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
668      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
669      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
670      desc, fin = self.FinPacket()
671      self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin)
672
673      # Destroy the socket and expect no RST.
674      self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket")
675      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
676
677      # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing
678      # because userspace had already closed it.
679      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
680
681      # ACK the FIN so we don't trip over retransmits in future tests.
682      finversion = 4 if version == 5 else version
683      desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
684      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
685      self.ReceivePacketOn(self.netid, finack)
686
687      # See if we can find the resulting FIN_WAIT2 socket. This does not appear
688      # to work on 3.10.
689      if net_test.LINUX_VERSION >= (3, 18):
690        diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
691        infos = self.sock_diag.Dump(diag_req, "")
692        self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
693                            for diag_msg, attrs in infos),
694                        "Expected to find FIN_WAIT2 socket in %s" % infos)
695
696  def FindChildSockets(self, s):
697    """Finds the SYN_RECV child sockets of a given listening socket."""
698    d = self.sock_diag.FindSockDiagFromFd(self.s)
699    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
700    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
701    req.id.cookie = "\x00" * 8
702
703    bad_bytecode = self.PackAndCheckBytecode(
704        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
705    self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))
706
707    bytecode = self.PackAndCheckBytecode(
708        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
709    children = self.sock_diag.Dump(req, bytecode)
710    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
711            for d, _ in children]
712
713  def CheckChildSocket(self, version, statename, parent_first):
714    state = getattr(tcp_test, statename)
715
716    self.IncomingConnection(version, state, self.netid)
717
718    d = self.sock_diag.FindSockDiagFromFd(self.s)
719    parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
720    children = self.FindChildSockets(self.s)
721    self.assertEqual(1, len(children))
722
723    is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
724    expected_state = tcp_test.TCP_ESTABLISHED if is_established else state
725
726    # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
727    # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
728    # Before 4.4, we can see those sockets in dumps, but we can't fetch
729    # or close them.
730    can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
731
732    for child in children:
733      if can_close_children:
734        diag_msg, attrs = self.sock_diag.GetSockInfo(child)
735        self.assertEqual(diag_msg.state, expected_state)
736        self.assertMarkIs(self.netid, attrs)
737      else:
738        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
739
740    def CloseParent(expect_reset):
741      msg = "Closing parent IPv%d %s socket %s child" % (
742          version, statename, "before" if parent_first else "after")
743      self.CheckRstOnClose(self.s, None, expect_reset, msg)
744      self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)
745
746    def CheckChildrenClosed():
747      for child in children:
748        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
749
750    def CloseChildren():
751      for child in children:
752        msg = "Closing child IPv%d %s socket %s parent" % (
753            version, statename, "after" if parent_first else "before")
754        self.sock_diag.GetSockInfo(child)
755        self.CheckRstOnClose(None, child, is_established, msg)
756        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
757      CheckChildrenClosed()
758
759    if parent_first:
760      # Closing the parent will close child sockets, which will send a RST,
761      # iff they are already established.
762      CloseParent(is_established)
763      if is_established:
764        CheckChildrenClosed()
765      elif can_close_children:
766        CloseChildren()
767        CheckChildrenClosed()
768      self.s.close()
769    else:
770      if can_close_children:
771        CloseChildren()
772      CloseParent(False)
773      self.s.close()
774
775  def testChildSockets(self):
776    for version in [4, 5, 6]:
777      self.CheckChildSocket(version, "TCP_SYN_RECV", False)
778      self.CheckChildSocket(version, "TCP_SYN_RECV", True)
779      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
780      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
781
782  def testAcceptInterrupted(self):
783    """Tests that accept() is interrupted by SOCK_DESTROY."""
784    for version in [4, 5, 6]:
785      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
786      self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
787      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
788      self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
789      self.assertRaisesErrno(EINVAL, self.s.accept)
790      # TODO: this should really return an error such as ENOTCONN...
791      self.assertEqual("", self.s.recv(4096))
792
793  def testReadInterrupted(self):
794    """Tests that read() is interrupted by SOCK_DESTROY."""
795    for version in [4, 5, 6]:
796      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
797      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
798                                   ECONNABORTED)
799      # Writing returns EPIPE, and reading returns EOF.
800      self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
801      self.assertEqual("", self.accepted.recv(4096))
802      self.assertEqual("", self.accepted.recv(4096))
803
804  def testConnectInterrupted(self):
805    """Tests that connect() is interrupted by SOCK_DESTROY."""
806    for version in [4, 5, 6]:
807      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
808      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
809      self.SelectInterface(s, self.netid, "mark")
810
811      remotesockaddr = self.GetRemoteSocketAddress(version)
812      remoteaddr = self.GetRemoteAddress(version)
813      s.bind(("", 0))
814      _, sport = s.getsockname()[:2]
815      self.CloseDuringBlockingCall(
816          s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
817      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
818                              remoteaddr, sport=sport, seq=None)
819      self.ExpectPacketOn(self.netid, desc, syn)
820      msg = "SOCK_DESTROY of socket in connect, expected no RST"
821      self.ExpectNoPacketsOn(self.netid, msg)
822
823
824class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
825  """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
826
827  The behaviour of poll() in these cases is not what we might expect: if only
828  POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
829  is (also) specified, it will only return POLLOUT.
830  """
831
832  POLLIN_OUT = select.POLLIN | select.POLLOUT
833  POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
834
835  def setUp(self):
836    super(PollOnCloseTest, self).setUp()
837    self.netid = random.choice(list(self.tuns.keys()))
838
839  POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
840                (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
841
842  def PollResultToString(self, poll_events, ignoremask):
843    out = []
844    for fd, event in poll_events:
845      flags = [name for (flag, name) in self.POLL_FLAGS
846               if event & flag & ~ignoremask != 0]
847      out.append((fd, "|".join(flags)))
848    return out
849
850  def BlockingPoll(self, sock, mask, expected, ignoremask):
851    p = select.poll()
852    p.register(sock, mask)
853    expected_fds = [(sock.fileno(), expected)]
854    # Don't block forever or we'll hang continuous test runs on failure.
855    # A 5-second timeout should be long enough not to be flaky.
856    actual_fds = p.poll(5000)
857    self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
858                     self.PollResultToString(actual_fds, ignoremask))
859
860  def RstDuringBlockingCall(self, sock, call, expected_errno):
861    self._EventDuringBlockingCall(
862        sock, call, expected_errno,
863        lambda _: self.ReceiveRstPacketOn(self.netid))
864
865  def assertSocketErrors(self, errno):
866    # The first operation returns the expected errno.
867    self.assertRaisesErrno(errno, self.accepted.recv, 4096)
868
869    # Subsequent operations behave as normal.
870    self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
871    self.assertEqual("", self.accepted.recv(4096))
872    self.assertEqual("", self.accepted.recv(4096))
873
874  def CheckPollDestroy(self, mask, expected, ignoremask):
875    """Interrupts a poll() with SOCK_DESTROY."""
876    for version in [4, 5, 6]:
877      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
878      self.CloseDuringBlockingCall(
879          self.accepted,
880          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
881          None)
882      self.assertSocketErrors(ECONNABORTED)
883
884  def CheckPollRst(self, mask, expected, ignoremask):
885    """Interrupts a poll() by receiving a TCP RST."""
886    for version in [4, 5, 6]:
887      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
888      self.RstDuringBlockingCall(
889          self.accepted,
890          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
891          None)
892      self.assertSocketErrors(ECONNRESET)
893
894  def testReadPollRst(self):
895    # Until 3d4762639d ("tcp: remove poll() flakes when receiving RST"), poll()
896    # would sometimes return POLLERR and sometimes POLLIN|POLLERR|POLLHUP. This
897    # is due to a race inside the kernel and thus is not visible on the VM, only
898    # on physical hardware.
899    if net_test.LINUX_VERSION < (4, 14, 0):
900      ignoremask = select.POLLIN | select.POLLHUP
901    else:
902      ignoremask = 0
903    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
904
905  def testWritePollRst(self):
906    self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
907
908  def testReadWritePollRst(self):
909    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
910
911  def testReadPollDestroy(self):
912    # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
913    ignoremask = select.POLLIN | select.POLLHUP
914    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
915
916  def testWritePollDestroy(self):
917    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
918
919  def testReadWritePollDestroy(self):
920    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
921
922
923@unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
924class SockDestroyUdpTest(SockDiagBaseTest):
925
926  """Tests SOCK_DESTROY on UDP sockets.
927
928    Relevant kernel commits:
929      upstream net-next:
930        5d77dca net: diag: support SOCK_DESTROY for UDP sockets
931        f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
932  """
933
934  def testClosesUdpSockets(self):
935    self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
936    for _, socketpair in self.socketpairs.items():
937      s1, s2 = socketpair
938
939      self.assertSocketConnected(s1)
940      self.sock_diag.CloseSocketFromFd(s1)
941      self.assertSocketClosed(s1)
942
943      self.assertSocketConnected(s2)
944      self.sock_diag.CloseSocketFromFd(s2)
945      self.assertSocketClosed(s2)
946
947  def BindToRandomPort(self, s, addr):
948    ATTEMPTS = 20
949    for i in range(20):
950      port = random.randrange(1024, 65535)
951      try:
952        s.bind((addr, port))
953        return port
954      except error as e:
955        if e.errno != EADDRINUSE:
956          raise e
957    raise ValueError("Could not find a free port on %s after %d attempts" %
958                     (addr, ATTEMPTS))
959
960  def testSocketAddressesAfterClose(self):
961    for version in 4, 5, 6:
962      netid = random.choice(self.NETIDS)
963      dst = self.GetRemoteSocketAddress(version)
964      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
965      unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
966
967      # Closing a socket that was not explicitly bound (i.e., bound via
968      # connect(), not bind()) clears the source address and port.
969      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
970      self.SelectInterface(s, netid, "mark")
971      s.connect((dst, 53))
972      self.sock_diag.CloseSocketFromFd(s)
973      self.assertEqual((unspec, 0), s.getsockname()[:2])
974
975      # Closing a socket bound to an IP address leaves the address as is.
976      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
977      src = self.MySocketAddress(version, netid)
978      s.bind((src, 0))
979      s.connect((dst, 53))
980      port = s.getsockname()[1]
981      self.sock_diag.CloseSocketFromFd(s)
982      self.assertEqual((src, 0), s.getsockname()[:2])
983
984      # Closing a socket bound to a port leaves the port as is.
985      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
986      port = self.BindToRandomPort(s, "")
987      s.connect((dst, 53))
988      self.sock_diag.CloseSocketFromFd(s)
989      self.assertEqual((unspec, port), s.getsockname()[:2])
990
991      # Closing a socket bound to IP address and port leaves both as is.
992      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
993      src = self.MySocketAddress(version, netid)
994      port = self.BindToRandomPort(s, src)
995      self.sock_diag.CloseSocketFromFd(s)
996      self.assertEqual((src, port), s.getsockname()[:2])
997
998  def testReadInterrupted(self):
999    """Tests that read() is interrupted by SOCK_DESTROY."""
1000    for version in [4, 5, 6]:
1001      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
1002      s = net_test.UDPSocket(family)
1003      self.SelectInterface(s, random.choice(self.NETIDS), "mark")
1004      addr = self.GetRemoteSocketAddress(version)
1005
1006      # Check that reads on connected sockets are interrupted.
1007      s.connect((addr, 53))
1008      self.assertEqual(3, s.send("foo"))
1009      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
1010                                   ECONNABORTED)
1011
1012      # A destroyed socket is no longer connected, but still usable.
1013      self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo")
1014      self.assertEqual(3, s.sendto("foo", (addr, 53)))
1015
1016      # Check that reads on unconnected sockets are also interrupted.
1017      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
1018                                   ECONNABORTED)
1019
1020class SockDestroyPermissionTest(SockDiagBaseTest):
1021
1022  def CheckPermissions(self, socktype):
1023    s = socket(AF_INET6, socktype, 0)
1024    self.SelectInterface(s, random.choice(self.NETIDS), "mark")
1025    if socktype == SOCK_STREAM:
1026      s.listen(1)
1027      expectedstate = tcp_test.TCP_LISTEN
1028    else:
1029      s.connect((self.GetRemoteAddress(6), 53))
1030      expectedstate = tcp_test.TCP_ESTABLISHED
1031
1032    with net_test.RunAsUid(12345):
1033      self.assertRaisesErrno(
1034          EPERM, self.sock_diag.CloseSocketFromFd, s)
1035
1036    self.sock_diag.CloseSocketFromFd(s)
1037    self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
1038
1039
1040  @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
1041  def testUdp(self):
1042    self.CheckPermissions(SOCK_DGRAM)
1043
1044  def testTcp(self):
1045    self.CheckPermissions(SOCK_STREAM)
1046
1047
1048class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
1049
1050  """Tests SOCK_DIAG bytecode filters that use marks.
1051
1052    Relevant kernel commits:
1053      upstream net-next:
1054        627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
1055        a52e95a net: diag: allow socket bytecode filters to match socket marks
1056        d545cac net: inet: diag: expose the socket mark to privileged processes.
1057  """
1058
1059  def FilterEstablishedSockets(self, mark, mask):
1060    instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
1061    bytecode = self.sock_diag.PackBytecode(instructions)
1062    return self.sock_diag.DumpAllInetSockets(
1063        IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))
1064
1065  def assertSamePorts(self, ports, diag_msgs):
1066    expected = sorted(ports)
1067    actual = sorted([msg[0].id.sport for msg in diag_msgs])
1068    self.assertEqual(expected, actual)
1069
1070  def SockInfoMatchesSocket(self, s, info):
1071    try:
1072      self.assertSockInfoMatchesSocket(s, info)
1073      return True
1074    except AssertionError:
1075      return False
1076
1077  @staticmethod
1078  def SocketDescription(s):
1079    return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))
1080
1081  def assertFoundSockets(self, infos, sockets):
1082    matches = {}
1083    for s in sockets:
1084      match = None
1085      for info in infos:
1086        if self.SockInfoMatchesSocket(s, info):
1087          if match:
1088            self.fail("Socket %s matched both %s and %s" %
1089                      (self.SocketDescription(s), match, info))
1090          matches[s] = info
1091      self.assertTrue(s in matches, "Did not find socket %s in dump" %
1092                      self.SocketDescription(s))
1093
1094    for i in infos:
1095       if i not in list(matches.values()):
1096         self.fail("Too many sockets in dump, first unexpected: %s" % str(i))
1097
1098  def testMarkBytecode(self):
1099    family, addr = random.choice([
1100        (AF_INET, "127.0.0.1"),
1101        (AF_INET6, "::1"),
1102        (AF_INET6, "::ffff:127.0.0.1")])
1103    s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
1104    s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
1105    s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)
1106
1107    infos = self.FilterEstablishedSockets(0x1234, 0xffff)
1108    self.assertFoundSockets(infos, [s1])
1109
1110    infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
1111    self.assertFoundSockets(infos, [s1, s2])
1112
1113    infos = self.FilterEstablishedSockets(0x1235, 0xffff)
1114    self.assertFoundSockets(infos, [s2])
1115
1116    infos = self.FilterEstablishedSockets(0x0, 0x0)
1117    self.assertFoundSockets(infos, [s1, s2])
1118
1119    infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
1120    self.assertEqual(0, len(infos))
1121
1122    with net_test.RunAsUid(12345):
1123        self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
1124                               0xfff0000, 0xf0fed00)
1125
1126  @staticmethod
1127  def SetRandomMark(s):
1128    # Python doesn't like marks that don't fit into a signed int.
1129    mark = random.randrange(0, 2**31 - 1)
1130    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
1131    return mark
1132
1133  def assertSocketMarkIs(self, s, mark):
1134    diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1135    self.assertMarkIs(mark, attrs)
1136    with net_test.RunAsUid(12345):
1137      diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1138      self.assertMarkIs(None, attrs)
1139
1140  def testMarkInAttributes(self):
1141    testcases = [(AF_INET, "127.0.0.1"),
1142                 (AF_INET6, "::1"),
1143                 (AF_INET6, "::ffff:127.0.0.1")]
1144    for family, addr in testcases:
1145      # TCP listen sockets.
1146      server = socket(family, SOCK_STREAM, 0)
1147      server.bind((addr, 0))
1148      port = server.getsockname()[1]
1149      server.listen(1)  # Or the socket won't be in the hashtables.
1150      server_mark = self.SetRandomMark(server)
1151      self.assertSocketMarkIs(server, server_mark)
1152
1153      # TCP client sockets.
1154      client = socket(family, SOCK_STREAM, 0)
1155      client_mark = self.SetRandomMark(client)
1156      client.connect((addr, port))
1157      self.assertSocketMarkIs(client, client_mark)
1158
1159      # TCP server sockets.
1160      accepted, _ = server.accept()
1161      self.assertSocketMarkIs(accepted, server_mark)
1162
1163      accepted_mark = self.SetRandomMark(accepted)
1164      self.assertSocketMarkIs(accepted, accepted_mark)
1165      self.assertSocketMarkIs(server, server_mark)
1166
1167      server.close()
1168      client.close()
1169
1170      # Other TCP states are tested in SockDestroyTcpTest.
1171
1172      # UDP sockets.
1173      if HAVE_UDP_DIAG:
1174        s = socket(family, SOCK_DGRAM, 0)
1175        mark = self.SetRandomMark(s)
1176        s.connect(("", 53))
1177        self.assertSocketMarkIs(s, mark)
1178        s.close()
1179
1180      # Basic test for SCTP. sctp_diag was only added in 4.7.
1181      if HAVE_SCTP:
1182        s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
1183        s.bind((addr, 0))
1184        s.listen(1)
1185        mark = self.SetRandomMark(s)
1186        self.assertSocketMarkIs(s, mark)
1187        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
1188        self.assertEqual(1, len(sockets))
1189        self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
1190        s.close()
1191
1192
1193if __name__ == "__main__":
1194  unittest.main()
1195