1#!/usr/bin/python 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import 18from errno import * # pylint: disable=wildcard-import 19import os 20import random 21import select 22from socket import * # pylint: disable=wildcard-import 23import struct 24import threading 25import time 26import unittest 27 28import cstruct 29import multinetwork_base 30import net_test 31import packets 32import sock_diag 33import tcp_test 34 35# Mostly empty structure definition containing only the fields we currently use. 36TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh") 37 38NUM_SOCKETS = 30 39NO_BYTECODE = "" 40LINUX_4_9_OR_ABOVE = net_test.LINUX_VERSION >= (4, 9, 0) 41LINUX_4_19_OR_ABOVE = net_test.LINUX_VERSION >= (4, 19, 0) 42 43IPPROTO_SCTP = 132 44 45def HaveUdpDiag(): 46 """Checks if the current kernel has config CONFIG_INET_UDP_DIAG enabled. 47 48 This config is required for device running 4.9 kernel that ship with P, In 49 this case always assume the config is there and use the tests to check if the 50 config is enabled as required. 51 52 For all ther other kernel version, there is no way to tell whether a dump 53 succeeded: if the appropriate handler wasn't found, __inet_diag_dump just 54 returns an empty result instead of an error. So, just check to see if a UDP 55 dump returns no sockets when we know it should return one. If not, some tests 56 will be skipped. 57 58 Returns: 59 True if the kernel is 4.9 or above, or the CONFIG_INET_UDP_DIAG is enabled. 60 False otherwise. 61 """ 62 if LINUX_4_9_OR_ABOVE: 63 return True; 64 s = socket(AF_INET6, SOCK_DGRAM, 0) 65 s.bind(("::", 0)) 66 s.connect((s.getsockname())) 67 sd = sock_diag.SockDiag() 68 have_udp_diag = len(sd.DumpAllInetSockets(IPPROTO_UDP, "")) > 0 69 s.close() 70 return have_udp_diag 71 72def HaveSctp(): 73 if net_test.LINUX_VERSION < (4, 7, 0): 74 return False 75 try: 76 s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP) 77 s.close() 78 return True 79 except IOError: 80 return False 81 82HAVE_UDP_DIAG = HaveUdpDiag() 83HAVE_SCTP = HaveSctp() 84 85 86class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): 87 """Basic tests for SOCK_DIAG functionality. 88 89 Relevant kernel commits: 90 android-3.4: 91 ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields 92 99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 93 94 android-3.10: 95 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields 96 f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 97 98 android-3.18: 99 e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 100 101 android-4.4: 102 525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 103 """ 104 @staticmethod 105 def _CreateLotsOfSockets(socktype): 106 # Dict mapping (addr, sport, dport) tuples to socketpairs. 107 socketpairs = {} 108 for _ in range(NUM_SOCKETS): 109 family, addr = random.choice([ 110 (AF_INET, "127.0.0.1"), 111 (AF_INET6, "::1"), 112 (AF_INET6, "::ffff:127.0.0.1")]) 113 socketpair = net_test.CreateSocketPair(family, socktype, addr) 114 sport, dport = (socketpair[0].getsockname()[1], 115 socketpair[1].getsockname()[1]) 116 socketpairs[(addr, sport, dport)] = socketpair 117 return socketpairs 118 119 def assertSocketClosed(self, sock): 120 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 121 122 def assertSocketConnected(self, sock): 123 sock.getpeername() # No errors? Socket is alive and connected. 124 125 def assertSocketsClosed(self, socketpair): 126 for sock in socketpair: 127 self.assertSocketClosed(sock) 128 129 def assertMarkIs(self, mark, attrs): 130 self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None)) 131 132 def assertSockInfoMatchesSocket(self, s, info): 133 diag_msg, attrs = info 134 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 135 self.assertEqual(diag_msg.family, family) 136 137 src, sport = s.getsockname()[0:2] 138 self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) 139 self.assertEqual(diag_msg.id.sport, sport) 140 141 if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: 142 dst, dport = s.getpeername()[0:2] 143 self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) 144 self.assertEqual(diag_msg.id.dport, dport) 145 else: 146 self.assertRaisesErrno(ENOTCONN, s.getpeername) 147 148 mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 149 self.assertMarkIs(mark, attrs) 150 151 def PackAndCheckBytecode(self, instructions): 152 bytecode = self.sock_diag.PackBytecode(instructions) 153 decoded = self.sock_diag.DecodeBytecode(bytecode) 154 self.assertEqual(len(instructions), len(decoded)) 155 self.assertFalse("???" in decoded) 156 return bytecode 157 158 def _EventDuringBlockingCall(self, sock, call, expected_errno, event): 159 """Simulates an external event during a blocking call on sock. 160 161 Args: 162 sock: The socket to use. 163 call: A function, the call to make. Takes one parameter, sock. 164 expected_errno: The value that call is expected to fail with, or None if 165 call is expected to succeed. 166 event: A function, the event that will happen during the blocking call. 167 Takes one parameter, sock. 168 """ 169 thread = SocketExceptionThread(sock, call) 170 thread.start() 171 time.sleep(0.1) 172 event(sock) 173 thread.join(1) 174 self.assertFalse(thread.is_alive()) 175 if expected_errno is not None: 176 self.assertIsNotNone(thread.exception) 177 self.assertTrue(isinstance(thread.exception, IOError), 178 "Expected IOError, got %s" % thread.exception) 179 self.assertEqual(expected_errno, thread.exception.errno) 180 else: 181 self.assertIsNone(thread.exception) 182 self.assertSocketClosed(sock) 183 184 def CloseDuringBlockingCall(self, sock, call, expected_errno): 185 self._EventDuringBlockingCall( 186 sock, call, expected_errno, 187 lambda sock: self.sock_diag.CloseSocketFromFd(sock)) 188 189 def setUp(self): 190 super(SockDiagBaseTest, self).setUp() 191 self.sock_diag = sock_diag.SockDiag() 192 self.socketpairs = {} 193 194 def tearDown(self): 195 for socketpair in list(self.socketpairs.values()): 196 for s in socketpair: 197 s.close() 198 super(SockDiagBaseTest, self).tearDown() 199 200 201class SockDiagTest(SockDiagBaseTest): 202 203 def testFindsMappedSockets(self): 204 """Tests that inet_diag_find_one_icsk can find mapped sockets.""" 205 socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 206 "::ffff:127.0.0.1") 207 for sock in socketpair: 208 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 209 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 210 self.sock_diag.GetSockInfo(diag_req) 211 # No errors? Good. 212 213 def CheckFindsAllMySockets(self, socktype, proto): 214 """Tests that basic socket dumping works.""" 215 self.socketpairs = self._CreateLotsOfSockets(socktype) 216 sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE) 217 self.assertGreaterEqual(len(sockets), NUM_SOCKETS) 218 219 # Find the cookies for all of our sockets. 220 cookies = {} 221 for diag_msg, unused_attrs in sockets: 222 addr = self.sock_diag.GetSourceAddress(diag_msg) 223 sport = diag_msg.id.sport 224 dport = diag_msg.id.dport 225 if (addr, sport, dport) in self.socketpairs: 226 cookies[(addr, sport, dport)] = diag_msg.id.cookie 227 elif (addr, dport, sport) in self.socketpairs: 228 cookies[(addr, sport, dport)] = diag_msg.id.cookie 229 230 # Did we find all the cookies? 231 self.assertEqual(2 * NUM_SOCKETS, len(cookies)) 232 233 socketpairs = list(self.socketpairs.values()) 234 random.shuffle(socketpairs) 235 for socketpair in socketpairs: 236 for sock in socketpair: 237 # Check that we can find a diag_msg by scanning a dump. 238 self.assertSockInfoMatchesSocket( 239 sock, 240 self.sock_diag.FindSockInfoFromFd(sock)) 241 cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie 242 243 # Check that we can find a diag_msg once we know the cookie. 244 req = self.sock_diag.DiagReqFromSocket(sock) 245 req.id.cookie = cookie 246 if proto == IPPROTO_UDP: 247 # Kernel bug: for UDP sockets, the order of arguments must be swapped. 248 # See testDemonstrateUdpGetSockIdBug. 249 req.id.sport, req.id.dport = req.id.dport, req.id.sport 250 req.id.src, req.id.dst = req.id.dst, req.id.src 251 info = self.sock_diag.GetSockInfo(req) 252 self.assertSockInfoMatchesSocket(sock, info) 253 254 def testFindsAllMySocketsTcp(self): 255 self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP) 256 257 @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled") 258 def testFindsAllMySocketsUdp(self): 259 self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP) 260 261 def testBytecodeCompilation(self): 262 # pylint: disable=bad-whitespace 263 instructions = [ 264 (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 265 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 266 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 267 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 268 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 269 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 270 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 271 # 76 acc 272 # 80 rej 273 ] 274 # pylint: enable=bad-whitespace 275 bytecode = self.PackAndCheckBytecode(instructions) 276 expected = ( 277 "0208500000000000" 278 "050848000000ffff" 279 "071c20000a800000ffffffff00000000000000000000000000000001" 280 "01041c00" 281 "0718200002200000ffffffff7f000001" 282 "0508100000006566" 283 "00040400" 284 ) 285 states = 1 << tcp_test.TCP_ESTABLISHED 286 self.assertMultiLineEqual(expected, bytecode.encode("hex")) 287 self.assertEqual(76, len(bytecode)) 288 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 289 filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode, 290 states=states) 291 allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, 292 states=states) 293 self.assertItemsEqual(allsockets, filteredsockets) 294 295 # Pick a few sockets in hash table order, and check that the bytecode we 296 # compiled selects them properly. 297 for socketpair in list(self.socketpairs.values())[:20]: 298 for s in socketpair: 299 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 300 instructions = [ 301 (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), 302 (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), 303 (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), 304 (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), 305 ] 306 bytecode = self.PackAndCheckBytecode(instructions) 307 self.assertEqual(32, len(bytecode)) 308 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) 309 self.assertEqual(1, len(sockets)) 310 311 # TODO: why doesn't comparing the cstructs work? 312 self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack()) 313 314 def testCrossFamilyBytecode(self): 315 """Checks for a cross-family bug in inet_diag_hostcond matching. 316 317 Relevant kernel commits: 318 android-3.4: 319 f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() 320 """ 321 # TODO: this is only here because the test fails if there are any open 322 # sockets other than the ones it creates itself. Make the bytecode more 323 # specific and remove it. 324 states = 1 << tcp_test.TCP_ESTABLISHED 325 self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "", 326 states=states)) 327 328 unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") 329 unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") 330 331 bytecode4 = self.PackAndCheckBytecode([ 332 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) 333 bytecode6 = self.PackAndCheckBytecode([ 334 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) 335 336 # IPv4/v6 filters must never match IPv6/IPv4 sockets... 337 v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4, 338 states=states) 339 self.assertTrue(v4socks) 340 self.assertTrue(all(d.family == AF_INET for d, _ in v4socks)) 341 342 v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6, 343 states=states) 344 self.assertTrue(v6socks) 345 self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks)) 346 347 # Except for mapped addresses, which match both IPv4 and IPv6. 348 pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 349 "::ffff:127.0.0.1") 350 diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] 351 v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 352 bytecode4, 353 states=states)] 354 v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 355 bytecode6, 356 states=states)] 357 self.assertTrue(all(d in v4socks for d in diag_msgs)) 358 self.assertTrue(all(d in v6socks for d in diag_msgs)) 359 360 def testPortComparisonValidation(self): 361 """Checks for a bug in validating port comparison bytecode. 362 363 Relevant kernel commits: 364 android-3.4: 365 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads 366 """ 367 bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) 368 self.assertEqual("???", 369 self.sock_diag.DecodeBytecode(bytecode)) 370 self.assertRaisesErrno( 371 EINVAL, 372 self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack()) 373 374 def testNonSockDiagCommand(self): 375 def DiagDump(code): 376 sock_id = self.sock_diag._EmptyInetDiagSockId() 377 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, 378 sock_id)) 379 self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "") 380 381 op = sock_diag.SOCK_DIAG_BY_FAMILY 382 DiagDump(op) # No errors? Good. 383 self.assertRaisesErrno(EINVAL, DiagDump, op + 17) 384 385 def CheckSocketCookie(self, inet, addr): 386 """Tests that getsockopt SO_COOKIE can get cookie for all sockets.""" 387 socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr) 388 for sock in socketpair: 389 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 390 cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 391 self.assertEqual(diag_msg.id.cookie, cookie) 392 393 @unittest.skipUnless(LINUX_4_9_OR_ABOVE, "SO_COOKIE not supported") 394 def testGetsockoptcookie(self): 395 self.CheckSocketCookie(AF_INET, "127.0.0.1") 396 self.CheckSocketCookie(AF_INET6, "::1") 397 398 @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled") 399 def testDemonstrateUdpGetSockIdBug(self): 400 # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup 401 # by passing the source address as the source address argument. 402 # Unfortunately those functions are intended to match local sockets based 403 # on received packets, and the argument that ends up being compared with 404 # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not 405 # have this bug. Upstream has confirmed that this will not be fixed: 406 # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html 407 """Documents a bug: getting UDP sockets requires swapping src and dst.""" 408 for version in [4, 5, 6]: 409 family = net_test.GetAddressFamily(version) 410 s = socket(family, SOCK_DGRAM, 0) 411 self.SelectInterface(s, self.RandomNetid(), "mark") 412 s.connect((self.GetRemoteSocketAddress(version), 53)) 413 414 # Create a fully-specified diag req from our socket, including cookie if 415 # we can get it. 416 req = self.sock_diag.DiagReqFromSocket(s) 417 if LINUX_4_9_OR_ABOVE: 418 req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 419 else: 420 req.id.cookie = "\xff" * 16 # INET_DIAG_NOCOOKIE[2] 421 422 # As is, this request does not find anything. 423 with self.assertRaisesErrno(ENOENT): 424 self.sock_diag.GetSockInfo(req) 425 426 # But if we swap src and dst, the kernel finds our socket. 427 req.id.sport, req.id.dport = req.id.dport, req.id.sport 428 req.id.src, req.id.dst = req.id.dst, req.id.src 429 430 self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req)) 431 432 433class SockDestroyTest(SockDiagBaseTest): 434 """Tests that SOCK_DESTROY works correctly. 435 436 Relevant kernel commits: 437 net-next: 438 b613f56 net: diag: split inet_diag_dump_one_icsk into two 439 64be0ae net: diag: Add the ability to destroy a socket. 440 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets. 441 c1e64e2 net: diag: Support destroying TCP sockets. 442 2010b93 net: tcp: deal with listen sockets properly in tcp_abort. 443 444 android-3.4: 445 d48ec88 net: diag: split inet_diag_dump_one_icsk into two 446 2438189 net: diag: Add the ability to destroy a socket. 447 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets. 448 44047b2 net: diag: Support destroying TCP sockets. 449 200dae7 net: tcp: deal with listen sockets properly in tcp_abort. 450 451 android-3.10: 452 9eaff90 net: diag: split inet_diag_dump_one_icsk into two 453 d60326c net: diag: Add the ability to destroy a socket. 454 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets. 455 529dfc6 net: diag: Support destroying TCP sockets. 456 9c712fe net: tcp: deal with listen sockets properly in tcp_abort. 457 458 android-3.18: 459 100263d net: diag: split inet_diag_dump_one_icsk into two 460 194c5f3 net: diag: Add the ability to destroy a socket. 461 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets. 462 b80585a net: diag: Support destroying TCP sockets. 463 476c6ce net: tcp: deal with listen sockets properly in tcp_abort. 464 465 android-4.1: 466 56eebf8 net: diag: split inet_diag_dump_one_icsk into two 467 fb486c9 net: diag: Add the ability to destroy a socket. 468 0c02b7e net: diag: Support SOCK_DESTROY for inet sockets. 469 67c71d8 net: diag: Support destroying TCP sockets. 470 a76e0ec net: tcp: deal with listen sockets properly in tcp_abort. 471 e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 472 473 android-4.4: 474 76c83a9 net: diag: split inet_diag_dump_one_icsk into two 475 f7cf791 net: diag: Add the ability to destroy a socket. 476 1c42248 net: diag: Support SOCK_DESTROY for inet sockets. 477 c9e8440d net: diag: Support destroying TCP sockets. 478 3d9502c tcp: diag: add support for request sockets to tcp_abort() 479 001cf75 net: tcp: deal with listen sockets properly in tcp_abort. 480 """ 481 482 def testClosesSockets(self): 483 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 484 for _, socketpair in self.socketpairs.items(): 485 # Close one of the sockets. 486 # This will send a RST that will close the other side as well. 487 s = random.choice(socketpair) 488 if random.randrange(0, 2) == 1: 489 self.sock_diag.CloseSocketFromFd(s) 490 else: 491 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 492 493 # Get the cookie wrong and ensure that we get an error and the socket 494 # is not closed. 495 real_cookie = diag_msg.id.cookie 496 diag_msg.id.cookie = os.urandom(len(real_cookie)) 497 req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 498 self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) 499 self.assertSocketConnected(s) 500 501 # Now close it with the correct cookie. 502 req.id.cookie = real_cookie 503 self.sock_diag.CloseSocket(req) 504 505 # Check that both sockets in the pair are closed. 506 self.assertSocketsClosed(socketpair) 507 508 # TODO: 509 # Test that killing unix sockets returns EOPNOTSUPP. 510 511 512class SocketExceptionThread(threading.Thread): 513 514 def __init__(self, sock, operation): 515 self.exception = None 516 super(SocketExceptionThread, self).__init__() 517 self.daemon = True 518 self.sock = sock 519 self.operation = operation 520 521 def run(self): 522 try: 523 self.operation(self.sock) 524 except (IOError, AssertionError) as e: 525 self.exception = e 526 527 528class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 529 530 def testIpv4MappedSynRecvSocket(self): 531 """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. 532 533 Relevant kernel commits: 534 android-3.4: 535 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state 536 """ 537 netid = random.choice(list(self.tuns.keys())) 538 self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid) 539 sock_id = self.sock_diag._EmptyInetDiagSockId() 540 sock_id.sport = self.port 541 states = 1 << tcp_test.TCP_SYN_RECV 542 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) 543 children = self.sock_diag.Dump(req, NO_BYTECODE) 544 545 self.assertTrue(children) 546 for child, unused_args in children: 547 self.assertEqual(tcp_test.TCP_SYN_RECV, child.state) 548 self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr), 549 child.id.dst) 550 self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr), 551 child.id.src) 552 553 554class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 555 556 RWND_SIZE = 64000 if LINUX_4_19_OR_ABOVE else 42000 557 TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd" 558 559 def setUp(self): 560 super(TcpRcvWindowTest, self).setUp() 561 if LINUX_4_19_OR_ABOVE: 562 self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w") 563 return 564 565 f = open(self.TCP_DEFAULT_INIT_RWND, "w") 566 f.write("60") 567 568 def checkInitRwndSize(self, version, netid): 569 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid) 570 tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP, 571 net_test.TCP_INFO, len(TcpInfo))) 572 self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh, 573 "Tcp rwnd of netid=%d, version=%d is not enough. " 574 "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE, 575 tcpInfo.tcpi_rcv_ssthresh)) 576 577 def checkSynPacketWindowSize(self, version, netid): 578 s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark") 579 myaddr = self.MyAddress(version, netid) 580 dstaddr = self.GetRemoteAddress(version) 581 dstsockaddr = self.GetRemoteSocketAddress(version) 582 desc, expected = packets.SYN(53, version, myaddr, dstaddr, 583 sport=None, seq=None) 584 self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53)) 585 msg = "IPv%s TCP connect: expected %s on %s" % ( 586 version, desc, self.GetInterfaceName(netid)) 587 syn = self.ExpectPacketOn(netid, msg, expected) 588 self.assertLess(self.RWND_SIZE, syn.window) 589 s.close() 590 591 def testTcpCwndSize(self): 592 for version in [4, 5, 6]: 593 for netid in self.NETIDS: 594 self.checkInitRwndSize(version, netid) 595 self.checkSynPacketWindowSize(version, netid) 596 597 598class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 599 600 def setUp(self): 601 super(SockDestroyTcpTest, self).setUp() 602 self.netid = random.choice(list(self.tuns.keys())) 603 604 def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): 605 """Closes the socket and checks whether a RST is sent or not.""" 606 if sock is not None: 607 self.assertIsNone(req, "Must specify sock or req, not both") 608 self.sock_diag.CloseSocketFromFd(sock) 609 self.assertRaisesErrno(EINVAL, sock.accept) 610 else: 611 self.assertIsNone(sock, "Must specify sock or req, not both") 612 self.sock_diag.CloseSocket(req) 613 614 if expect_reset: 615 desc, rst = self.RstPacket() 616 msg = "%s: expecting %s: " % (msg, desc) 617 self.ExpectPacketOn(self.netid, msg, rst) 618 else: 619 msg = "%s: " % msg 620 self.ExpectNoPacketsOn(self.netid, msg) 621 622 if sock is not None and do_close: 623 sock.close() 624 625 def CheckTcpReset(self, state, statename): 626 for version in [4, 5, 6]: 627 msg = "Closing incoming IPv%d %s socket" % (version, statename) 628 self.IncomingConnection(version, state, self.netid) 629 self.CheckRstOnClose(self.s, None, False, msg) 630 if state != tcp_test.TCP_LISTEN: 631 msg = "Closing accepted IPv%d %s socket" % (version, statename) 632 self.CheckRstOnClose(self.accepted, None, True, msg) 633 634 def testTcpResets(self): 635 """Checks that closing sockets in appropriate states sends a RST.""" 636 self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN") 637 self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED") 638 self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") 639 640 def testFinWait1Socket(self): 641 for version in [4, 5, 6]: 642 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 643 644 # Get the cookie so we can find this socket after we close it. 645 diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted) 646 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 647 648 # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN. 649 net_test.EnableFinWait(self.accepted) 650 self.accepted.close() 651 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1 652 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 653 self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 654 desc, fin = self.FinPacket() 655 self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin) 656 657 # Destroy the socket and expect no RST. 658 self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket") 659 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 660 661 # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing 662 # because userspace had already closed it. 663 self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 664 665 # ACK the FIN so we don't trip over retransmits in future tests. 666 finversion = 4 if version == 5 else version 667 desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin) 668 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 669 self.ReceivePacketOn(self.netid, finack) 670 671 # See if we can find the resulting FIN_WAIT2 socket. This does not appear 672 # to work on 3.10. 673 if net_test.LINUX_VERSION >= (3, 18): 674 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2 675 infos = self.sock_diag.Dump(diag_req, "") 676 self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2 677 for diag_msg, attrs in infos), 678 "Expected to find FIN_WAIT2 socket in %s" % infos) 679 680 def FindChildSockets(self, s): 681 """Finds the SYN_RECV child sockets of a given listening socket.""" 682 d = self.sock_diag.FindSockDiagFromFd(self.s) 683 req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 684 req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED 685 req.id.cookie = "\x00" * 8 686 687 bad_bytecode = self.PackAndCheckBytecode( 688 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))]) 689 self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode)) 690 691 bytecode = self.PackAndCheckBytecode( 692 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))]) 693 children = self.sock_diag.Dump(req, bytecode) 694 return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 695 for d, _ in children] 696 697 def CheckChildSocket(self, version, statename, parent_first): 698 state = getattr(tcp_test, statename) 699 700 self.IncomingConnection(version, state, self.netid) 701 702 d = self.sock_diag.FindSockDiagFromFd(self.s) 703 parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 704 children = self.FindChildSockets(self.s) 705 self.assertEqual(1, len(children)) 706 707 is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED) 708 expected_state = tcp_test.TCP_ESTABLISHED if is_established else state 709 710 # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the 711 # regular TCP hash tables, and inet_diag_find_one_icsk can find them. 712 # Before 4.4, we can see those sockets in dumps, but we can't fetch 713 # or close them. 714 can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4) 715 716 for child in children: 717 if can_close_children: 718 diag_msg, attrs = self.sock_diag.GetSockInfo(child) 719 self.assertEqual(diag_msg.state, expected_state) 720 self.assertMarkIs(self.netid, attrs) 721 else: 722 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 723 724 def CloseParent(expect_reset): 725 msg = "Closing parent IPv%d %s socket %s child" % ( 726 version, statename, "before" if parent_first else "after") 727 self.CheckRstOnClose(self.s, None, expect_reset, msg) 728 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent) 729 730 def CheckChildrenClosed(): 731 for child in children: 732 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 733 734 def CloseChildren(): 735 for child in children: 736 msg = "Closing child IPv%d %s socket %s parent" % ( 737 version, statename, "after" if parent_first else "before") 738 self.sock_diag.GetSockInfo(child) 739 self.CheckRstOnClose(None, child, is_established, msg) 740 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 741 CheckChildrenClosed() 742 743 if parent_first: 744 # Closing the parent will close child sockets, which will send a RST, 745 # iff they are already established. 746 CloseParent(is_established) 747 if is_established: 748 CheckChildrenClosed() 749 elif can_close_children: 750 CloseChildren() 751 CheckChildrenClosed() 752 self.s.close() 753 else: 754 if can_close_children: 755 CloseChildren() 756 CloseParent(False) 757 self.s.close() 758 759 def testChildSockets(self): 760 for version in [4, 5, 6]: 761 self.CheckChildSocket(version, "TCP_SYN_RECV", False) 762 self.CheckChildSocket(version, "TCP_SYN_RECV", True) 763 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False) 764 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True) 765 766 def testAcceptInterrupted(self): 767 """Tests that accept() is interrupted by SOCK_DESTROY.""" 768 for version in [4, 5, 6]: 769 self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid) 770 self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096) 771 self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) 772 self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") 773 self.assertRaisesErrno(EINVAL, self.s.accept) 774 # TODO: this should really return an error such as ENOTCONN... 775 self.assertEqual("", self.s.recv(4096)) 776 777 def testReadInterrupted(self): 778 """Tests that read() is interrupted by SOCK_DESTROY.""" 779 for version in [4, 5, 6]: 780 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 781 self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), 782 ECONNABORTED) 783 # Writing returns EPIPE, and reading returns EOF. 784 self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") 785 self.assertEqual("", self.accepted.recv(4096)) 786 self.assertEqual("", self.accepted.recv(4096)) 787 788 def testConnectInterrupted(self): 789 """Tests that connect() is interrupted by SOCK_DESTROY.""" 790 for version in [4, 5, 6]: 791 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 792 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 793 self.SelectInterface(s, self.netid, "mark") 794 795 remotesockaddr = self.GetRemoteSocketAddress(version) 796 remoteaddr = self.GetRemoteAddress(version) 797 s.bind(("", 0)) 798 _, sport = s.getsockname()[:2] 799 self.CloseDuringBlockingCall( 800 s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED) 801 desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), 802 remoteaddr, sport=sport, seq=None) 803 self.ExpectPacketOn(self.netid, desc, syn) 804 msg = "SOCK_DESTROY of socket in connect, expected no RST" 805 self.ExpectNoPacketsOn(self.netid, msg) 806 807 808class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 809 """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs. 810 811 The behaviour of poll() in these cases is not what we might expect: if only 812 POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT 813 is (also) specified, it will only return POLLOUT. 814 """ 815 816 POLLIN_OUT = select.POLLIN | select.POLLOUT 817 POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP 818 819 def setUp(self): 820 super(PollOnCloseTest, self).setUp() 821 self.netid = random.choice(list(self.tuns.keys())) 822 823 POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"), 824 (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")] 825 826 def PollResultToString(self, poll_events, ignoremask): 827 out = [] 828 for fd, event in poll_events: 829 flags = [name for (flag, name) in self.POLL_FLAGS 830 if event & flag & ~ignoremask != 0] 831 out.append((fd, "|".join(flags))) 832 return out 833 834 def BlockingPoll(self, sock, mask, expected, ignoremask): 835 p = select.poll() 836 p.register(sock, mask) 837 expected_fds = [(sock.fileno(), expected)] 838 # Don't block forever or we'll hang continuous test runs on failure. 839 # A 5-second timeout should be long enough not to be flaky. 840 actual_fds = p.poll(5000) 841 self.assertEqual(self.PollResultToString(expected_fds, ignoremask), 842 self.PollResultToString(actual_fds, ignoremask)) 843 844 def RstDuringBlockingCall(self, sock, call, expected_errno): 845 self._EventDuringBlockingCall( 846 sock, call, expected_errno, 847 lambda _: self.ReceiveRstPacketOn(self.netid)) 848 849 def assertSocketErrors(self, errno): 850 # The first operation returns the expected errno. 851 self.assertRaisesErrno(errno, self.accepted.recv, 4096) 852 853 # Subsequent operations behave as normal. 854 self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") 855 self.assertEqual("", self.accepted.recv(4096)) 856 self.assertEqual("", self.accepted.recv(4096)) 857 858 def CheckPollDestroy(self, mask, expected, ignoremask): 859 """Interrupts a poll() with SOCK_DESTROY.""" 860 for version in [4, 5, 6]: 861 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 862 self.CloseDuringBlockingCall( 863 self.accepted, 864 lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), 865 None) 866 self.assertSocketErrors(ECONNABORTED) 867 868 def CheckPollRst(self, mask, expected, ignoremask): 869 """Interrupts a poll() by receiving a TCP RST.""" 870 for version in [4, 5, 6]: 871 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 872 self.RstDuringBlockingCall( 873 self.accepted, 874 lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), 875 None) 876 self.assertSocketErrors(ECONNRESET) 877 878 def testReadPollRst(self): 879 # Until 3d4762639d ("tcp: remove poll() flakes when receiving RST"), poll() 880 # would sometimes return POLLERR and sometimes POLLIN|POLLERR|POLLHUP. This 881 # is due to a race inside the kernel and thus is not visible on the VM, only 882 # on physical hardware. 883 if net_test.LINUX_VERSION < (4, 14, 0): 884 ignoremask = select.POLLIN | select.POLLHUP 885 else: 886 ignoremask = 0 887 self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask) 888 889 def testWritePollRst(self): 890 self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0) 891 892 def testReadWritePollRst(self): 893 self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0) 894 895 def testReadPollDestroy(self): 896 # tcp_abort has the same race that tcp_reset has, but it's not fixed yet. 897 ignoremask = select.POLLIN | select.POLLHUP 898 self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask) 899 900 def testWritePollDestroy(self): 901 self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0) 902 903 def testReadWritePollDestroy(self): 904 self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0) 905 906 907@unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled") 908class SockDestroyUdpTest(SockDiagBaseTest): 909 910 """Tests SOCK_DESTROY on UDP sockets. 911 912 Relevant kernel commits: 913 upstream net-next: 914 5d77dca net: diag: support SOCK_DESTROY for UDP sockets 915 f95bf34 net: diag: make udp_diag_destroy work for mapped addresses. 916 """ 917 918 def testClosesUdpSockets(self): 919 self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM) 920 for _, socketpair in self.socketpairs.items(): 921 s1, s2 = socketpair 922 923 self.assertSocketConnected(s1) 924 self.sock_diag.CloseSocketFromFd(s1) 925 self.assertSocketClosed(s1) 926 927 self.assertSocketConnected(s2) 928 self.sock_diag.CloseSocketFromFd(s2) 929 self.assertSocketClosed(s2) 930 931 def BindToRandomPort(self, s, addr): 932 ATTEMPTS = 20 933 for i in range(20): 934 port = random.randrange(1024, 65535) 935 try: 936 s.bind((addr, port)) 937 return port 938 except error as e: 939 if e.errno != EADDRINUSE: 940 raise e 941 raise ValueError("Could not find a free port on %s after %d attempts" % 942 (addr, ATTEMPTS)) 943 944 def testSocketAddressesAfterClose(self): 945 for version in 4, 5, 6: 946 netid = random.choice(self.NETIDS) 947 dst = self.GetRemoteSocketAddress(version) 948 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 949 unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version] 950 951 # Closing a socket that was not explicitly bound (i.e., bound via 952 # connect(), not bind()) clears the source address and port. 953 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 954 self.SelectInterface(s, netid, "mark") 955 s.connect((dst, 53)) 956 self.sock_diag.CloseSocketFromFd(s) 957 self.assertEqual((unspec, 0), s.getsockname()[:2]) 958 959 # Closing a socket bound to an IP address leaves the address as is. 960 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 961 src = self.MySocketAddress(version, netid) 962 s.bind((src, 0)) 963 s.connect((dst, 53)) 964 port = s.getsockname()[1] 965 self.sock_diag.CloseSocketFromFd(s) 966 self.assertEqual((src, 0), s.getsockname()[:2]) 967 968 # Closing a socket bound to a port leaves the port as is. 969 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 970 port = self.BindToRandomPort(s, "") 971 s.connect((dst, 53)) 972 self.sock_diag.CloseSocketFromFd(s) 973 self.assertEqual((unspec, port), s.getsockname()[:2]) 974 975 # Closing a socket bound to IP address and port leaves both as is. 976 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 977 src = self.MySocketAddress(version, netid) 978 port = self.BindToRandomPort(s, src) 979 self.sock_diag.CloseSocketFromFd(s) 980 self.assertEqual((src, port), s.getsockname()[:2]) 981 982 def testReadInterrupted(self): 983 """Tests that read() is interrupted by SOCK_DESTROY.""" 984 for version in [4, 5, 6]: 985 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 986 s = net_test.UDPSocket(family) 987 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 988 addr = self.GetRemoteSocketAddress(version) 989 990 # Check that reads on connected sockets are interrupted. 991 s.connect((addr, 53)) 992 self.assertEqual(3, s.send("foo")) 993 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 994 ECONNABORTED) 995 996 # A destroyed socket is no longer connected, but still usable. 997 self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo") 998 self.assertEqual(3, s.sendto("foo", (addr, 53))) 999 1000 # Check that reads on unconnected sockets are also interrupted. 1001 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 1002 ECONNABORTED) 1003 1004class SockDestroyPermissionTest(SockDiagBaseTest): 1005 1006 def CheckPermissions(self, socktype): 1007 s = socket(AF_INET6, socktype, 0) 1008 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 1009 if socktype == SOCK_STREAM: 1010 s.listen(1) 1011 expectedstate = tcp_test.TCP_LISTEN 1012 else: 1013 s.connect((self.GetRemoteAddress(6), 53)) 1014 expectedstate = tcp_test.TCP_ESTABLISHED 1015 1016 with net_test.RunAsUid(12345): 1017 self.assertRaisesErrno( 1018 EPERM, self.sock_diag.CloseSocketFromFd, s) 1019 1020 self.sock_diag.CloseSocketFromFd(s) 1021 self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s) 1022 1023 1024 @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled") 1025 def testUdp(self): 1026 self.CheckPermissions(SOCK_DGRAM) 1027 1028 def testTcp(self): 1029 self.CheckPermissions(SOCK_STREAM) 1030 1031 1032class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 1033 1034 """Tests SOCK_DIAG bytecode filters that use marks. 1035 1036 Relevant kernel commits: 1037 upstream net-next: 1038 627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks. 1039 a52e95a net: diag: allow socket bytecode filters to match socket marks 1040 d545cac net: inet: diag: expose the socket mark to privileged processes. 1041 """ 1042 1043 def FilterEstablishedSockets(self, mark, mask): 1044 instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))] 1045 bytecode = self.sock_diag.PackBytecode(instructions) 1046 return self.sock_diag.DumpAllInetSockets( 1047 IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED)) 1048 1049 def assertSamePorts(self, ports, diag_msgs): 1050 expected = sorted(ports) 1051 actual = sorted([msg[0].id.sport for msg in diag_msgs]) 1052 self.assertEqual(expected, actual) 1053 1054 def SockInfoMatchesSocket(self, s, info): 1055 try: 1056 self.assertSockInfoMatchesSocket(s, info) 1057 return True 1058 except AssertionError: 1059 return False 1060 1061 @staticmethod 1062 def SocketDescription(s): 1063 return "%s -> %s" % (str(s.getsockname()), str(s.getpeername())) 1064 1065 def assertFoundSockets(self, infos, sockets): 1066 matches = {} 1067 for s in sockets: 1068 match = None 1069 for info in infos: 1070 if self.SockInfoMatchesSocket(s, info): 1071 if match: 1072 self.fail("Socket %s matched both %s and %s" % 1073 (self.SocketDescription(s), match, info)) 1074 matches[s] = info 1075 self.assertTrue(s in matches, "Did not find socket %s in dump" % 1076 self.SocketDescription(s)) 1077 1078 for i in infos: 1079 if i not in list(matches.values()): 1080 self.fail("Too many sockets in dump, first unexpected: %s" % str(i)) 1081 1082 def testMarkBytecode(self): 1083 family, addr = random.choice([ 1084 (AF_INET, "127.0.0.1"), 1085 (AF_INET6, "::1"), 1086 (AF_INET6, "::ffff:127.0.0.1")]) 1087 s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr) 1088 s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234) 1089 s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235) 1090 1091 infos = self.FilterEstablishedSockets(0x1234, 0xffff) 1092 self.assertFoundSockets(infos, [s1]) 1093 1094 infos = self.FilterEstablishedSockets(0x1234, 0xfffe) 1095 self.assertFoundSockets(infos, [s1, s2]) 1096 1097 infos = self.FilterEstablishedSockets(0x1235, 0xffff) 1098 self.assertFoundSockets(infos, [s2]) 1099 1100 infos = self.FilterEstablishedSockets(0x0, 0x0) 1101 self.assertFoundSockets(infos, [s1, s2]) 1102 1103 infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00) 1104 self.assertEqual(0, len(infos)) 1105 1106 with net_test.RunAsUid(12345): 1107 self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets, 1108 0xfff0000, 0xf0fed00) 1109 1110 @staticmethod 1111 def SetRandomMark(s): 1112 # Python doesn't like marks that don't fit into a signed int. 1113 mark = random.randrange(0, 2**31 - 1) 1114 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark) 1115 return mark 1116 1117 def assertSocketMarkIs(self, s, mark): 1118 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 1119 self.assertMarkIs(mark, attrs) 1120 with net_test.RunAsUid(12345): 1121 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 1122 self.assertMarkIs(None, attrs) 1123 1124 def testMarkInAttributes(self): 1125 testcases = [(AF_INET, "127.0.0.1"), 1126 (AF_INET6, "::1"), 1127 (AF_INET6, "::ffff:127.0.0.1")] 1128 for family, addr in testcases: 1129 # TCP listen sockets. 1130 server = socket(family, SOCK_STREAM, 0) 1131 server.bind((addr, 0)) 1132 port = server.getsockname()[1] 1133 server.listen(1) # Or the socket won't be in the hashtables. 1134 server_mark = self.SetRandomMark(server) 1135 self.assertSocketMarkIs(server, server_mark) 1136 1137 # TCP client sockets. 1138 client = socket(family, SOCK_STREAM, 0) 1139 client_mark = self.SetRandomMark(client) 1140 client.connect((addr, port)) 1141 self.assertSocketMarkIs(client, client_mark) 1142 1143 # TCP server sockets. 1144 accepted, _ = server.accept() 1145 self.assertSocketMarkIs(accepted, server_mark) 1146 1147 accepted_mark = self.SetRandomMark(accepted) 1148 self.assertSocketMarkIs(accepted, accepted_mark) 1149 self.assertSocketMarkIs(server, server_mark) 1150 1151 server.close() 1152 client.close() 1153 1154 # Other TCP states are tested in SockDestroyTcpTest. 1155 1156 # UDP sockets. 1157 if HAVE_UDP_DIAG: 1158 s = socket(family, SOCK_DGRAM, 0) 1159 mark = self.SetRandomMark(s) 1160 s.connect(("", 53)) 1161 self.assertSocketMarkIs(s, mark) 1162 s.close() 1163 1164 # Basic test for SCTP. sctp_diag was only added in 4.7. 1165 if HAVE_SCTP: 1166 s = socket(family, SOCK_STREAM, IPPROTO_SCTP) 1167 s.bind((addr, 0)) 1168 s.listen(1) 1169 mark = self.SetRandomMark(s) 1170 self.assertSocketMarkIs(s, mark) 1171 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE) 1172 self.assertEqual(1, len(sockets)) 1173 self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None)) 1174 s.close() 1175 1176 1177if __name__ == "__main__": 1178 unittest.main() 1179