#!/usr/bin/python3 # # Copyright 2015 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import from errno import * # pylint: disable=wildcard-import import binascii import os import random import select from socket import * # pylint: disable=wildcard-import import struct import threading import time import unittest import cstruct import multinetwork_base import net_test import packets import sock_diag import tcp_test # Mostly empty structure definition containing only the fields we currently use. TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh") NUM_SOCKETS = 30 NO_BYTECODE = b"" LINUX_4_19_OR_ABOVE = net_test.LINUX_VERSION >= (4, 19, 0) IPPROTO_SCTP = 132 def HaveSctp(): try: s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP) s.close() return True except IOError: return False HAVE_SCTP = HaveSctp() class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): """Basic tests for SOCK_DIAG functionality. Relevant kernel commits: android-3.4: ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields 99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() android-3.10: 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() android-3.18: e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() android-4.4: 525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() """ @staticmethod def _CreateLotsOfSockets(socktype): # Dict mapping (addr, sport, dport) tuples to socketpairs. socketpairs = {} for _ in range(NUM_SOCKETS): family, addr = random.choice([ (AF_INET, "127.0.0.1"), (AF_INET6, "::1"), (AF_INET6, "::ffff:127.0.0.1")]) socketpair = net_test.CreateSocketPair(family, socktype, addr) sport, dport = (socketpair[0].getsockname()[1], socketpair[1].getsockname()[1]) socketpairs[(addr, sport, dport)] = socketpair return socketpairs def assertSocketClosed(self, sock): self.assertRaisesErrno(ENOTCONN, sock.getpeername) def assertSocketConnected(self, sock): sock.getpeername() # No errors? Socket is alive and connected. def assertSocketsClosed(self, socketpair): for sock in socketpair: self.assertSocketClosed(sock) def assertMarkIs(self, mark, attrs): self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None)) def assertSockInfoMatchesSocket(self, s, info): diag_msg, attrs = info family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) self.assertEqual(diag_msg.family, family) src, sport = s.getsockname()[0:2] self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) self.assertEqual(diag_msg.id.sport, sport) if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: dst, dport = s.getpeername()[0:2] self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) self.assertEqual(diag_msg.id.dport, dport) else: self.assertRaisesErrno(ENOTCONN, s.getpeername) mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK) self.assertMarkIs(mark, attrs) def PackAndCheckBytecode(self, instructions): bytecode = self.sock_diag.PackBytecode(instructions) decoded = self.sock_diag.DecodeBytecode(bytecode) self.assertEqual(len(instructions), len(decoded)) self.assertFalse("???" in decoded) return bytecode def _EventDuringBlockingCall(self, sock, call, expected_errno, event): """Simulates an external event during a blocking call on sock. Args: sock: The socket to use. call: A function, the call to make. Takes one parameter, sock. expected_errno: The value that call is expected to fail with, or None if call is expected to succeed. event: A function, the event that will happen during the blocking call. Takes one parameter, sock. """ thread = SocketExceptionThread(sock, call) thread.start() time.sleep(0.1) event(sock) thread.join(1) self.assertFalse(thread.is_alive()) if expected_errno is not None: self.assertIsNotNone(thread.exception) self.assertTrue(isinstance(thread.exception, IOError), "Expected IOError, got %s" % thread.exception) self.assertEqual(expected_errno, thread.exception.errno) else: self.assertIsNone(thread.exception) self.assertSocketClosed(sock) def CloseDuringBlockingCall(self, sock, call, expected_errno): self._EventDuringBlockingCall( sock, call, expected_errno, lambda sock: self.sock_diag.CloseSocketFromFd(sock)) def setUp(self): super(SockDiagBaseTest, self).setUp() self.sock_diag = sock_diag.SockDiag() self.socketpairs = {} def tearDown(self): for socketpair in list(self.socketpairs.values()): for s in socketpair: s.close() super(SockDiagBaseTest, self).tearDown() class SockDiagTest(SockDiagBaseTest): def testFindsMappedSockets(self): """Tests that inet_diag_find_one_icsk can find mapped sockets.""" socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::ffff:127.0.0.1") for sock in socketpair: diag_msg = self.sock_diag.FindSockDiagFromFd(sock) diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) self.sock_diag.GetSockInfo(diag_req) # No errors? Good. def CheckFindsAllMySockets(self, socktype, proto): """Tests that basic socket dumping works.""" self.socketpairs = self._CreateLotsOfSockets(socktype) sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE) self.assertGreaterEqual(len(sockets), NUM_SOCKETS) # Find the cookies for all of our sockets. cookies = {} for diag_msg, unused_attrs in sockets: addr = self.sock_diag.GetSourceAddress(diag_msg) sport = diag_msg.id.sport dport = diag_msg.id.dport if (addr, sport, dport) in self.socketpairs: cookies[(addr, sport, dport)] = diag_msg.id.cookie elif (addr, dport, sport) in self.socketpairs: cookies[(addr, sport, dport)] = diag_msg.id.cookie # Did we find all the cookies? self.assertEqual(2 * NUM_SOCKETS, len(cookies)) socketpairs = list(self.socketpairs.values()) random.shuffle(socketpairs) for socketpair in socketpairs: for sock in socketpair: # Check that we can find a diag_msg by scanning a dump. self.assertSockInfoMatchesSocket( sock, self.sock_diag.FindSockInfoFromFd(sock)) cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie # Check that we can find a diag_msg once we know the cookie. req = self.sock_diag.DiagReqFromSocket(sock) req.id.cookie = cookie if proto == IPPROTO_UDP: # Kernel bug: for UDP sockets, the order of arguments must be swapped. # See testDemonstrateUdpGetSockIdBug. req.id.sport, req.id.dport = req.id.dport, req.id.sport req.id.src, req.id.dst = req.id.dst, req.id.src info = self.sock_diag.GetSockInfo(req) self.assertSockInfoMatchesSocket(sock, info) def assertItemsEqual(self, expected, actual): try: super(SockDiagTest, self).assertItemsEqual(expected, actual) except AttributeError: # This was renamed in python3 but has the same behaviour. super(SockDiagTest, self).assertCountEqual(expected, actual) def testFindsAllMySocketsTcp(self): self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP) def testFindsAllMySocketsUdp(self): self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP) def testBytecodeCompilation(self): # pylint: disable=bad-whitespace instructions = [ (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 # 76 acc # 80 rej ] # pylint: enable=bad-whitespace bytecode = self.PackAndCheckBytecode(instructions) expected = ( b"0208500000000000" b"050848000000ffff" b"071c20000a800000ffffffff00000000000000000000000000000001" b"01041c00" b"0718200002200000ffffffff7f000001" b"0508100000006566" b"00040400" ) states = 1 << tcp_test.TCP_ESTABLISHED self.assertEqual(expected, binascii.hexlify(bytecode)) self.assertEqual(76, len(bytecode)) self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode, states=states) allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, states=states) self.assertItemsEqual(allsockets, filteredsockets) # Pick a few sockets in hash table order, and check that the bytecode we # compiled selects them properly. for socketpair in list(self.socketpairs.values())[:20]: for s in socketpair: diag_msg = self.sock_diag.FindSockDiagFromFd(s) instructions = [ (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), ] bytecode = self.PackAndCheckBytecode(instructions) self.assertEqual(32, len(bytecode)) sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) self.assertEqual(1, len(sockets)) # TODO: why doesn't comparing the cstructs work? self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack()) def testCrossFamilyBytecode(self): """Checks for a cross-family bug in inet_diag_hostcond matching. Relevant kernel commits: android-3.4: f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() """ # TODO: this is only here because the test fails if there are any open # sockets other than the ones it creates itself. Make the bytecode more # specific and remove it. states = 1 << tcp_test.TCP_ESTABLISHED self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, states=states)) unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") bytecode4 = self.PackAndCheckBytecode([ (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) bytecode6 = self.PackAndCheckBytecode([ (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) # IPv4/v6 filters must never match IPv6/IPv4 sockets... v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4, states=states) self.assertTrue(v4socks) self.assertTrue(all(d.family == AF_INET for d, _ in v4socks)) v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6, states=states) self.assertTrue(v6socks) self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks)) # Except for mapped addresses, which match both IPv4 and IPv6. pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::ffff:127.0.0.1") diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4, states=states)] v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6, states=states)] self.assertTrue(all(d in v4socks for d in diag_msgs)) self.assertTrue(all(d in v6socks for d in diag_msgs)) def testPortComparisonValidation(self): """Checks for a bug in validating port comparison bytecode. Relevant kernel commits: android-3.4: 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads """ bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) self.assertEqual("???", self.sock_diag.DecodeBytecode(bytecode)) self.assertRaisesErrno( EINVAL, self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack()) def testNonSockDiagCommand(self): def DiagDump(code): sock_id = self.sock_diag._EmptyInetDiagSockId() req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, sock_id)) self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg) op = sock_diag.SOCK_DIAG_BY_FAMILY DiagDump(op) # No errors? Good. self.assertRaisesErrno(EINVAL, DiagDump, op + 17) def CheckSocketCookie(self, inet, addr): """Tests that getsockopt SO_COOKIE can get cookie for all sockets.""" socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr) for sock in socketpair: diag_msg = self.sock_diag.FindSockDiagFromFd(sock) cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) self.assertEqual(diag_msg.id.cookie, cookie) def testGetsockoptcookie(self): self.CheckSocketCookie(AF_INET, "127.0.0.1") self.CheckSocketCookie(AF_INET6, "::1") def testDemonstrateUdpGetSockIdBug(self): # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup # by passing the source address as the source address argument. # Unfortunately those functions are intended to match local sockets based # on received packets, and the argument that ends up being compared with # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not # have this bug. Upstream has confirmed that this will not be fixed: # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html """Documents a bug: getting UDP sockets requires swapping src and dst.""" for version in [4, 5, 6]: family = net_test.GetAddressFamily(version) s = socket(family, SOCK_DGRAM, 0) self.SelectInterface(s, self.RandomNetid(), "mark") s.connect((self.GetRemoteSocketAddress(version), 53)) # Create a fully-specified diag req from our socket, including cookie if # we can get it. req = self.sock_diag.DiagReqFromSocket(s) req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) # As is, this request does not find anything. with self.assertRaisesErrno(ENOENT): self.sock_diag.GetSockInfo(req) # But if we swap src and dst, the kernel finds our socket. req.id.sport, req.id.dport = req.id.dport, req.id.sport req.id.src, req.id.dst = req.id.dst, req.id.src self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req)) class SockDestroyTest(SockDiagBaseTest): """Tests that SOCK_DESTROY works correctly. Relevant kernel commits: net-next: b613f56 net: diag: split inet_diag_dump_one_icsk into two 64be0ae net: diag: Add the ability to destroy a socket. 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets. c1e64e2 net: diag: Support destroying TCP sockets. 2010b93 net: tcp: deal with listen sockets properly in tcp_abort. android-3.4: d48ec88 net: diag: split inet_diag_dump_one_icsk into two 2438189 net: diag: Add the ability to destroy a socket. 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets. 44047b2 net: diag: Support destroying TCP sockets. 200dae7 net: tcp: deal with listen sockets properly in tcp_abort. android-3.10: 9eaff90 net: diag: split inet_diag_dump_one_icsk into two d60326c net: diag: Add the ability to destroy a socket. 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets. 529dfc6 net: diag: Support destroying TCP sockets. 9c712fe net: tcp: deal with listen sockets properly in tcp_abort. android-3.18: 100263d net: diag: split inet_diag_dump_one_icsk into two 194c5f3 net: diag: Add the ability to destroy a socket. 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets. b80585a net: diag: Support destroying TCP sockets. 476c6ce net: tcp: deal with listen sockets properly in tcp_abort. android-4.1: 56eebf8 net: diag: split inet_diag_dump_one_icsk into two fb486c9 net: diag: Add the ability to destroy a socket. 0c02b7e net: diag: Support SOCK_DESTROY for inet sockets. 67c71d8 net: diag: Support destroying TCP sockets. a76e0ec net: tcp: deal with listen sockets properly in tcp_abort. e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk() android-4.4: 76c83a9 net: diag: split inet_diag_dump_one_icsk into two f7cf791 net: diag: Add the ability to destroy a socket. 1c42248 net: diag: Support SOCK_DESTROY for inet sockets. c9e8440d net: diag: Support destroying TCP sockets. 3d9502c tcp: diag: add support for request sockets to tcp_abort() 001cf75 net: tcp: deal with listen sockets properly in tcp_abort. """ def testClosesSockets(self): self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) for _, socketpair in self.socketpairs.items(): # Close one of the sockets. # This will send a RST that will close the other side as well. s = random.choice(socketpair) if random.randrange(0, 2) == 1: self.sock_diag.CloseSocketFromFd(s) else: diag_msg = self.sock_diag.FindSockDiagFromFd(s) # Get the cookie wrong and ensure that we get an error and the socket # is not closed. real_cookie = diag_msg.id.cookie diag_msg.id.cookie = os.urandom(len(real_cookie)) req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) self.assertSocketConnected(s) # Now close it with the correct cookie. req.id.cookie = real_cookie self.sock_diag.CloseSocket(req) # Check that both sockets in the pair are closed. self.assertSocketsClosed(socketpair) # TODO: # Test that killing unix sockets returns EOPNOTSUPP. class SocketExceptionThread(threading.Thread): def __init__(self, sock, operation): self.exception = None super(SocketExceptionThread, self).__init__() self.daemon = True self.sock = sock self.operation = operation def run(self): try: self.operation(self.sock) except (IOError, AssertionError) as e: self.exception = e class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): def testIpv4MappedSynRecvSocket(self): """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. Relevant kernel commits: android-3.4: 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state """ netid = random.choice(list(self.tuns.keys())) self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid) sock_id = self.sock_diag._EmptyInetDiagSockId() sock_id.sport = self.port states = 1 << tcp_test.TCP_SYN_RECV req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) children = self.sock_diag.Dump(req, NO_BYTECODE) self.assertTrue(children) for child, unused_args in children: self.assertEqual(tcp_test.TCP_SYN_RECV, child.state) self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr), child.id.dst) self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr), child.id.src) class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest): RWND_SIZE = 64000 if LINUX_4_19_OR_ABOVE else 42000 TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd" def setUp(self): super(TcpRcvWindowTest, self).setUp() if LINUX_4_19_OR_ABOVE: self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w") return try: f = open(self.TCP_DEFAULT_INIT_RWND, "w") except IOError as e: # sysctl was namespace-ified on May 25, 2020 in android-4.14-stable [R] # just after 4.14.181 by: # https://android-review.googlesource.com/c/kernel/common/+/1312623 # ANDROID: namespace'ify tcp_default_init_rwnd implementation # But that commit might be missing in Q era kernels even when > 4.14.181 # when running T vts. if net_test.LINUX_VERSION >= (4, 15, 0): raise if e.errno != ENOENT: raise # we rely on the network namespace creation code # modifying the root netns sysctl before the namespace is even created return f.write("60") f.close() def checkInitRwndSize(self, version, netid): self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid) tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP, net_test.TCP_INFO, len(TcpInfo))) self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh, "Tcp rwnd of netid=%d, version=%d is not enough. " "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh)) def checkSynPacketWindowSize(self, version, netid): s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark") myaddr = self.MyAddress(version, netid) dstaddr = self.GetRemoteAddress(version) dstsockaddr = self.GetRemoteSocketAddress(version) desc, expected = packets.SYN(53, version, myaddr, dstaddr, sport=None, seq=None) self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53)) msg = "IPv%s TCP connect: expected %s on %s" % ( version, desc, self.GetInterfaceName(netid)) syn = self.ExpectPacketOn(netid, msg, expected) self.assertLess(self.RWND_SIZE, syn.window) s.close() def testTcpCwndSize(self): for version in [4, 5, 6]: for netid in self.NETIDS: self.checkInitRwndSize(version, netid) self.checkSynPacketWindowSize(version, netid) class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): def setUp(self): super(SockDestroyTcpTest, self).setUp() self.netid = random.choice(list(self.tuns.keys())) def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): """Closes the socket and checks whether a RST is sent or not.""" if sock is not None: self.assertIsNone(req, "Must specify sock or req, not both") self.sock_diag.CloseSocketFromFd(sock) self.assertRaisesErrno(EINVAL, sock.accept) else: self.assertIsNone(sock, "Must specify sock or req, not both") self.sock_diag.CloseSocket(req) if expect_reset: desc, rst = self.RstPacket() msg = "%s: expecting %s: " % (msg, desc) self.ExpectPacketOn(self.netid, msg, rst) else: msg = "%s: " % msg self.ExpectNoPacketsOn(self.netid, msg) if sock is not None and do_close: sock.close() def CheckTcpReset(self, state, statename): for version in [4, 5, 6]: msg = "Closing incoming IPv%d %s socket" % (version, statename) self.IncomingConnection(version, state, self.netid) self.CheckRstOnClose(self.s, None, False, msg) if state != tcp_test.TCP_LISTEN: msg = "Closing accepted IPv%d %s socket" % (version, statename) self.CheckRstOnClose(self.accepted, None, True, msg) def testTcpResets(self): """Checks that closing sockets in appropriate states sends a RST.""" self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN") self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED") self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") def testFinWait1Socket(self): for version in [4, 5, 6]: self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) # Get the cookie so we can find this socket after we close it. diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted) diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN. net_test.EnableFinWait(self.accepted) self.accepted.close() diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) desc, fin = self.FinPacket() self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin) # Destroy the socket and expect no RST. self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket") diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing # because userspace had already closed it. self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) # ACK the FIN so we don't trip over retransmits in future tests. finversion = 4 if version == 5 else version desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin) diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) self.ReceivePacketOn(self.netid, finack) # See if we can find the resulting FIN_WAIT2 socket. diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2 infos = self.sock_diag.Dump(diag_req, NO_BYTECODE) self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2 for diag_msg, attrs in infos), "Expected to find FIN_WAIT2 socket in %s" % infos) def FindChildSockets(self, s): """Finds the SYN_RECV child sockets of a given listening socket.""" d = self.sock_diag.FindSockDiagFromFd(self.s) req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED req.id.cookie = b"\x00" * 8 bad_bytecode = self.PackAndCheckBytecode( [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))]) self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode)) bytecode = self.PackAndCheckBytecode( [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))]) children = self.sock_diag.Dump(req, bytecode) return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) for d, _ in children] def CheckChildSocket(self, version, statename, parent_first): state = getattr(tcp_test, statename) self.IncomingConnection(version, state, self.netid) d = self.sock_diag.FindSockDiagFromFd(self.s) parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) children = self.FindChildSockets(self.s) self.assertEqual(1, len(children)) is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED) expected_state = tcp_test.TCP_ESTABLISHED if is_established else state for child in children: diag_msg, attrs = self.sock_diag.GetSockInfo(child) self.assertEqual(diag_msg.state, expected_state) self.assertMarkIs(self.netid, attrs) def CloseParent(expect_reset): msg = "Closing parent IPv%d %s socket %s child" % ( version, statename, "before" if parent_first else "after") self.CheckRstOnClose(self.s, None, expect_reset, msg) self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent) def CheckChildrenClosed(): for child in children: self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) def CloseChildren(): for child in children: msg = "Closing child IPv%d %s socket %s parent" % ( version, statename, "after" if parent_first else "before") self.sock_diag.GetSockInfo(child) self.CheckRstOnClose(None, child, is_established, msg) self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) CheckChildrenClosed() if parent_first: # Closing the parent will close child sockets, which will send a RST, # iff they are already established. CloseParent(is_established) if is_established: CheckChildrenClosed() else: CloseChildren() CheckChildrenClosed() self.s.close() else: CloseChildren() CloseParent(False) self.s.close() def testChildSockets(self): for version in [4, 5, 6]: self.CheckChildSocket(version, "TCP_SYN_RECV", False) self.CheckChildSocket(version, "TCP_SYN_RECV", True) self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False) self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True) def testAcceptInterrupted(self): """Tests that accept() is interrupted by SOCK_DESTROY.""" for version in [4, 5, 6]: self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid) self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096) self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) self.assertRaisesErrno(ECONNABORTED, self.s.send, b"foo") self.assertRaisesErrno(EINVAL, self.s.accept) # TODO: this should really return an error such as ENOTCONN... self.assertEqual(b"", self.s.recv(4096)) def testReadInterrupted(self): """Tests that read() is interrupted by SOCK_DESTROY.""" for version in [4, 5, 6]: self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), ECONNABORTED) # Writing returns EPIPE, and reading returns EOF. self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo") self.assertEqual(b"", self.accepted.recv(4096)) self.assertEqual(b"", self.accepted.recv(4096)) def testConnectInterrupted(self): """Tests that connect() is interrupted by SOCK_DESTROY.""" for version in [4, 5, 6]: family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) self.SelectInterface(s, self.netid, "mark") remotesockaddr = self.GetRemoteSocketAddress(version) remoteaddr = self.GetRemoteAddress(version) s.bind(("", 0)) _, sport = s.getsockname()[:2] self.CloseDuringBlockingCall( s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED) desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), remoteaddr, sport=sport, seq=None) self.ExpectPacketOn(self.netid, desc, syn) msg = "SOCK_DESTROY of socket in connect, expected no RST" self.ExpectNoPacketsOn(self.netid, msg) class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest): """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs. The behaviour of poll() in these cases is not what we might expect: if only POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT is (also) specified, it will only return POLLOUT. """ POLLIN_OUT = select.POLLIN | select.POLLOUT POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP def setUp(self): super(PollOnCloseTest, self).setUp() self.netid = random.choice(list(self.tuns.keys())) POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"), (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")] def PollResultToString(self, poll_events, ignoremask): out = [] for fd, event in poll_events: flags = [name for (flag, name) in self.POLL_FLAGS if event & flag & ~ignoremask != 0] out.append((fd, "|".join(flags))) return out def BlockingPoll(self, sock, mask, expected, ignoremask): p = select.poll() p.register(sock, mask) expected_fds = [(sock.fileno(), expected)] # Don't block forever or we'll hang continuous test runs on failure. # A 5-second timeout should be long enough not to be flaky. actual_fds = p.poll(5000) self.assertEqual(self.PollResultToString(expected_fds, ignoremask), self.PollResultToString(actual_fds, ignoremask)) def RstDuringBlockingCall(self, sock, call, expected_errno): self._EventDuringBlockingCall( sock, call, expected_errno, lambda _: self.ReceiveRstPacketOn(self.netid)) def assertSocketErrors(self, errno): # The first operation returns the expected errno. self.assertRaisesErrno(errno, self.accepted.recv, 4096) # Subsequent operations behave as normal. self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo") self.assertEqual(b"", self.accepted.recv(4096)) self.assertEqual(b"", self.accepted.recv(4096)) def CheckPollDestroy(self, mask, expected, ignoremask): """Interrupts a poll() with SOCK_DESTROY.""" for version in [4, 5, 6]: self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) self.CloseDuringBlockingCall( self.accepted, lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), None) self.assertSocketErrors(ECONNABORTED) def CheckPollRst(self, mask, expected, ignoremask): """Interrupts a poll() by receiving a TCP RST.""" for version in [4, 5, 6]: self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) self.RstDuringBlockingCall( self.accepted, lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), None) self.assertSocketErrors(ECONNRESET) def testReadPollRst(self): self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, 0) def testWritePollRst(self): self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0) def testReadWritePollRst(self): self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0) def testReadPollDestroy(self): # tcp_abort has the same race that tcp_reset has, but it's not fixed yet. ignoremask = select.POLLIN | select.POLLHUP self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask) def testWritePollDestroy(self): self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0) def testReadWritePollDestroy(self): self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0) class SockDestroyUdpTest(SockDiagBaseTest): """Tests SOCK_DESTROY on UDP sockets. Relevant kernel commits: upstream net-next: 5d77dca net: diag: support SOCK_DESTROY for UDP sockets f95bf34 net: diag: make udp_diag_destroy work for mapped addresses. """ def testClosesUdpSockets(self): self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM) for _, socketpair in self.socketpairs.items(): s1, s2 = socketpair self.assertSocketConnected(s1) self.sock_diag.CloseSocketFromFd(s1) self.assertSocketClosed(s1) self.assertSocketConnected(s2) self.sock_diag.CloseSocketFromFd(s2) self.assertSocketClosed(s2) def BindToRandomPort(self, s, addr): ATTEMPTS = 20 for i in range(20): port = random.randrange(1024, 65535) try: s.bind((addr, port)) return port except error as e: if e.errno != EADDRINUSE: raise e raise ValueError("Could not find a free port on %s after %d attempts" % (addr, ATTEMPTS)) def testSocketAddressesAfterClose(self): for version in 4, 5, 6: netid = random.choice(self.NETIDS) dst = self.GetRemoteSocketAddress(version) family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version] # Closing a socket that was not explicitly bound (i.e., bound via # connect(), not bind()) clears the source address and port. s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") self.SelectInterface(s, netid, "mark") s.connect((dst, 53)) self.sock_diag.CloseSocketFromFd(s) self.assertEqual((unspec, 0), s.getsockname()[:2]) # Closing a socket bound to an IP address leaves the address as is. s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") src = self.MySocketAddress(version, netid) s.bind((src, 0)) s.connect((dst, 53)) port = s.getsockname()[1] self.sock_diag.CloseSocketFromFd(s) self.assertEqual((src, 0), s.getsockname()[:2]) # Closing a socket bound to a port leaves the port as is. s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") port = self.BindToRandomPort(s, "") s.connect((dst, 53)) self.sock_diag.CloseSocketFromFd(s) self.assertEqual((unspec, port), s.getsockname()[:2]) # Closing a socket bound to IP address and port leaves both as is. s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") src = self.MySocketAddress(version, netid) port = self.BindToRandomPort(s, src) self.sock_diag.CloseSocketFromFd(s) self.assertEqual((src, port), s.getsockname()[:2]) def testReadInterrupted(self): """Tests that read() is interrupted by SOCK_DESTROY.""" for version in [4, 5, 6]: family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] s = net_test.UDPSocket(family) self.SelectInterface(s, random.choice(self.NETIDS), "mark") addr = self.GetRemoteSocketAddress(version) # Check that reads on connected sockets are interrupted. s.connect((addr, 53)) self.assertEqual(3, s.send(b"foo")) self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), ECONNABORTED) # A destroyed socket is no longer connected, but still usable. self.assertRaisesErrno(EDESTADDRREQ, s.send, b"foo") self.assertEqual(3, s.sendto(b"foo", (addr, 53))) # Check that reads on unconnected sockets are also interrupted. self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), ECONNABORTED) class SockDestroyPermissionTest(SockDiagBaseTest): def CheckPermissions(self, socktype): s = socket(AF_INET6, socktype, 0) self.SelectInterface(s, random.choice(self.NETIDS), "mark") if socktype == SOCK_STREAM: s.listen(1) expectedstate = tcp_test.TCP_LISTEN else: s.connect((self.GetRemoteAddress(6), 53)) expectedstate = tcp_test.TCP_ESTABLISHED with net_test.RunAsUid(12345): self.assertRaisesErrno( EPERM, self.sock_diag.CloseSocketFromFd, s) self.sock_diag.CloseSocketFromFd(s) self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s) def testUdp(self): self.CheckPermissions(SOCK_DGRAM) def testTcp(self): self.CheckPermissions(SOCK_STREAM) class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): """Tests SOCK_DIAG bytecode filters that use marks. Relevant kernel commits: upstream net-next: 627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks. a52e95a net: diag: allow socket bytecode filters to match socket marks d545cac net: inet: diag: expose the socket mark to privileged processes. """ def FilterEstablishedSockets(self, mark, mask): instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))] bytecode = self.sock_diag.PackBytecode(instructions) return self.sock_diag.DumpAllInetSockets( IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED)) def assertSamePorts(self, ports, diag_msgs): expected = sorted(ports) actual = sorted([msg[0].id.sport for msg in diag_msgs]) self.assertEqual(expected, actual) def SockInfoMatchesSocket(self, s, info): try: self.assertSockInfoMatchesSocket(s, info) return True except AssertionError: return False @staticmethod def SocketDescription(s): return "%s -> %s" % (str(s.getsockname()), str(s.getpeername())) def assertFoundSockets(self, infos, sockets): matches = {} for s in sockets: match = None for info in infos: if self.SockInfoMatchesSocket(s, info): if match: self.fail("Socket %s matched both %s and %s" % (self.SocketDescription(s), match, info)) matches[s] = info self.assertTrue(s in matches, "Did not find socket %s in dump" % self.SocketDescription(s)) for i in infos: if i not in list(matches.values()): self.fail("Too many sockets in dump, first unexpected: %s" % str(i)) def testMarkBytecode(self): family, addr = random.choice([ (AF_INET, "127.0.0.1"), (AF_INET6, "::1"), (AF_INET6, "::ffff:127.0.0.1")]) s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr) s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234) s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235) infos = self.FilterEstablishedSockets(0x1234, 0xffff) self.assertFoundSockets(infos, [s1]) infos = self.FilterEstablishedSockets(0x1234, 0xfffe) self.assertFoundSockets(infos, [s1, s2]) infos = self.FilterEstablishedSockets(0x1235, 0xffff) self.assertFoundSockets(infos, [s2]) infos = self.FilterEstablishedSockets(0x0, 0x0) self.assertFoundSockets(infos, [s1, s2]) infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00) self.assertEqual(0, len(infos)) with net_test.RunAsUid(12345): self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets, 0xfff0000, 0xf0fed00) @staticmethod def SetRandomMark(s): # Python doesn't like marks that don't fit into a signed int. mark = random.randrange(0, 2**31 - 1) s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark) return mark def assertSocketMarkIs(self, s, mark): diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) self.assertMarkIs(mark, attrs) with net_test.RunAsUid(12345): diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) self.assertMarkIs(None, attrs) def testMarkInAttributes(self): testcases = [(AF_INET, "127.0.0.1"), (AF_INET6, "::1"), (AF_INET6, "::ffff:127.0.0.1")] for family, addr in testcases: # TCP listen sockets. server = socket(family, SOCK_STREAM, 0) server.bind((addr, 0)) port = server.getsockname()[1] server.listen(1) # Or the socket won't be in the hashtables. server_mark = self.SetRandomMark(server) self.assertSocketMarkIs(server, server_mark) # TCP client sockets. client = socket(family, SOCK_STREAM, 0) client_mark = self.SetRandomMark(client) client.connect((addr, port)) self.assertSocketMarkIs(client, client_mark) # TCP server sockets. accepted, _ = server.accept() self.assertSocketMarkIs(accepted, server_mark) accepted_mark = self.SetRandomMark(accepted) self.assertSocketMarkIs(accepted, accepted_mark) self.assertSocketMarkIs(server, server_mark) server.close() client.close() # Other TCP states are tested in SockDestroyTcpTest. # UDP sockets. s = socket(family, SOCK_DGRAM, 0) mark = self.SetRandomMark(s) s.connect(("", 53)) self.assertSocketMarkIs(s, mark) s.close() # Basic test for SCTP. sctp_diag was only added in 4.7. if HAVE_SCTP: s = socket(family, SOCK_STREAM, IPPROTO_SCTP) s.bind((addr, 0)) s.listen(1) mark = self.SetRandomMark(s) self.assertSocketMarkIs(s, mark) sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE) self.assertEqual(1, len(sockets)) self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None)) s.close() if __name__ == "__main__": unittest.main()