1#!/usr/bin/python 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import 18from errno import * # pylint: disable=wildcard-import 19import os 20import random 21import select 22from socket import * # pylint: disable=wildcard-import 23import struct 24import threading 25import time 26import unittest 27 28import cstruct 29import multinetwork_base 30import net_test 31import packets 32import sock_diag 33import tcp_test 34 35# Mostly empty structure definition containing only the fields we currently use. 36TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh") 37 38NUM_SOCKETS = 30 39NO_BYTECODE = "" 40LINUX_4_9_OR_ABOVE = net_test.LINUX_VERSION >= (4, 9, 0) 41LINUX_4_19_OR_ABOVE = net_test.LINUX_VERSION >= (4, 19, 0) 42 43IPPROTO_SCTP = 132 44 45def HaveUdpDiag(): 46 """Checks if the current kernel has config CONFIG_INET_UDP_DIAG enabled. 47 48 This config is required for device running 4.9 kernel that ship with P, In 49 this case always assume the config is there and use the tests to check if the 50 config is enabled as required. 51 52 For all ther other kernel version, there is no way to tell whether a dump 53 succeeded: if the appropriate handler wasn't found, __inet_diag_dump just 54 returns an empty result instead of an error. So, just check to see if a UDP 55 dump returns no sockets when we know it should return one. If not, some tests 56 will be skipped. 57 58 Returns: 59 True if the kernel is 4.9 or above, or the CONFIG_INET_UDP_DIAG is enabled. 60 False otherwise. 61 """ 62 if LINUX_4_9_OR_ABOVE: 63 return True; 64 s = socket(AF_INET6, SOCK_DGRAM, 0) 65 s.bind(("::", 0)) 66 s.connect((s.getsockname())) 67 sd = sock_diag.SockDiag() 68 have_udp_diag = len(sd.DumpAllInetSockets(IPPROTO_UDP, "")) > 0 69 s.close() 70 return have_udp_diag 71 72def HaveSctp(): 73 if net_test.LINUX_VERSION < (4, 7, 0): 74 return False 75 try: 76 s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP) 77 s.close() 78 return True 79 except IOError: 80 return False 81 82HAVE_UDP_DIAG = HaveUdpDiag() 83HAVE_SCTP = HaveSctp() 84 85 86class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): 87 """Basic tests for SOCK_DIAG functionality. 88 89 Relevant kernel commits: 90 android-3.4: 91 ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields 92 99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 93 94 android-3.10: 95 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields 96 f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 97 98 android-3.18: 99 e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 100 101 android-4.4: 102 525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 103 """ 104 @staticmethod 105 def _CreateLotsOfSockets(socktype): 106 # Dict mapping (addr, sport, dport) tuples to socketpairs. 107 socketpairs = {} 108 for _ in range(NUM_SOCKETS): 109 family, addr = random.choice([ 110 (AF_INET, "127.0.0.1"), 111 (AF_INET6, "::1"), 112 (AF_INET6, "::ffff:127.0.0.1")]) 113 socketpair = net_test.CreateSocketPair(family, socktype, addr) 114 sport, dport = (socketpair[0].getsockname()[1], 115 socketpair[1].getsockname()[1]) 116 socketpairs[(addr, sport, dport)] = socketpair 117 return socketpairs 118 119 def assertSocketClosed(self, sock): 120 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 121 122 def assertSocketConnected(self, sock): 123 sock.getpeername() # No errors? Socket is alive and connected. 124 125 def assertSocketsClosed(self, socketpair): 126 for sock in socketpair: 127 self.assertSocketClosed(sock) 128 129 def assertMarkIs(self, mark, attrs): 130 self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None)) 131 132 def assertSockInfoMatchesSocket(self, s, info): 133 diag_msg, attrs = info 134 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 135 self.assertEqual(diag_msg.family, family) 136 137 src, sport = s.getsockname()[0:2] 138 self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) 139 self.assertEqual(diag_msg.id.sport, sport) 140 141 if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: 142 dst, dport = s.getpeername()[0:2] 143 self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) 144 self.assertEqual(diag_msg.id.dport, dport) 145 else: 146 self.assertRaisesErrno(ENOTCONN, s.getpeername) 147 148 mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 149 self.assertMarkIs(mark, attrs) 150 151 def PackAndCheckBytecode(self, instructions): 152 bytecode = self.sock_diag.PackBytecode(instructions) 153 decoded = self.sock_diag.DecodeBytecode(bytecode) 154 self.assertEqual(len(instructions), len(decoded)) 155 self.assertFalse("???" in decoded) 156 return bytecode 157 158 def _EventDuringBlockingCall(self, sock, call, expected_errno, event): 159 """Simulates an external event during a blocking call on sock. 160 161 Args: 162 sock: The socket to use. 163 call: A function, the call to make. Takes one parameter, sock. 164 expected_errno: The value that call is expected to fail with, or None if 165 call is expected to succeed. 166 event: A function, the event that will happen during the blocking call. 167 Takes one parameter, sock. 168 """ 169 thread = SocketExceptionThread(sock, call) 170 thread.start() 171 time.sleep(0.1) 172 event(sock) 173 thread.join(1) 174 self.assertFalse(thread.is_alive()) 175 if expected_errno is not None: 176 self.assertIsNotNone(thread.exception) 177 self.assertTrue(isinstance(thread.exception, IOError), 178 "Expected IOError, got %s" % thread.exception) 179 self.assertEqual(expected_errno, thread.exception.errno) 180 else: 181 self.assertIsNone(thread.exception) 182 self.assertSocketClosed(sock) 183 184 def CloseDuringBlockingCall(self, sock, call, expected_errno): 185 self._EventDuringBlockingCall( 186 sock, call, expected_errno, 187 lambda sock: self.sock_diag.CloseSocketFromFd(sock)) 188 189 def setUp(self): 190 super(SockDiagBaseTest, self).setUp() 191 self.sock_diag = sock_diag.SockDiag() 192 self.socketpairs = {} 193 194 def tearDown(self): 195 for socketpair in list(self.socketpairs.values()): 196 for s in socketpair: 197 s.close() 198 super(SockDiagBaseTest, self).tearDown() 199 200 201class SockDiagTest(SockDiagBaseTest): 202 203 def testFindsMappedSockets(self): 204 """Tests that inet_diag_find_one_icsk can find mapped sockets.""" 205 socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 206 "::ffff:127.0.0.1") 207 for sock in socketpair: 208 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 209 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 210 self.sock_diag.GetSockInfo(diag_req) 211 # No errors? Good. 212 213 def CheckFindsAllMySockets(self, socktype, proto): 214 """Tests that basic socket dumping works.""" 215 self.socketpairs = self._CreateLotsOfSockets(socktype) 216 sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE) 217 self.assertGreaterEqual(len(sockets), NUM_SOCKETS) 218 219 # Find the cookies for all of our sockets. 220 cookies = {} 221 for diag_msg, unused_attrs in sockets: 222 addr = self.sock_diag.GetSourceAddress(diag_msg) 223 sport = diag_msg.id.sport 224 dport = diag_msg.id.dport 225 if (addr, sport, dport) in self.socketpairs: 226 cookies[(addr, sport, dport)] = diag_msg.id.cookie 227 elif (addr, dport, sport) in self.socketpairs: 228 cookies[(addr, sport, dport)] = diag_msg.id.cookie 229 230 # Did we find all the cookies? 231 self.assertEqual(2 * NUM_SOCKETS, len(cookies)) 232 233 socketpairs = list(self.socketpairs.values()) 234 random.shuffle(socketpairs) 235 for socketpair in socketpairs: 236 for sock in socketpair: 237 # Check that we can find a diag_msg by scanning a dump. 238 self.assertSockInfoMatchesSocket( 239 sock, 240 self.sock_diag.FindSockInfoFromFd(sock)) 241 cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie 242 243 # Check that we can find a diag_msg once we know the cookie. 244 req = self.sock_diag.DiagReqFromSocket(sock) 245 req.id.cookie = cookie 246 if proto == IPPROTO_UDP: 247 # Kernel bug: for UDP sockets, the order of arguments must be swapped. 248 # See testDemonstrateUdpGetSockIdBug. 249 req.id.sport, req.id.dport = req.id.dport, req.id.sport 250 req.id.src, req.id.dst = req.id.dst, req.id.src 251 info = self.sock_diag.GetSockInfo(req) 252 self.assertSockInfoMatchesSocket(sock, info) 253 254 def testFindsAllMySocketsTcp(self): 255 self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP) 256 257 @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled") 258 def testFindsAllMySocketsUdp(self): 259 self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP) 260 261 def testBytecodeCompilation(self): 262 # pylint: disable=bad-whitespace 263 instructions = [ 264 (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 265 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 266 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 267 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 268 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 269 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 270 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 271 # 76 acc 272 # 80 rej 273 ] 274 # pylint: enable=bad-whitespace 275 bytecode = self.PackAndCheckBytecode(instructions) 276 expected = ( 277 "0208500000000000" 278 "050848000000ffff" 279 "071c20000a800000ffffffff00000000000000000000000000000001" 280 "01041c00" 281 "0718200002200000ffffffff7f000001" 282 "0508100000006566" 283 "00040400" 284 ) 285 states = 1 << tcp_test.TCP_ESTABLISHED 286 self.assertMultiLineEqual(expected, bytecode.encode("hex")) 287 self.assertEqual(76, len(bytecode)) 288 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 289 filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode, 290 states=states) 291 allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, 292 states=states) 293 self.assertItemsEqual(allsockets, filteredsockets) 294 295 # Pick a few sockets in hash table order, and check that the bytecode we 296 # compiled selects them properly. 297 for socketpair in list(self.socketpairs.values())[:20]: 298 for s in socketpair: 299 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 300 instructions = [ 301 (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), 302 (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), 303 (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), 304 (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), 305 ] 306 bytecode = self.PackAndCheckBytecode(instructions) 307 self.assertEqual(32, len(bytecode)) 308 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) 309 self.assertEqual(1, len(sockets)) 310 311 # TODO: why doesn't comparing the cstructs work? 312 self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack()) 313 314 def testCrossFamilyBytecode(self): 315 """Checks for a cross-family bug in inet_diag_hostcond matching. 316 317 Relevant kernel commits: 318 android-3.4: 319 f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() 320 """ 321 # TODO: this is only here because the test fails if there are any open 322 # sockets other than the ones it creates itself. Make the bytecode more 323 # specific and remove it. 324 states = 1 << tcp_test.TCP_ESTABLISHED 325 self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "", 326 states=states)) 327 328 unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") 329 unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") 330 331 bytecode4 = self.PackAndCheckBytecode([ 332 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) 333 bytecode6 = self.PackAndCheckBytecode([ 334 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) 335 336 # IPv4/v6 filters must never match IPv6/IPv4 sockets... 337 v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4, 338 states=states) 339 self.assertTrue(v4socks) 340 self.assertTrue(all(d.family == AF_INET for d, _ in v4socks)) 341 342 v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6, 343 states=states) 344 self.assertTrue(v6socks) 345 self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks)) 346 347 # Except for mapped addresses, which match both IPv4 and IPv6. 348 pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 349 "::ffff:127.0.0.1") 350 diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] 351 v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 352 bytecode4, 353 states=states)] 354 v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 355 bytecode6, 356 states=states)] 357 self.assertTrue(all(d in v4socks for d in diag_msgs)) 358 self.assertTrue(all(d in v6socks for d in diag_msgs)) 359 360 def testPortComparisonValidation(self): 361 """Checks for a bug in validating port comparison bytecode. 362 363 Relevant kernel commits: 364 android-3.4: 365 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads 366 """ 367 bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) 368 self.assertEqual("???", 369 self.sock_diag.DecodeBytecode(bytecode)) 370 self.assertRaisesErrno( 371 EINVAL, 372 self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack()) 373 374 def testNonSockDiagCommand(self): 375 def DiagDump(code): 376 sock_id = self.sock_diag._EmptyInetDiagSockId() 377 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, 378 sock_id)) 379 self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "") 380 381 op = sock_diag.SOCK_DIAG_BY_FAMILY 382 DiagDump(op) # No errors? Good. 383 self.assertRaisesErrno(EINVAL, DiagDump, op + 17) 384 385 def CheckSocketCookie(self, inet, addr): 386 """Tests that getsockopt SO_COOKIE can get cookie for all sockets.""" 387 socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr) 388 for sock in socketpair: 389 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 390 cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 391 self.assertEqual(diag_msg.id.cookie, cookie) 392 393 @unittest.skipUnless(LINUX_4_9_OR_ABOVE, "SO_COOKIE not supported") 394 def testGetsockoptcookie(self): 395 self.CheckSocketCookie(AF_INET, "127.0.0.1") 396 self.CheckSocketCookie(AF_INET6, "::1") 397 398 @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled") 399 def testDemonstrateUdpGetSockIdBug(self): 400 # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup 401 # by passing the source address as the source address argument. 402 # Unfortunately those functions are intended to match local sockets based 403 # on received packets, and the argument that ends up being compared with 404 # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not 405 # have this bug. Upstream has confirmed that this will not be fixed: 406 # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html 407 """Documents a bug: getting UDP sockets requires swapping src and dst.""" 408 for version in [4, 5, 6]: 409 family = net_test.GetAddressFamily(version) 410 s = socket(family, SOCK_DGRAM, 0) 411 self.SelectInterface(s, self.RandomNetid(), "mark") 412 s.connect((self.GetRemoteSocketAddress(version), 53)) 413 414 # Create a fully-specified diag req from our socket, including cookie if 415 # we can get it. 416 req = self.sock_diag.DiagReqFromSocket(s) 417 if LINUX_4_9_OR_ABOVE: 418 req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 419 else: 420 req.id.cookie = "\xff" * 16 # INET_DIAG_NOCOOKIE[2] 421 422 # As is, this request does not find anything. 423 with self.assertRaisesErrno(ENOENT): 424 self.sock_diag.GetSockInfo(req) 425 426 # But if we swap src and dst, the kernel finds our socket. 427 req.id.sport, req.id.dport = req.id.dport, req.id.sport 428 req.id.src, req.id.dst = req.id.dst, req.id.src 429 430 self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req)) 431 432 433class SockDestroyTest(SockDiagBaseTest): 434 """Tests that SOCK_DESTROY works correctly. 435 436 Relevant kernel commits: 437 net-next: 438 b613f56 net: diag: split inet_diag_dump_one_icsk into two 439 64be0ae net: diag: Add the ability to destroy a socket. 440 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets. 441 c1e64e2 net: diag: Support destroying TCP sockets. 442 2010b93 net: tcp: deal with listen sockets properly in tcp_abort. 443 444 android-3.4: 445 d48ec88 net: diag: split inet_diag_dump_one_icsk into two 446 2438189 net: diag: Add the ability to destroy a socket. 447 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets. 448 44047b2 net: diag: Support destroying TCP sockets. 449 200dae7 net: tcp: deal with listen sockets properly in tcp_abort. 450 451 android-3.10: 452 9eaff90 net: diag: split inet_diag_dump_one_icsk into two 453 d60326c net: diag: Add the ability to destroy a socket. 454 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets. 455 529dfc6 net: diag: Support destroying TCP sockets. 456 9c712fe net: tcp: deal with listen sockets properly in tcp_abort. 457 458 android-3.18: 459 100263d net: diag: split inet_diag_dump_one_icsk into two 460 194c5f3 net: diag: Add the ability to destroy a socket. 461 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets. 462 b80585a net: diag: Support destroying TCP sockets. 463 476c6ce net: tcp: deal with listen sockets properly in tcp_abort. 464 465 android-4.1: 466 56eebf8 net: diag: split inet_diag_dump_one_icsk into two 467 fb486c9 net: diag: Add the ability to destroy a socket. 468 0c02b7e net: diag: Support SOCK_DESTROY for inet sockets. 469 67c71d8 net: diag: Support destroying TCP sockets. 470 a76e0ec net: tcp: deal with listen sockets properly in tcp_abort. 471 e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 472 473 android-4.4: 474 76c83a9 net: diag: split inet_diag_dump_one_icsk into two 475 f7cf791 net: diag: Add the ability to destroy a socket. 476 1c42248 net: diag: Support SOCK_DESTROY for inet sockets. 477 c9e8440d net: diag: Support destroying TCP sockets. 478 3d9502c tcp: diag: add support for request sockets to tcp_abort() 479 001cf75 net: tcp: deal with listen sockets properly in tcp_abort. 480 """ 481 482 def testClosesSockets(self): 483 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 484 for _, socketpair in self.socketpairs.items(): 485 # Close one of the sockets. 486 # This will send a RST that will close the other side as well. 487 s = random.choice(socketpair) 488 if random.randrange(0, 2) == 1: 489 self.sock_diag.CloseSocketFromFd(s) 490 else: 491 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 492 493 # Get the cookie wrong and ensure that we get an error and the socket 494 # is not closed. 495 real_cookie = diag_msg.id.cookie 496 diag_msg.id.cookie = os.urandom(len(real_cookie)) 497 req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 498 self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) 499 self.assertSocketConnected(s) 500 501 # Now close it with the correct cookie. 502 req.id.cookie = real_cookie 503 self.sock_diag.CloseSocket(req) 504 505 # Check that both sockets in the pair are closed. 506 self.assertSocketsClosed(socketpair) 507 508 # TODO: 509 # Test that killing unix sockets returns EOPNOTSUPP. 510 511 512class SocketExceptionThread(threading.Thread): 513 514 def __init__(self, sock, operation): 515 self.exception = None 516 super(SocketExceptionThread, self).__init__() 517 self.daemon = True 518 self.sock = sock 519 self.operation = operation 520 521 def run(self): 522 try: 523 self.operation(self.sock) 524 except (IOError, AssertionError) as e: 525 self.exception = e 526 527 528class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 529 530 def testIpv4MappedSynRecvSocket(self): 531 """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. 532 533 Relevant kernel commits: 534 android-3.4: 535 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state 536 """ 537 netid = random.choice(list(self.tuns.keys())) 538 self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid) 539 sock_id = self.sock_diag._EmptyInetDiagSockId() 540 sock_id.sport = self.port 541 states = 1 << tcp_test.TCP_SYN_RECV 542 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) 543 children = self.sock_diag.Dump(req, NO_BYTECODE) 544 545 self.assertTrue(children) 546 for child, unused_args in children: 547 self.assertEqual(tcp_test.TCP_SYN_RECV, child.state) 548 self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr), 549 child.id.dst) 550 self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr), 551 child.id.src) 552 553 554class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 555 556 RWND_SIZE = 64000 if LINUX_4_19_OR_ABOVE else 42000 557 TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd" 558 559 def setUp(self): 560 super(TcpRcvWindowTest, self).setUp() 561 if LINUX_4_19_OR_ABOVE: 562 self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w") 563 return 564 565 try: 566 f = open(self.TCP_DEFAULT_INIT_RWND, "w") 567 except IOError as e: 568 # sysctl was namespace-ified on May 25, 2020 in android-4.14-stable [R] 569 # just after 4.14.181 by: 570 # https://android-review.googlesource.com/c/kernel/common/+/1312623 571 # ANDROID: namespace'ify tcp_default_init_rwnd implementation 572 # But that commit might be missing in Q era kernels even when > 4.14.181 573 # when running T vts. 574 if net_test.LINUX_VERSION >= (4, 15, 0): 575 raise 576 if e.errno != ENOENT: 577 raise 578 # we rely on the network namespace creation code 579 # modifying the root netns sysctl before the namespace is even created 580 return 581 582 f.write("60") 583 584 def checkInitRwndSize(self, version, netid): 585 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid) 586 tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP, 587 net_test.TCP_INFO, len(TcpInfo))) 588 self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh, 589 "Tcp rwnd of netid=%d, version=%d is not enough. " 590 "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE, 591 tcpInfo.tcpi_rcv_ssthresh)) 592 593 def checkSynPacketWindowSize(self, version, netid): 594 s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark") 595 myaddr = self.MyAddress(version, netid) 596 dstaddr = self.GetRemoteAddress(version) 597 dstsockaddr = self.GetRemoteSocketAddress(version) 598 desc, expected = packets.SYN(53, version, myaddr, dstaddr, 599 sport=None, seq=None) 600 self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53)) 601 msg = "IPv%s TCP connect: expected %s on %s" % ( 602 version, desc, self.GetInterfaceName(netid)) 603 syn = self.ExpectPacketOn(netid, msg, expected) 604 self.assertLess(self.RWND_SIZE, syn.window) 605 s.close() 606 607 def testTcpCwndSize(self): 608 for version in [4, 5, 6]: 609 for netid in self.NETIDS: 610 self.checkInitRwndSize(version, netid) 611 self.checkSynPacketWindowSize(version, netid) 612 613 614class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 615 616 def setUp(self): 617 super(SockDestroyTcpTest, self).setUp() 618 self.netid = random.choice(list(self.tuns.keys())) 619 620 def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): 621 """Closes the socket and checks whether a RST is sent or not.""" 622 if sock is not None: 623 self.assertIsNone(req, "Must specify sock or req, not both") 624 self.sock_diag.CloseSocketFromFd(sock) 625 self.assertRaisesErrno(EINVAL, sock.accept) 626 else: 627 self.assertIsNone(sock, "Must specify sock or req, not both") 628 self.sock_diag.CloseSocket(req) 629 630 if expect_reset: 631 desc, rst = self.RstPacket() 632 msg = "%s: expecting %s: " % (msg, desc) 633 self.ExpectPacketOn(self.netid, msg, rst) 634 else: 635 msg = "%s: " % msg 636 self.ExpectNoPacketsOn(self.netid, msg) 637 638 if sock is not None and do_close: 639 sock.close() 640 641 def CheckTcpReset(self, state, statename): 642 for version in [4, 5, 6]: 643 msg = "Closing incoming IPv%d %s socket" % (version, statename) 644 self.IncomingConnection(version, state, self.netid) 645 self.CheckRstOnClose(self.s, None, False, msg) 646 if state != tcp_test.TCP_LISTEN: 647 msg = "Closing accepted IPv%d %s socket" % (version, statename) 648 self.CheckRstOnClose(self.accepted, None, True, msg) 649 650 def testTcpResets(self): 651 """Checks that closing sockets in appropriate states sends a RST.""" 652 self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN") 653 self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED") 654 self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") 655 656 def testFinWait1Socket(self): 657 for version in [4, 5, 6]: 658 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 659 660 # Get the cookie so we can find this socket after we close it. 661 diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted) 662 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 663 664 # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN. 665 net_test.EnableFinWait(self.accepted) 666 self.accepted.close() 667 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1 668 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 669 self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 670 desc, fin = self.FinPacket() 671 self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin) 672 673 # Destroy the socket and expect no RST. 674 self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket") 675 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 676 677 # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing 678 # because userspace had already closed it. 679 self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 680 681 # ACK the FIN so we don't trip over retransmits in future tests. 682 finversion = 4 if version == 5 else version 683 desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin) 684 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 685 self.ReceivePacketOn(self.netid, finack) 686 687 # See if we can find the resulting FIN_WAIT2 socket. This does not appear 688 # to work on 3.10. 689 if net_test.LINUX_VERSION >= (3, 18): 690 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2 691 infos = self.sock_diag.Dump(diag_req, "") 692 self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2 693 for diag_msg, attrs in infos), 694 "Expected to find FIN_WAIT2 socket in %s" % infos) 695 696 def FindChildSockets(self, s): 697 """Finds the SYN_RECV child sockets of a given listening socket.""" 698 d = self.sock_diag.FindSockDiagFromFd(self.s) 699 req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 700 req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED 701 req.id.cookie = "\x00" * 8 702 703 bad_bytecode = self.PackAndCheckBytecode( 704 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))]) 705 self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode)) 706 707 bytecode = self.PackAndCheckBytecode( 708 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))]) 709 children = self.sock_diag.Dump(req, bytecode) 710 return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 711 for d, _ in children] 712 713 def CheckChildSocket(self, version, statename, parent_first): 714 state = getattr(tcp_test, statename) 715 716 self.IncomingConnection(version, state, self.netid) 717 718 d = self.sock_diag.FindSockDiagFromFd(self.s) 719 parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 720 children = self.FindChildSockets(self.s) 721 self.assertEqual(1, len(children)) 722 723 is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED) 724 expected_state = tcp_test.TCP_ESTABLISHED if is_established else state 725 726 # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the 727 # regular TCP hash tables, and inet_diag_find_one_icsk can find them. 728 # Before 4.4, we can see those sockets in dumps, but we can't fetch 729 # or close them. 730 can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4) 731 732 for child in children: 733 if can_close_children: 734 diag_msg, attrs = self.sock_diag.GetSockInfo(child) 735 self.assertEqual(diag_msg.state, expected_state) 736 self.assertMarkIs(self.netid, attrs) 737 else: 738 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 739 740 def CloseParent(expect_reset): 741 msg = "Closing parent IPv%d %s socket %s child" % ( 742 version, statename, "before" if parent_first else "after") 743 self.CheckRstOnClose(self.s, None, expect_reset, msg) 744 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent) 745 746 def CheckChildrenClosed(): 747 for child in children: 748 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 749 750 def CloseChildren(): 751 for child in children: 752 msg = "Closing child IPv%d %s socket %s parent" % ( 753 version, statename, "after" if parent_first else "before") 754 self.sock_diag.GetSockInfo(child) 755 self.CheckRstOnClose(None, child, is_established, msg) 756 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 757 CheckChildrenClosed() 758 759 if parent_first: 760 # Closing the parent will close child sockets, which will send a RST, 761 # iff they are already established. 762 CloseParent(is_established) 763 if is_established: 764 CheckChildrenClosed() 765 elif can_close_children: 766 CloseChildren() 767 CheckChildrenClosed() 768 self.s.close() 769 else: 770 if can_close_children: 771 CloseChildren() 772 CloseParent(False) 773 self.s.close() 774 775 def testChildSockets(self): 776 for version in [4, 5, 6]: 777 self.CheckChildSocket(version, "TCP_SYN_RECV", False) 778 self.CheckChildSocket(version, "TCP_SYN_RECV", True) 779 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False) 780 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True) 781 782 def testAcceptInterrupted(self): 783 """Tests that accept() is interrupted by SOCK_DESTROY.""" 784 for version in [4, 5, 6]: 785 self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid) 786 self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096) 787 self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) 788 self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") 789 self.assertRaisesErrno(EINVAL, self.s.accept) 790 # TODO: this should really return an error such as ENOTCONN... 791 self.assertEqual("", self.s.recv(4096)) 792 793 def testReadInterrupted(self): 794 """Tests that read() is interrupted by SOCK_DESTROY.""" 795 for version in [4, 5, 6]: 796 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 797 self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), 798 ECONNABORTED) 799 # Writing returns EPIPE, and reading returns EOF. 800 self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") 801 self.assertEqual("", self.accepted.recv(4096)) 802 self.assertEqual("", self.accepted.recv(4096)) 803 804 def testConnectInterrupted(self): 805 """Tests that connect() is interrupted by SOCK_DESTROY.""" 806 for version in [4, 5, 6]: 807 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 808 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 809 self.SelectInterface(s, self.netid, "mark") 810 811 remotesockaddr = self.GetRemoteSocketAddress(version) 812 remoteaddr = self.GetRemoteAddress(version) 813 s.bind(("", 0)) 814 _, sport = s.getsockname()[:2] 815 self.CloseDuringBlockingCall( 816 s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED) 817 desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), 818 remoteaddr, sport=sport, seq=None) 819 self.ExpectPacketOn(self.netid, desc, syn) 820 msg = "SOCK_DESTROY of socket in connect, expected no RST" 821 self.ExpectNoPacketsOn(self.netid, msg) 822 823 824class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 825 """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs. 826 827 The behaviour of poll() in these cases is not what we might expect: if only 828 POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT 829 is (also) specified, it will only return POLLOUT. 830 """ 831 832 POLLIN_OUT = select.POLLIN | select.POLLOUT 833 POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP 834 835 def setUp(self): 836 super(PollOnCloseTest, self).setUp() 837 self.netid = random.choice(list(self.tuns.keys())) 838 839 POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"), 840 (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")] 841 842 def PollResultToString(self, poll_events, ignoremask): 843 out = [] 844 for fd, event in poll_events: 845 flags = [name for (flag, name) in self.POLL_FLAGS 846 if event & flag & ~ignoremask != 0] 847 out.append((fd, "|".join(flags))) 848 return out 849 850 def BlockingPoll(self, sock, mask, expected, ignoremask): 851 p = select.poll() 852 p.register(sock, mask) 853 expected_fds = [(sock.fileno(), expected)] 854 # Don't block forever or we'll hang continuous test runs on failure. 855 # A 5-second timeout should be long enough not to be flaky. 856 actual_fds = p.poll(5000) 857 self.assertEqual(self.PollResultToString(expected_fds, ignoremask), 858 self.PollResultToString(actual_fds, ignoremask)) 859 860 def RstDuringBlockingCall(self, sock, call, expected_errno): 861 self._EventDuringBlockingCall( 862 sock, call, expected_errno, 863 lambda _: self.ReceiveRstPacketOn(self.netid)) 864 865 def assertSocketErrors(self, errno): 866 # The first operation returns the expected errno. 867 self.assertRaisesErrno(errno, self.accepted.recv, 4096) 868 869 # Subsequent operations behave as normal. 870 self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") 871 self.assertEqual("", self.accepted.recv(4096)) 872 self.assertEqual("", self.accepted.recv(4096)) 873 874 def CheckPollDestroy(self, mask, expected, ignoremask): 875 """Interrupts a poll() with SOCK_DESTROY.""" 876 for version in [4, 5, 6]: 877 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 878 self.CloseDuringBlockingCall( 879 self.accepted, 880 lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), 881 None) 882 self.assertSocketErrors(ECONNABORTED) 883 884 def CheckPollRst(self, mask, expected, ignoremask): 885 """Interrupts a poll() by receiving a TCP RST.""" 886 for version in [4, 5, 6]: 887 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 888 self.RstDuringBlockingCall( 889 self.accepted, 890 lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), 891 None) 892 self.assertSocketErrors(ECONNRESET) 893 894 def testReadPollRst(self): 895 # Until 3d4762639d ("tcp: remove poll() flakes when receiving RST"), poll() 896 # would sometimes return POLLERR and sometimes POLLIN|POLLERR|POLLHUP. This 897 # is due to a race inside the kernel and thus is not visible on the VM, only 898 # on physical hardware. 899 if net_test.LINUX_VERSION < (4, 14, 0): 900 ignoremask = select.POLLIN | select.POLLHUP 901 else: 902 ignoremask = 0 903 self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask) 904 905 def testWritePollRst(self): 906 self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0) 907 908 def testReadWritePollRst(self): 909 self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0) 910 911 def testReadPollDestroy(self): 912 # tcp_abort has the same race that tcp_reset has, but it's not fixed yet. 913 ignoremask = select.POLLIN | select.POLLHUP 914 self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask) 915 916 def testWritePollDestroy(self): 917 self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0) 918 919 def testReadWritePollDestroy(self): 920 self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0) 921 922 923@unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled") 924class SockDestroyUdpTest(SockDiagBaseTest): 925 926 """Tests SOCK_DESTROY on UDP sockets. 927 928 Relevant kernel commits: 929 upstream net-next: 930 5d77dca net: diag: support SOCK_DESTROY for UDP sockets 931 f95bf34 net: diag: make udp_diag_destroy work for mapped addresses. 932 """ 933 934 def testClosesUdpSockets(self): 935 self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM) 936 for _, socketpair in self.socketpairs.items(): 937 s1, s2 = socketpair 938 939 self.assertSocketConnected(s1) 940 self.sock_diag.CloseSocketFromFd(s1) 941 self.assertSocketClosed(s1) 942 943 self.assertSocketConnected(s2) 944 self.sock_diag.CloseSocketFromFd(s2) 945 self.assertSocketClosed(s2) 946 947 def BindToRandomPort(self, s, addr): 948 ATTEMPTS = 20 949 for i in range(20): 950 port = random.randrange(1024, 65535) 951 try: 952 s.bind((addr, port)) 953 return port 954 except error as e: 955 if e.errno != EADDRINUSE: 956 raise e 957 raise ValueError("Could not find a free port on %s after %d attempts" % 958 (addr, ATTEMPTS)) 959 960 def testSocketAddressesAfterClose(self): 961 for version in 4, 5, 6: 962 netid = random.choice(self.NETIDS) 963 dst = self.GetRemoteSocketAddress(version) 964 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 965 unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version] 966 967 # Closing a socket that was not explicitly bound (i.e., bound via 968 # connect(), not bind()) clears the source address and port. 969 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 970 self.SelectInterface(s, netid, "mark") 971 s.connect((dst, 53)) 972 self.sock_diag.CloseSocketFromFd(s) 973 self.assertEqual((unspec, 0), s.getsockname()[:2]) 974 975 # Closing a socket bound to an IP address leaves the address as is. 976 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 977 src = self.MySocketAddress(version, netid) 978 s.bind((src, 0)) 979 s.connect((dst, 53)) 980 port = s.getsockname()[1] 981 self.sock_diag.CloseSocketFromFd(s) 982 self.assertEqual((src, 0), s.getsockname()[:2]) 983 984 # Closing a socket bound to a port leaves the port as is. 985 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 986 port = self.BindToRandomPort(s, "") 987 s.connect((dst, 53)) 988 self.sock_diag.CloseSocketFromFd(s) 989 self.assertEqual((unspec, port), s.getsockname()[:2]) 990 991 # Closing a socket bound to IP address and port leaves both as is. 992 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 993 src = self.MySocketAddress(version, netid) 994 port = self.BindToRandomPort(s, src) 995 self.sock_diag.CloseSocketFromFd(s) 996 self.assertEqual((src, port), s.getsockname()[:2]) 997 998 def testReadInterrupted(self): 999 """Tests that read() is interrupted by SOCK_DESTROY.""" 1000 for version in [4, 5, 6]: 1001 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 1002 s = net_test.UDPSocket(family) 1003 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 1004 addr = self.GetRemoteSocketAddress(version) 1005 1006 # Check that reads on connected sockets are interrupted. 1007 s.connect((addr, 53)) 1008 self.assertEqual(3, s.send("foo")) 1009 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 1010 ECONNABORTED) 1011 1012 # A destroyed socket is no longer connected, but still usable. 1013 self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo") 1014 self.assertEqual(3, s.sendto("foo", (addr, 53))) 1015 1016 # Check that reads on unconnected sockets are also interrupted. 1017 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 1018 ECONNABORTED) 1019 1020class SockDestroyPermissionTest(SockDiagBaseTest): 1021 1022 def CheckPermissions(self, socktype): 1023 s = socket(AF_INET6, socktype, 0) 1024 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 1025 if socktype == SOCK_STREAM: 1026 s.listen(1) 1027 expectedstate = tcp_test.TCP_LISTEN 1028 else: 1029 s.connect((self.GetRemoteAddress(6), 53)) 1030 expectedstate = tcp_test.TCP_ESTABLISHED 1031 1032 with net_test.RunAsUid(12345): 1033 self.assertRaisesErrno( 1034 EPERM, self.sock_diag.CloseSocketFromFd, s) 1035 1036 self.sock_diag.CloseSocketFromFd(s) 1037 self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s) 1038 1039 1040 @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled") 1041 def testUdp(self): 1042 self.CheckPermissions(SOCK_DGRAM) 1043 1044 def testTcp(self): 1045 self.CheckPermissions(SOCK_STREAM) 1046 1047 1048class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 1049 1050 """Tests SOCK_DIAG bytecode filters that use marks. 1051 1052 Relevant kernel commits: 1053 upstream net-next: 1054 627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks. 1055 a52e95a net: diag: allow socket bytecode filters to match socket marks 1056 d545cac net: inet: diag: expose the socket mark to privileged processes. 1057 """ 1058 1059 def FilterEstablishedSockets(self, mark, mask): 1060 instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))] 1061 bytecode = self.sock_diag.PackBytecode(instructions) 1062 return self.sock_diag.DumpAllInetSockets( 1063 IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED)) 1064 1065 def assertSamePorts(self, ports, diag_msgs): 1066 expected = sorted(ports) 1067 actual = sorted([msg[0].id.sport for msg in diag_msgs]) 1068 self.assertEqual(expected, actual) 1069 1070 def SockInfoMatchesSocket(self, s, info): 1071 try: 1072 self.assertSockInfoMatchesSocket(s, info) 1073 return True 1074 except AssertionError: 1075 return False 1076 1077 @staticmethod 1078 def SocketDescription(s): 1079 return "%s -> %s" % (str(s.getsockname()), str(s.getpeername())) 1080 1081 def assertFoundSockets(self, infos, sockets): 1082 matches = {} 1083 for s in sockets: 1084 match = None 1085 for info in infos: 1086 if self.SockInfoMatchesSocket(s, info): 1087 if match: 1088 self.fail("Socket %s matched both %s and %s" % 1089 (self.SocketDescription(s), match, info)) 1090 matches[s] = info 1091 self.assertTrue(s in matches, "Did not find socket %s in dump" % 1092 self.SocketDescription(s)) 1093 1094 for i in infos: 1095 if i not in list(matches.values()): 1096 self.fail("Too many sockets in dump, first unexpected: %s" % str(i)) 1097 1098 def testMarkBytecode(self): 1099 family, addr = random.choice([ 1100 (AF_INET, "127.0.0.1"), 1101 (AF_INET6, "::1"), 1102 (AF_INET6, "::ffff:127.0.0.1")]) 1103 s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr) 1104 s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234) 1105 s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235) 1106 1107 infos = self.FilterEstablishedSockets(0x1234, 0xffff) 1108 self.assertFoundSockets(infos, [s1]) 1109 1110 infos = self.FilterEstablishedSockets(0x1234, 0xfffe) 1111 self.assertFoundSockets(infos, [s1, s2]) 1112 1113 infos = self.FilterEstablishedSockets(0x1235, 0xffff) 1114 self.assertFoundSockets(infos, [s2]) 1115 1116 infos = self.FilterEstablishedSockets(0x0, 0x0) 1117 self.assertFoundSockets(infos, [s1, s2]) 1118 1119 infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00) 1120 self.assertEqual(0, len(infos)) 1121 1122 with net_test.RunAsUid(12345): 1123 self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets, 1124 0xfff0000, 0xf0fed00) 1125 1126 @staticmethod 1127 def SetRandomMark(s): 1128 # Python doesn't like marks that don't fit into a signed int. 1129 mark = random.randrange(0, 2**31 - 1) 1130 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark) 1131 return mark 1132 1133 def assertSocketMarkIs(self, s, mark): 1134 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 1135 self.assertMarkIs(mark, attrs) 1136 with net_test.RunAsUid(12345): 1137 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 1138 self.assertMarkIs(None, attrs) 1139 1140 def testMarkInAttributes(self): 1141 testcases = [(AF_INET, "127.0.0.1"), 1142 (AF_INET6, "::1"), 1143 (AF_INET6, "::ffff:127.0.0.1")] 1144 for family, addr in testcases: 1145 # TCP listen sockets. 1146 server = socket(family, SOCK_STREAM, 0) 1147 server.bind((addr, 0)) 1148 port = server.getsockname()[1] 1149 server.listen(1) # Or the socket won't be in the hashtables. 1150 server_mark = self.SetRandomMark(server) 1151 self.assertSocketMarkIs(server, server_mark) 1152 1153 # TCP client sockets. 1154 client = socket(family, SOCK_STREAM, 0) 1155 client_mark = self.SetRandomMark(client) 1156 client.connect((addr, port)) 1157 self.assertSocketMarkIs(client, client_mark) 1158 1159 # TCP server sockets. 1160 accepted, _ = server.accept() 1161 self.assertSocketMarkIs(accepted, server_mark) 1162 1163 accepted_mark = self.SetRandomMark(accepted) 1164 self.assertSocketMarkIs(accepted, accepted_mark) 1165 self.assertSocketMarkIs(server, server_mark) 1166 1167 server.close() 1168 client.close() 1169 1170 # Other TCP states are tested in SockDestroyTcpTest. 1171 1172 # UDP sockets. 1173 if HAVE_UDP_DIAG: 1174 s = socket(family, SOCK_DGRAM, 0) 1175 mark = self.SetRandomMark(s) 1176 s.connect(("", 53)) 1177 self.assertSocketMarkIs(s, mark) 1178 s.close() 1179 1180 # Basic test for SCTP. sctp_diag was only added in 4.7. 1181 if HAVE_SCTP: 1182 s = socket(family, SOCK_STREAM, IPPROTO_SCTP) 1183 s.bind((addr, 0)) 1184 s.listen(1) 1185 mark = self.SetRandomMark(s) 1186 self.assertSocketMarkIs(s, mark) 1187 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE) 1188 self.assertEqual(1, len(sockets)) 1189 self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None)) 1190 s.close() 1191 1192 1193if __name__ == "__main__": 1194 unittest.main() 1195