1#!/usr/bin/python3 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import 18from errno import * # pylint: disable=wildcard-import 19import binascii 20import os 21import random 22import select 23from socket import * # pylint: disable=wildcard-import 24import struct 25import threading 26import time 27import unittest 28 29import cstruct 30import multinetwork_base 31import net_test 32import packets 33import sock_diag 34import tcp_test 35 36# Mostly empty structure definition containing only the fields we currently use. 37TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh") 38 39NUM_SOCKETS = 30 40NO_BYTECODE = b"" 41LINUX_4_19_OR_ABOVE = net_test.LINUX_VERSION >= (4, 19, 0) 42 43IPPROTO_SCTP = 132 44 45def HaveSctp(): 46 try: 47 s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP) 48 s.close() 49 return True 50 except IOError: 51 return False 52 53HAVE_SCTP = HaveSctp() 54 55 56class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): 57 """Basic tests for SOCK_DIAG functionality. 58 59 Relevant kernel commits: 60 android-3.4: 61 ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields 62 99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 63 64 android-3.10: 65 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields 66 f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 67 68 android-3.18: 69 e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 70 71 android-4.4: 72 525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 73 """ 74 @staticmethod 75 def _CreateLotsOfSockets(socktype): 76 # Dict mapping (addr, sport, dport) tuples to socketpairs. 77 socketpairs = {} 78 for _ in range(NUM_SOCKETS): 79 family, addr = random.choice([ 80 (AF_INET, "127.0.0.1"), 81 (AF_INET6, "::1"), 82 (AF_INET6, "::ffff:127.0.0.1")]) 83 socketpair = net_test.CreateSocketPair(family, socktype, addr) 84 sport, dport = (socketpair[0].getsockname()[1], 85 socketpair[1].getsockname()[1]) 86 socketpairs[(addr, sport, dport)] = socketpair 87 return socketpairs 88 89 def assertSocketClosed(self, sock): 90 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 91 92 def assertSocketConnected(self, sock): 93 sock.getpeername() # No errors? Socket is alive and connected. 94 95 def assertSocketsClosed(self, socketpair): 96 for sock in socketpair: 97 self.assertSocketClosed(sock) 98 99 def assertMarkIs(self, mark, attrs): 100 self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None)) 101 102 def assertSockInfoMatchesSocket(self, s, info): 103 diag_msg, attrs = info 104 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 105 self.assertEqual(diag_msg.family, family) 106 107 src, sport = s.getsockname()[0:2] 108 self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) 109 self.assertEqual(diag_msg.id.sport, sport) 110 111 if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: 112 dst, dport = s.getpeername()[0:2] 113 self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) 114 self.assertEqual(diag_msg.id.dport, dport) 115 else: 116 self.assertRaisesErrno(ENOTCONN, s.getpeername) 117 118 mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 119 self.assertMarkIs(mark, attrs) 120 121 def PackAndCheckBytecode(self, instructions): 122 bytecode = self.sock_diag.PackBytecode(instructions) 123 decoded = self.sock_diag.DecodeBytecode(bytecode) 124 self.assertEqual(len(instructions), len(decoded)) 125 self.assertFalse("???" in decoded) 126 return bytecode 127 128 def _EventDuringBlockingCall(self, sock, call, expected_errno, event): 129 """Simulates an external event during a blocking call on sock. 130 131 Args: 132 sock: The socket to use. 133 call: A function, the call to make. Takes one parameter, sock. 134 expected_errno: The value that call is expected to fail with, or None if 135 call is expected to succeed. 136 event: A function, the event that will happen during the blocking call. 137 Takes one parameter, sock. 138 """ 139 thread = SocketExceptionThread(sock, call) 140 thread.start() 141 time.sleep(0.1) 142 event(sock) 143 thread.join(1) 144 self.assertFalse(thread.is_alive()) 145 if expected_errno is not None: 146 self.assertIsNotNone(thread.exception) 147 self.assertTrue(isinstance(thread.exception, IOError), 148 "Expected IOError, got %s" % thread.exception) 149 self.assertEqual(expected_errno, thread.exception.errno) 150 else: 151 self.assertIsNone(thread.exception) 152 self.assertSocketClosed(sock) 153 154 def CloseDuringBlockingCall(self, sock, call, expected_errno): 155 self._EventDuringBlockingCall( 156 sock, call, expected_errno, 157 lambda sock: self.sock_diag.CloseSocketFromFd(sock)) 158 159 def setUp(self): 160 super(SockDiagBaseTest, self).setUp() 161 self.sock_diag = sock_diag.SockDiag() 162 self.socketpairs = {} 163 164 def tearDown(self): 165 for socketpair in list(self.socketpairs.values()): 166 for s in socketpair: 167 s.close() 168 super(SockDiagBaseTest, self).tearDown() 169 170 171class SockDiagTest(SockDiagBaseTest): 172 173 def testFindsMappedSockets(self): 174 """Tests that inet_diag_find_one_icsk can find mapped sockets.""" 175 socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 176 "::ffff:127.0.0.1") 177 for sock in socketpair: 178 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 179 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 180 self.sock_diag.GetSockInfo(diag_req) 181 # No errors? Good. 182 183 def CheckFindsAllMySockets(self, socktype, proto): 184 """Tests that basic socket dumping works.""" 185 self.socketpairs = self._CreateLotsOfSockets(socktype) 186 sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE) 187 self.assertGreaterEqual(len(sockets), NUM_SOCKETS) 188 189 # Find the cookies for all of our sockets. 190 cookies = {} 191 for diag_msg, unused_attrs in sockets: 192 addr = self.sock_diag.GetSourceAddress(diag_msg) 193 sport = diag_msg.id.sport 194 dport = diag_msg.id.dport 195 if (addr, sport, dport) in self.socketpairs: 196 cookies[(addr, sport, dport)] = diag_msg.id.cookie 197 elif (addr, dport, sport) in self.socketpairs: 198 cookies[(addr, sport, dport)] = diag_msg.id.cookie 199 200 # Did we find all the cookies? 201 self.assertEqual(2 * NUM_SOCKETS, len(cookies)) 202 203 socketpairs = list(self.socketpairs.values()) 204 random.shuffle(socketpairs) 205 for socketpair in socketpairs: 206 for sock in socketpair: 207 # Check that we can find a diag_msg by scanning a dump. 208 self.assertSockInfoMatchesSocket( 209 sock, 210 self.sock_diag.FindSockInfoFromFd(sock)) 211 cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie 212 213 # Check that we can find a diag_msg once we know the cookie. 214 req = self.sock_diag.DiagReqFromSocket(sock) 215 req.id.cookie = cookie 216 if proto == IPPROTO_UDP: 217 # Kernel bug: for UDP sockets, the order of arguments must be swapped. 218 # See testDemonstrateUdpGetSockIdBug. 219 req.id.sport, req.id.dport = req.id.dport, req.id.sport 220 req.id.src, req.id.dst = req.id.dst, req.id.src 221 info = self.sock_diag.GetSockInfo(req) 222 self.assertSockInfoMatchesSocket(sock, info) 223 224 def assertItemsEqual(self, expected, actual): 225 try: 226 super(SockDiagTest, self).assertItemsEqual(expected, actual) 227 except AttributeError: 228 # This was renamed in python3 but has the same behaviour. 229 super(SockDiagTest, self).assertCountEqual(expected, actual) 230 231 def testFindsAllMySocketsTcp(self): 232 self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP) 233 234 def testFindsAllMySocketsUdp(self): 235 self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP) 236 237 def testBytecodeCompilation(self): 238 # pylint: disable=bad-whitespace 239 instructions = [ 240 (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 241 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 242 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 243 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 244 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 245 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 246 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 247 # 76 acc 248 # 80 rej 249 ] 250 # pylint: enable=bad-whitespace 251 bytecode = self.PackAndCheckBytecode(instructions) 252 expected = ( 253 b"0208500000000000" 254 b"050848000000ffff" 255 b"071c20000a800000ffffffff00000000000000000000000000000001" 256 b"01041c00" 257 b"0718200002200000ffffffff7f000001" 258 b"0508100000006566" 259 b"00040400" 260 ) 261 states = 1 << tcp_test.TCP_ESTABLISHED 262 self.assertEqual(expected, binascii.hexlify(bytecode)) 263 self.assertEqual(76, len(bytecode)) 264 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 265 filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode, 266 states=states) 267 allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, 268 states=states) 269 self.assertItemsEqual(allsockets, filteredsockets) 270 271 # Pick a few sockets in hash table order, and check that the bytecode we 272 # compiled selects them properly. 273 for socketpair in list(self.socketpairs.values())[:20]: 274 for s in socketpair: 275 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 276 instructions = [ 277 (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), 278 (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), 279 (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), 280 (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), 281 ] 282 bytecode = self.PackAndCheckBytecode(instructions) 283 self.assertEqual(32, len(bytecode)) 284 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) 285 self.assertEqual(1, len(sockets)) 286 287 # TODO: why doesn't comparing the cstructs work? 288 self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack()) 289 290 def testCrossFamilyBytecode(self): 291 """Checks for a cross-family bug in inet_diag_hostcond matching. 292 293 Relevant kernel commits: 294 android-3.4: 295 f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() 296 """ 297 # TODO: this is only here because the test fails if there are any open 298 # sockets other than the ones it creates itself. Make the bytecode more 299 # specific and remove it. 300 states = 1 << tcp_test.TCP_ESTABLISHED 301 self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, 302 states=states)) 303 304 unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") 305 unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") 306 307 bytecode4 = self.PackAndCheckBytecode([ 308 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) 309 bytecode6 = self.PackAndCheckBytecode([ 310 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) 311 312 # IPv4/v6 filters must never match IPv6/IPv4 sockets... 313 v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4, 314 states=states) 315 self.assertTrue(v4socks) 316 self.assertTrue(all(d.family == AF_INET for d, _ in v4socks)) 317 318 v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6, 319 states=states) 320 self.assertTrue(v6socks) 321 self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks)) 322 323 # Except for mapped addresses, which match both IPv4 and IPv6. 324 pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 325 "::ffff:127.0.0.1") 326 diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] 327 v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 328 bytecode4, 329 states=states)] 330 v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 331 bytecode6, 332 states=states)] 333 self.assertTrue(all(d in v4socks for d in diag_msgs)) 334 self.assertTrue(all(d in v6socks for d in diag_msgs)) 335 336 def testPortComparisonValidation(self): 337 """Checks for a bug in validating port comparison bytecode. 338 339 Relevant kernel commits: 340 android-3.4: 341 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads 342 """ 343 bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) 344 self.assertEqual("???", 345 self.sock_diag.DecodeBytecode(bytecode)) 346 self.assertRaisesErrno( 347 EINVAL, 348 self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack()) 349 350 def testNonSockDiagCommand(self): 351 def DiagDump(code): 352 sock_id = self.sock_diag._EmptyInetDiagSockId() 353 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, 354 sock_id)) 355 self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg) 356 357 op = sock_diag.SOCK_DIAG_BY_FAMILY 358 DiagDump(op) # No errors? Good. 359 self.assertRaisesErrno(EINVAL, DiagDump, op + 17) 360 361 def CheckSocketCookie(self, inet, addr): 362 """Tests that getsockopt SO_COOKIE can get cookie for all sockets.""" 363 socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr) 364 for sock in socketpair: 365 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 366 cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 367 self.assertEqual(diag_msg.id.cookie, cookie) 368 369 def testGetsockoptcookie(self): 370 self.CheckSocketCookie(AF_INET, "127.0.0.1") 371 self.CheckSocketCookie(AF_INET6, "::1") 372 373 def testDemonstrateUdpGetSockIdBug(self): 374 # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup 375 # by passing the source address as the source address argument. 376 # Unfortunately those functions are intended to match local sockets based 377 # on received packets, and the argument that ends up being compared with 378 # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not 379 # have this bug. Upstream has confirmed that this will not be fixed: 380 # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html 381 """Documents a bug: getting UDP sockets requires swapping src and dst.""" 382 for version in [4, 5, 6]: 383 family = net_test.GetAddressFamily(version) 384 s = socket(family, SOCK_DGRAM, 0) 385 self.SelectInterface(s, self.RandomNetid(), "mark") 386 s.connect((self.GetRemoteSocketAddress(version), 53)) 387 388 # Create a fully-specified diag req from our socket, including cookie if 389 # we can get it. 390 req = self.sock_diag.DiagReqFromSocket(s) 391 req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 392 393 # As is, this request does not find anything. 394 with self.assertRaisesErrno(ENOENT): 395 self.sock_diag.GetSockInfo(req) 396 397 # But if we swap src and dst, the kernel finds our socket. 398 req.id.sport, req.id.dport = req.id.dport, req.id.sport 399 req.id.src, req.id.dst = req.id.dst, req.id.src 400 401 self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req)) 402 403 404class SockDestroyTest(SockDiagBaseTest): 405 """Tests that SOCK_DESTROY works correctly. 406 407 Relevant kernel commits: 408 net-next: 409 b613f56 net: diag: split inet_diag_dump_one_icsk into two 410 64be0ae net: diag: Add the ability to destroy a socket. 411 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets. 412 c1e64e2 net: diag: Support destroying TCP sockets. 413 2010b93 net: tcp: deal with listen sockets properly in tcp_abort. 414 415 android-3.4: 416 d48ec88 net: diag: split inet_diag_dump_one_icsk into two 417 2438189 net: diag: Add the ability to destroy a socket. 418 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets. 419 44047b2 net: diag: Support destroying TCP sockets. 420 200dae7 net: tcp: deal with listen sockets properly in tcp_abort. 421 422 android-3.10: 423 9eaff90 net: diag: split inet_diag_dump_one_icsk into two 424 d60326c net: diag: Add the ability to destroy a socket. 425 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets. 426 529dfc6 net: diag: Support destroying TCP sockets. 427 9c712fe net: tcp: deal with listen sockets properly in tcp_abort. 428 429 android-3.18: 430 100263d net: diag: split inet_diag_dump_one_icsk into two 431 194c5f3 net: diag: Add the ability to destroy a socket. 432 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets. 433 b80585a net: diag: Support destroying TCP sockets. 434 476c6ce net: tcp: deal with listen sockets properly in tcp_abort. 435 436 android-4.1: 437 56eebf8 net: diag: split inet_diag_dump_one_icsk into two 438 fb486c9 net: diag: Add the ability to destroy a socket. 439 0c02b7e net: diag: Support SOCK_DESTROY for inet sockets. 440 67c71d8 net: diag: Support destroying TCP sockets. 441 a76e0ec net: tcp: deal with listen sockets properly in tcp_abort. 442 e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 443 444 android-4.4: 445 76c83a9 net: diag: split inet_diag_dump_one_icsk into two 446 f7cf791 net: diag: Add the ability to destroy a socket. 447 1c42248 net: diag: Support SOCK_DESTROY for inet sockets. 448 c9e8440d net: diag: Support destroying TCP sockets. 449 3d9502c tcp: diag: add support for request sockets to tcp_abort() 450 001cf75 net: tcp: deal with listen sockets properly in tcp_abort. 451 """ 452 453 def testClosesSockets(self): 454 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 455 for _, socketpair in self.socketpairs.items(): 456 # Close one of the sockets. 457 # This will send a RST that will close the other side as well. 458 s = random.choice(socketpair) 459 if random.randrange(0, 2) == 1: 460 self.sock_diag.CloseSocketFromFd(s) 461 else: 462 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 463 464 # Get the cookie wrong and ensure that we get an error and the socket 465 # is not closed. 466 real_cookie = diag_msg.id.cookie 467 diag_msg.id.cookie = os.urandom(len(real_cookie)) 468 req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 469 self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) 470 self.assertSocketConnected(s) 471 472 # Now close it with the correct cookie. 473 req.id.cookie = real_cookie 474 self.sock_diag.CloseSocket(req) 475 476 # Check that both sockets in the pair are closed. 477 self.assertSocketsClosed(socketpair) 478 479 # TODO: 480 # Test that killing unix sockets returns EOPNOTSUPP. 481 482 483class SocketExceptionThread(threading.Thread): 484 485 def __init__(self, sock, operation): 486 self.exception = None 487 super(SocketExceptionThread, self).__init__() 488 self.daemon = True 489 self.sock = sock 490 self.operation = operation 491 492 def run(self): 493 try: 494 self.operation(self.sock) 495 except (IOError, AssertionError) as e: 496 self.exception = e 497 498 499class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 500 501 def testIpv4MappedSynRecvSocket(self): 502 """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. 503 504 Relevant kernel commits: 505 android-3.4: 506 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state 507 """ 508 netid = random.choice(list(self.tuns.keys())) 509 self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid) 510 sock_id = self.sock_diag._EmptyInetDiagSockId() 511 sock_id.sport = self.port 512 states = 1 << tcp_test.TCP_SYN_RECV 513 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) 514 children = self.sock_diag.Dump(req, NO_BYTECODE) 515 516 self.assertTrue(children) 517 for child, unused_args in children: 518 self.assertEqual(tcp_test.TCP_SYN_RECV, child.state) 519 self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr), 520 child.id.dst) 521 self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr), 522 child.id.src) 523 524 525class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 526 527 RWND_SIZE = 64000 if LINUX_4_19_OR_ABOVE else 42000 528 TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd" 529 530 def setUp(self): 531 super(TcpRcvWindowTest, self).setUp() 532 if LINUX_4_19_OR_ABOVE: 533 self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w") 534 return 535 536 try: 537 f = open(self.TCP_DEFAULT_INIT_RWND, "w") 538 except IOError as e: 539 # sysctl was namespace-ified on May 25, 2020 in android-4.14-stable [R] 540 # just after 4.14.181 by: 541 # https://android-review.googlesource.com/c/kernel/common/+/1312623 542 # ANDROID: namespace'ify tcp_default_init_rwnd implementation 543 # But that commit might be missing in Q era kernels even when > 4.14.181 544 # when running T vts. 545 if net_test.LINUX_VERSION >= (4, 15, 0): 546 raise 547 if e.errno != ENOENT: 548 raise 549 # we rely on the network namespace creation code 550 # modifying the root netns sysctl before the namespace is even created 551 return 552 553 f.write("60") 554 f.close() 555 556 def checkInitRwndSize(self, version, netid): 557 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid) 558 tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP, 559 net_test.TCP_INFO, len(TcpInfo))) 560 self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh, 561 "Tcp rwnd of netid=%d, version=%d is not enough. " 562 "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE, 563 tcpInfo.tcpi_rcv_ssthresh)) 564 565 def checkSynPacketWindowSize(self, version, netid): 566 s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark") 567 myaddr = self.MyAddress(version, netid) 568 dstaddr = self.GetRemoteAddress(version) 569 dstsockaddr = self.GetRemoteSocketAddress(version) 570 desc, expected = packets.SYN(53, version, myaddr, dstaddr, 571 sport=None, seq=None) 572 self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53)) 573 msg = "IPv%s TCP connect: expected %s on %s" % ( 574 version, desc, self.GetInterfaceName(netid)) 575 syn = self.ExpectPacketOn(netid, msg, expected) 576 self.assertLess(self.RWND_SIZE, syn.window) 577 s.close() 578 579 def testTcpCwndSize(self): 580 for version in [4, 5, 6]: 581 for netid in self.NETIDS: 582 self.checkInitRwndSize(version, netid) 583 self.checkSynPacketWindowSize(version, netid) 584 585 586class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 587 588 def setUp(self): 589 super(SockDestroyTcpTest, self).setUp() 590 self.netid = random.choice(list(self.tuns.keys())) 591 592 def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): 593 """Closes the socket and checks whether a RST is sent or not.""" 594 if sock is not None: 595 self.assertIsNone(req, "Must specify sock or req, not both") 596 self.sock_diag.CloseSocketFromFd(sock) 597 self.assertRaisesErrno(EINVAL, sock.accept) 598 else: 599 self.assertIsNone(sock, "Must specify sock or req, not both") 600 self.sock_diag.CloseSocket(req) 601 602 if expect_reset: 603 desc, rst = self.RstPacket() 604 msg = "%s: expecting %s: " % (msg, desc) 605 self.ExpectPacketOn(self.netid, msg, rst) 606 else: 607 msg = "%s: " % msg 608 self.ExpectNoPacketsOn(self.netid, msg) 609 610 if sock is not None and do_close: 611 sock.close() 612 613 def CheckTcpReset(self, state, statename): 614 for version in [4, 5, 6]: 615 msg = "Closing incoming IPv%d %s socket" % (version, statename) 616 self.IncomingConnection(version, state, self.netid) 617 self.CheckRstOnClose(self.s, None, False, msg) 618 if state != tcp_test.TCP_LISTEN: 619 msg = "Closing accepted IPv%d %s socket" % (version, statename) 620 self.CheckRstOnClose(self.accepted, None, True, msg) 621 622 def testTcpResets(self): 623 """Checks that closing sockets in appropriate states sends a RST.""" 624 self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN") 625 self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED") 626 self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") 627 628 def testFinWait1Socket(self): 629 for version in [4, 5, 6]: 630 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 631 632 # Get the cookie so we can find this socket after we close it. 633 diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted) 634 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 635 636 # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN. 637 net_test.EnableFinWait(self.accepted) 638 self.accepted.close() 639 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1 640 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 641 self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 642 desc, fin = self.FinPacket() 643 self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin) 644 645 # Destroy the socket and expect no RST. 646 self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket") 647 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 648 649 # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing 650 # because userspace had already closed it. 651 self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 652 653 # ACK the FIN so we don't trip over retransmits in future tests. 654 finversion = 4 if version == 5 else version 655 desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin) 656 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 657 self.ReceivePacketOn(self.netid, finack) 658 659 # See if we can find the resulting FIN_WAIT2 socket. 660 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2 661 infos = self.sock_diag.Dump(diag_req, NO_BYTECODE) 662 self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2 663 for diag_msg, attrs in infos), 664 "Expected to find FIN_WAIT2 socket in %s" % infos) 665 666 def FindChildSockets(self, s): 667 """Finds the SYN_RECV child sockets of a given listening socket.""" 668 d = self.sock_diag.FindSockDiagFromFd(self.s) 669 req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 670 req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED 671 req.id.cookie = b"\x00" * 8 672 673 bad_bytecode = self.PackAndCheckBytecode( 674 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))]) 675 self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode)) 676 677 bytecode = self.PackAndCheckBytecode( 678 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))]) 679 children = self.sock_diag.Dump(req, bytecode) 680 return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 681 for d, _ in children] 682 683 def CheckChildSocket(self, version, statename, parent_first): 684 state = getattr(tcp_test, statename) 685 686 self.IncomingConnection(version, state, self.netid) 687 688 d = self.sock_diag.FindSockDiagFromFd(self.s) 689 parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 690 children = self.FindChildSockets(self.s) 691 self.assertEqual(1, len(children)) 692 693 is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED) 694 expected_state = tcp_test.TCP_ESTABLISHED if is_established else state 695 696 for child in children: 697 diag_msg, attrs = self.sock_diag.GetSockInfo(child) 698 self.assertEqual(diag_msg.state, expected_state) 699 self.assertMarkIs(self.netid, attrs) 700 701 def CloseParent(expect_reset): 702 msg = "Closing parent IPv%d %s socket %s child" % ( 703 version, statename, "before" if parent_first else "after") 704 self.CheckRstOnClose(self.s, None, expect_reset, msg) 705 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent) 706 707 def CheckChildrenClosed(): 708 for child in children: 709 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 710 711 def CloseChildren(): 712 for child in children: 713 msg = "Closing child IPv%d %s socket %s parent" % ( 714 version, statename, "after" if parent_first else "before") 715 self.sock_diag.GetSockInfo(child) 716 self.CheckRstOnClose(None, child, is_established, msg) 717 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 718 CheckChildrenClosed() 719 720 if parent_first: 721 # Closing the parent will close child sockets, which will send a RST, 722 # iff they are already established. 723 CloseParent(is_established) 724 if is_established: 725 CheckChildrenClosed() 726 else: 727 CloseChildren() 728 CheckChildrenClosed() 729 self.s.close() 730 else: 731 CloseChildren() 732 CloseParent(False) 733 self.s.close() 734 735 def testChildSockets(self): 736 for version in [4, 5, 6]: 737 self.CheckChildSocket(version, "TCP_SYN_RECV", False) 738 self.CheckChildSocket(version, "TCP_SYN_RECV", True) 739 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False) 740 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True) 741 742 def testAcceptInterrupted(self): 743 """Tests that accept() is interrupted by SOCK_DESTROY.""" 744 for version in [4, 5, 6]: 745 self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid) 746 self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096) 747 self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) 748 self.assertRaisesErrno(ECONNABORTED, self.s.send, b"foo") 749 self.assertRaisesErrno(EINVAL, self.s.accept) 750 # TODO: this should really return an error such as ENOTCONN... 751 self.assertEqual(b"", self.s.recv(4096)) 752 753 def testReadInterrupted(self): 754 """Tests that read() is interrupted by SOCK_DESTROY.""" 755 for version in [4, 5, 6]: 756 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 757 self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), 758 ECONNABORTED) 759 # Writing returns EPIPE, and reading returns EOF. 760 self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo") 761 self.assertEqual(b"", self.accepted.recv(4096)) 762 self.assertEqual(b"", self.accepted.recv(4096)) 763 764 def testConnectInterrupted(self): 765 """Tests that connect() is interrupted by SOCK_DESTROY.""" 766 for version in [4, 5, 6]: 767 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 768 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 769 self.SelectInterface(s, self.netid, "mark") 770 771 remotesockaddr = self.GetRemoteSocketAddress(version) 772 remoteaddr = self.GetRemoteAddress(version) 773 s.bind(("", 0)) 774 _, sport = s.getsockname()[:2] 775 self.CloseDuringBlockingCall( 776 s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED) 777 desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), 778 remoteaddr, sport=sport, seq=None) 779 self.ExpectPacketOn(self.netid, desc, syn) 780 msg = "SOCK_DESTROY of socket in connect, expected no RST" 781 self.ExpectNoPacketsOn(self.netid, msg) 782 783 784class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 785 """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs. 786 787 The behaviour of poll() in these cases is not what we might expect: if only 788 POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT 789 is (also) specified, it will only return POLLOUT. 790 """ 791 792 POLLIN_OUT = select.POLLIN | select.POLLOUT 793 POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP 794 795 def setUp(self): 796 super(PollOnCloseTest, self).setUp() 797 self.netid = random.choice(list(self.tuns.keys())) 798 799 POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"), 800 (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")] 801 802 def PollResultToString(self, poll_events, ignoremask): 803 out = [] 804 for fd, event in poll_events: 805 flags = [name for (flag, name) in self.POLL_FLAGS 806 if event & flag & ~ignoremask != 0] 807 out.append((fd, "|".join(flags))) 808 return out 809 810 def BlockingPoll(self, sock, mask, expected, ignoremask): 811 p = select.poll() 812 p.register(sock, mask) 813 expected_fds = [(sock.fileno(), expected)] 814 # Don't block forever or we'll hang continuous test runs on failure. 815 # A 5-second timeout should be long enough not to be flaky. 816 actual_fds = p.poll(5000) 817 self.assertEqual(self.PollResultToString(expected_fds, ignoremask), 818 self.PollResultToString(actual_fds, ignoremask)) 819 820 def RstDuringBlockingCall(self, sock, call, expected_errno): 821 self._EventDuringBlockingCall( 822 sock, call, expected_errno, 823 lambda _: self.ReceiveRstPacketOn(self.netid)) 824 825 def assertSocketErrors(self, errno): 826 # The first operation returns the expected errno. 827 self.assertRaisesErrno(errno, self.accepted.recv, 4096) 828 829 # Subsequent operations behave as normal. 830 self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo") 831 self.assertEqual(b"", self.accepted.recv(4096)) 832 self.assertEqual(b"", self.accepted.recv(4096)) 833 834 def CheckPollDestroy(self, mask, expected, ignoremask): 835 """Interrupts a poll() with SOCK_DESTROY.""" 836 for version in [4, 5, 6]: 837 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 838 self.CloseDuringBlockingCall( 839 self.accepted, 840 lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), 841 None) 842 self.assertSocketErrors(ECONNABORTED) 843 844 def CheckPollRst(self, mask, expected, ignoremask): 845 """Interrupts a poll() by receiving a TCP RST.""" 846 for version in [4, 5, 6]: 847 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 848 self.RstDuringBlockingCall( 849 self.accepted, 850 lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), 851 None) 852 self.assertSocketErrors(ECONNRESET) 853 854 def testReadPollRst(self): 855 self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, 0) 856 857 def testWritePollRst(self): 858 self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0) 859 860 def testReadWritePollRst(self): 861 self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0) 862 863 def testReadPollDestroy(self): 864 # tcp_abort has the same race that tcp_reset has, but it's not fixed yet. 865 ignoremask = select.POLLIN | select.POLLHUP 866 self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask) 867 868 def testWritePollDestroy(self): 869 self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0) 870 871 def testReadWritePollDestroy(self): 872 self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0) 873 874 875class SockDestroyUdpTest(SockDiagBaseTest): 876 877 """Tests SOCK_DESTROY on UDP sockets. 878 879 Relevant kernel commits: 880 upstream net-next: 881 5d77dca net: diag: support SOCK_DESTROY for UDP sockets 882 f95bf34 net: diag: make udp_diag_destroy work for mapped addresses. 883 """ 884 885 def testClosesUdpSockets(self): 886 self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM) 887 for _, socketpair in self.socketpairs.items(): 888 s1, s2 = socketpair 889 890 self.assertSocketConnected(s1) 891 self.sock_diag.CloseSocketFromFd(s1) 892 self.assertSocketClosed(s1) 893 894 self.assertSocketConnected(s2) 895 self.sock_diag.CloseSocketFromFd(s2) 896 self.assertSocketClosed(s2) 897 898 def BindToRandomPort(self, s, addr): 899 ATTEMPTS = 20 900 for i in range(20): 901 port = random.randrange(1024, 65535) 902 try: 903 s.bind((addr, port)) 904 return port 905 except error as e: 906 if e.errno != EADDRINUSE: 907 raise e 908 raise ValueError("Could not find a free port on %s after %d attempts" % 909 (addr, ATTEMPTS)) 910 911 def testSocketAddressesAfterClose(self): 912 for version in 4, 5, 6: 913 netid = random.choice(self.NETIDS) 914 dst = self.GetRemoteSocketAddress(version) 915 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 916 unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version] 917 918 # Closing a socket that was not explicitly bound (i.e., bound via 919 # connect(), not bind()) clears the source address and port. 920 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 921 self.SelectInterface(s, netid, "mark") 922 s.connect((dst, 53)) 923 self.sock_diag.CloseSocketFromFd(s) 924 self.assertEqual((unspec, 0), s.getsockname()[:2]) 925 926 # Closing a socket bound to an IP address leaves the address as is. 927 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 928 src = self.MySocketAddress(version, netid) 929 s.bind((src, 0)) 930 s.connect((dst, 53)) 931 port = s.getsockname()[1] 932 self.sock_diag.CloseSocketFromFd(s) 933 self.assertEqual((src, 0), s.getsockname()[:2]) 934 935 # Closing a socket bound to a port leaves the port as is. 936 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 937 port = self.BindToRandomPort(s, "") 938 s.connect((dst, 53)) 939 self.sock_diag.CloseSocketFromFd(s) 940 self.assertEqual((unspec, port), s.getsockname()[:2]) 941 942 # Closing a socket bound to IP address and port leaves both as is. 943 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 944 src = self.MySocketAddress(version, netid) 945 port = self.BindToRandomPort(s, src) 946 self.sock_diag.CloseSocketFromFd(s) 947 self.assertEqual((src, port), s.getsockname()[:2]) 948 949 def testReadInterrupted(self): 950 """Tests that read() is interrupted by SOCK_DESTROY.""" 951 for version in [4, 5, 6]: 952 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 953 s = net_test.UDPSocket(family) 954 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 955 addr = self.GetRemoteSocketAddress(version) 956 957 # Check that reads on connected sockets are interrupted. 958 s.connect((addr, 53)) 959 self.assertEqual(3, s.send(b"foo")) 960 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 961 ECONNABORTED) 962 963 # A destroyed socket is no longer connected, but still usable. 964 self.assertRaisesErrno(EDESTADDRREQ, s.send, b"foo") 965 self.assertEqual(3, s.sendto(b"foo", (addr, 53))) 966 967 # Check that reads on unconnected sockets are also interrupted. 968 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 969 ECONNABORTED) 970 971class SockDestroyPermissionTest(SockDiagBaseTest): 972 973 def CheckPermissions(self, socktype): 974 s = socket(AF_INET6, socktype, 0) 975 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 976 if socktype == SOCK_STREAM: 977 s.listen(1) 978 expectedstate = tcp_test.TCP_LISTEN 979 else: 980 s.connect((self.GetRemoteAddress(6), 53)) 981 expectedstate = tcp_test.TCP_ESTABLISHED 982 983 with net_test.RunAsUid(12345): 984 self.assertRaisesErrno( 985 EPERM, self.sock_diag.CloseSocketFromFd, s) 986 987 self.sock_diag.CloseSocketFromFd(s) 988 self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s) 989 990 991 def testUdp(self): 992 self.CheckPermissions(SOCK_DGRAM) 993 994 def testTcp(self): 995 self.CheckPermissions(SOCK_STREAM) 996 997 998class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 999 1000 """Tests SOCK_DIAG bytecode filters that use marks. 1001 1002 Relevant kernel commits: 1003 upstream net-next: 1004 627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks. 1005 a52e95a net: diag: allow socket bytecode filters to match socket marks 1006 d545cac net: inet: diag: expose the socket mark to privileged processes. 1007 """ 1008 1009 def FilterEstablishedSockets(self, mark, mask): 1010 instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))] 1011 bytecode = self.sock_diag.PackBytecode(instructions) 1012 return self.sock_diag.DumpAllInetSockets( 1013 IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED)) 1014 1015 def assertSamePorts(self, ports, diag_msgs): 1016 expected = sorted(ports) 1017 actual = sorted([msg[0].id.sport for msg in diag_msgs]) 1018 self.assertEqual(expected, actual) 1019 1020 def SockInfoMatchesSocket(self, s, info): 1021 try: 1022 self.assertSockInfoMatchesSocket(s, info) 1023 return True 1024 except AssertionError: 1025 return False 1026 1027 @staticmethod 1028 def SocketDescription(s): 1029 return "%s -> %s" % (str(s.getsockname()), str(s.getpeername())) 1030 1031 def assertFoundSockets(self, infos, sockets): 1032 matches = {} 1033 for s in sockets: 1034 match = None 1035 for info in infos: 1036 if self.SockInfoMatchesSocket(s, info): 1037 if match: 1038 self.fail("Socket %s matched both %s and %s" % 1039 (self.SocketDescription(s), match, info)) 1040 matches[s] = info 1041 self.assertTrue(s in matches, "Did not find socket %s in dump" % 1042 self.SocketDescription(s)) 1043 1044 for i in infos: 1045 if i not in list(matches.values()): 1046 self.fail("Too many sockets in dump, first unexpected: %s" % str(i)) 1047 1048 def testMarkBytecode(self): 1049 family, addr = random.choice([ 1050 (AF_INET, "127.0.0.1"), 1051 (AF_INET6, "::1"), 1052 (AF_INET6, "::ffff:127.0.0.1")]) 1053 s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr) 1054 s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234) 1055 s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235) 1056 1057 infos = self.FilterEstablishedSockets(0x1234, 0xffff) 1058 self.assertFoundSockets(infos, [s1]) 1059 1060 infos = self.FilterEstablishedSockets(0x1234, 0xfffe) 1061 self.assertFoundSockets(infos, [s1, s2]) 1062 1063 infos = self.FilterEstablishedSockets(0x1235, 0xffff) 1064 self.assertFoundSockets(infos, [s2]) 1065 1066 infos = self.FilterEstablishedSockets(0x0, 0x0) 1067 self.assertFoundSockets(infos, [s1, s2]) 1068 1069 infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00) 1070 self.assertEqual(0, len(infos)) 1071 1072 with net_test.RunAsUid(12345): 1073 self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets, 1074 0xfff0000, 0xf0fed00) 1075 1076 @staticmethod 1077 def SetRandomMark(s): 1078 # Python doesn't like marks that don't fit into a signed int. 1079 mark = random.randrange(0, 2**31 - 1) 1080 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark) 1081 return mark 1082 1083 def assertSocketMarkIs(self, s, mark): 1084 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 1085 self.assertMarkIs(mark, attrs) 1086 with net_test.RunAsUid(12345): 1087 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 1088 self.assertMarkIs(None, attrs) 1089 1090 def testMarkInAttributes(self): 1091 testcases = [(AF_INET, "127.0.0.1"), 1092 (AF_INET6, "::1"), 1093 (AF_INET6, "::ffff:127.0.0.1")] 1094 for family, addr in testcases: 1095 # TCP listen sockets. 1096 server = socket(family, SOCK_STREAM, 0) 1097 server.bind((addr, 0)) 1098 port = server.getsockname()[1] 1099 server.listen(1) # Or the socket won't be in the hashtables. 1100 server_mark = self.SetRandomMark(server) 1101 self.assertSocketMarkIs(server, server_mark) 1102 1103 # TCP client sockets. 1104 client = socket(family, SOCK_STREAM, 0) 1105 client_mark = self.SetRandomMark(client) 1106 client.connect((addr, port)) 1107 self.assertSocketMarkIs(client, client_mark) 1108 1109 # TCP server sockets. 1110 accepted, _ = server.accept() 1111 self.assertSocketMarkIs(accepted, server_mark) 1112 1113 accepted_mark = self.SetRandomMark(accepted) 1114 self.assertSocketMarkIs(accepted, accepted_mark) 1115 self.assertSocketMarkIs(server, server_mark) 1116 1117 server.close() 1118 client.close() 1119 1120 # Other TCP states are tested in SockDestroyTcpTest. 1121 1122 # UDP sockets. 1123 s = socket(family, SOCK_DGRAM, 0) 1124 mark = self.SetRandomMark(s) 1125 s.connect(("", 53)) 1126 self.assertSocketMarkIs(s, mark) 1127 s.close() 1128 1129 # Basic test for SCTP. sctp_diag was only added in 4.7. 1130 if HAVE_SCTP: 1131 s = socket(family, SOCK_STREAM, IPPROTO_SCTP) 1132 s.bind((addr, 0)) 1133 s.listen(1) 1134 mark = self.SetRandomMark(s) 1135 self.assertSocketMarkIs(s, mark) 1136 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE) 1137 self.assertEqual(1, len(sockets)) 1138 self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None)) 1139 s.close() 1140 1141 1142if __name__ == "__main__": 1143 unittest.main() 1144