1#!/usr/bin/python 2# 3# Copyright 2019 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import unittest 18 19from errno import * # pylint: disable=wildcard-import 20from socket import * # pylint: disable=wildcard-import 21import ctypes 22import fcntl 23import os 24import random 25import select 26import termios 27import threading 28import time 29from scapy import all as scapy 30 31import multinetwork_base 32import net_test 33import packets 34 35SOL_TCP = net_test.SOL_TCP 36SHUT_RD = net_test.SHUT_RD 37SHUT_WR = net_test.SHUT_WR 38SHUT_RDWR = net_test.SHUT_RDWR 39SIOCINQ = termios.FIONREAD 40SIOCOUTQ = termios.TIOCOUTQ 41 42TEST_PORT = 5555 43 44# Following constants are SOL_TCP level options and arguments. 45# They are defined in linux-kernel: include/uapi/linux/tcp.h 46 47# SOL_TCP level options. 48TCP_REPAIR = 19 49TCP_REPAIR_QUEUE = 20 50TCP_QUEUE_SEQ = 21 51 52# TCP_REPAIR_{OFF, ON} is an argument to TCP_REPAIR. 53TCP_REPAIR_OFF = 0 54TCP_REPAIR_ON = 1 55 56# TCP_{NO, RECV, SEND}_QUEUE is an argument to TCP_REPAIR_QUEUE. 57TCP_NO_QUEUE = 0 58TCP_RECV_QUEUE = 1 59TCP_SEND_QUEUE = 2 60 61# This test is aiming to ensure tcp keep alive offload works correctly 62# when it fetches tcp information from kernel via tcp repair mode. 63class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): 64 65 def assertSocketNotConnected(self, sock): 66 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 67 68 def assertSocketConnected(self, sock): 69 sock.getpeername() # No errors? Socket is alive and connected. 70 71 def createConnectedSocket(self, version, netid): 72 s = net_test.TCPSocket(net_test.GetAddressFamily(version)) 73 net_test.DisableFinWait(s) 74 self.SelectInterface(s, netid, "mark") 75 76 remotesockaddr = self.GetRemoteSocketAddress(version) 77 remoteaddr = self.GetRemoteAddress(version) 78 self.assertRaisesErrno(EINPROGRESS, s.connect, (remotesockaddr, TEST_PORT)) 79 self.assertSocketNotConnected(s) 80 81 myaddr = self.MyAddress(version, netid) 82 port = s.getsockname()[1] 83 self.assertNotEqual(0, port) 84 85 desc, expect_syn = packets.SYN(TEST_PORT, version, myaddr, remoteaddr, port, seq=None) 86 msg = "socket connect: expected %s" % desc 87 syn = self.ExpectPacketOn(netid, msg, expect_syn) 88 synack_desc, synack = packets.SYNACK(version, remoteaddr, myaddr, syn) 89 synack.getlayer("TCP").seq = random.getrandbits(32) 90 synack.getlayer("TCP").window = 14400 91 self.ReceivePacketOn(netid, synack) 92 desc, ack = packets.ACK(version, myaddr, remoteaddr, synack) 93 msg = "socket connect: got SYN+ACK, expected %s" % desc 94 ack = self.ExpectPacketOn(netid, msg, ack) 95 self.last_sent = ack 96 self.last_received = synack 97 return s 98 99 def receiveFin(self, netid, version, sock): 100 self.assertSocketConnected(sock) 101 remoteaddr = self.GetRemoteAddress(version) 102 myaddr = self.MyAddress(version, netid) 103 desc, fin = packets.FIN(version, remoteaddr, myaddr, self.last_sent) 104 self.ReceivePacketOn(netid, fin) 105 self.last_received = fin 106 107 def sendData(self, netid, version, sock, payload): 108 sock.send(payload) 109 110 remoteaddr = self.GetRemoteAddress(version) 111 myaddr = self.MyAddress(version, netid) 112 desc, send = packets.ACK(version, myaddr, remoteaddr, 113 self.last_received, payload) 114 self.last_sent = send 115 116 def receiveData(self, netid, version, payload): 117 remoteaddr = self.GetRemoteAddress(version) 118 myaddr = self.MyAddress(version, netid) 119 120 desc, received = packets.ACK(version, remoteaddr, myaddr, 121 self.last_sent, payload) 122 ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, received) 123 self.ReceivePacketOn(netid, received) 124 time.sleep(0.1) 125 self.ExpectPacketOn(netid, "expecting %s" % ack_desc, ack) 126 self.last_sent = ack 127 self.last_received = received 128 129 # Test the behavior of NO_QUEUE. Expect incoming data will be stored into 130 # the queue, but socket cannot be read/written in NO_QUEUE. 131 def testTcpRepairInNoQueue(self): 132 for version in [4, 5, 6]: 133 self.tcpRepairInNoQueueTest(version) 134 135 def tcpRepairInNoQueueTest(self, version): 136 netid = self.RandomNetid() 137 sock = self.createConnectedSocket(version, netid) 138 sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) 139 140 # In repair mode with NO_QUEUE, writes fail... 141 self.assertRaisesErrno(EINVAL, sock.send, "write test") 142 143 # remote data is coming. 144 TEST_RECEIVED = net_test.UDP_PAYLOAD 145 self.receiveData(netid, version, TEST_RECEIVED) 146 147 # In repair mode with NO_QUEUE, read fail... 148 self.assertRaisesErrno(EPERM, sock.recv, 4096) 149 150 sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) 151 readData = sock.recv(4096) 152 self.assertEqual(readData, TEST_RECEIVED) 153 sock.close() 154 155 # Test whether tcp read/write sequence number can be fetched correctly 156 # by TCP_QUEUE_SEQ. 157 def testGetSequenceNumber(self): 158 for version in [4, 5, 6]: 159 self.GetSequenceNumberTest(version) 160 161 def GetSequenceNumberTest(self, version): 162 netid = self.RandomNetid() 163 sock = self.createConnectedSocket(version, netid) 164 # test write queue sequence number 165 sequence_before = self.GetWriteSequenceNumber(version, sock) 166 expect_sequence = self.last_sent.getlayer("TCP").seq 167 self.assertEqual(sequence_before & 0xffffffff, expect_sequence) 168 TEST_SEND = net_test.UDP_PAYLOAD 169 self.sendData(netid, version, sock, TEST_SEND) 170 sequence_after = self.GetWriteSequenceNumber(version, sock) 171 self.assertEqual(sequence_before + len(TEST_SEND), sequence_after) 172 173 # test read queue sequence number 174 sequence_before = self.GetReadSequenceNumber(version, sock) 175 expect_sequence = self.last_received.getlayer("TCP").seq + 1 176 self.assertEqual(sequence_before & 0xffffffff, expect_sequence) 177 TEST_READ = net_test.UDP_PAYLOAD 178 self.receiveData(netid, version, TEST_READ) 179 sequence_after = self.GetReadSequenceNumber(version, sock) 180 self.assertEqual(sequence_before + len(TEST_READ), sequence_after) 181 sock.close() 182 183 def GetWriteSequenceNumber(self, version, sock): 184 sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) 185 sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE) 186 sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ) 187 sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE) 188 sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) 189 return sequence 190 191 def GetReadSequenceNumber(self, version, sock): 192 sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) 193 sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_RECV_QUEUE) 194 sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ) 195 sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE) 196 sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) 197 return sequence 198 199 # Test whether tcp repair socket can be poll()'ed correctly 200 # in mutiple threads at the same time. 201 def testMultiThreadedPoll(self): 202 for version in [4, 5, 6]: 203 self.PollWhenShutdownTest(version) 204 self.PollWhenReceiveFinTest(version) 205 206 def PollRepairSocketInMultipleThreads(self, netid, version, expected): 207 sock = self.createConnectedSocket(version, netid) 208 sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) 209 210 multiThreads = [] 211 for i in [0, 1]: 212 thread = SocketExceptionThread(sock, lambda sk: self.fdSelect(sock, expected)) 213 thread.start() 214 self.assertTrue(thread.is_alive()) 215 multiThreads.append(thread) 216 217 return sock, multiThreads 218 219 def assertThreadsStopped(self, multiThreads, msg) : 220 for thread in multiThreads: 221 if (thread.is_alive()): 222 thread.join(1) 223 if (thread.is_alive()): 224 thread.stop() 225 raise AssertionError(msg) 226 227 def PollWhenShutdownTest(self, version): 228 netid = self.RandomNetid() 229 expected = select.POLLIN 230 sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) 231 # Test shutdown RD. 232 sock.shutdown(SHUT_RD) 233 self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RD") 234 sock.close() 235 236 expected = None 237 sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) 238 # Test shutdown WR. 239 sock.shutdown(SHUT_WR) 240 self.assertThreadsStopped(multiThreads, "poll fail during SHUT_WR") 241 sock.close() 242 243 expected = select.POLLIN | select.POLLHUP 244 sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) 245 # Test shutdown RDWR. 246 sock.shutdown(SHUT_RDWR) 247 self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RDWR") 248 sock.close() 249 250 def PollWhenReceiveFinTest(self, version): 251 netid = self.RandomNetid() 252 expected = select.POLLIN 253 sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) 254 self.receiveFin(netid, version, sock) 255 self.assertThreadsStopped(multiThreads, "poll fail during FIN") 256 sock.close() 257 258 # Test whether socket idle can be detected by SIOCINQ and SIOCOUTQ. 259 def testSocketIdle(self): 260 for version in [4, 5, 6]: 261 self.readQueueIdleTest(version) 262 self.writeQueueIdleTest(version) 263 264 def readQueueIdleTest(self, version): 265 netid = self.RandomNetid() 266 sock = self.createConnectedSocket(version, netid) 267 268 buf = ctypes.c_int() 269 fcntl.ioctl(sock, SIOCINQ, buf) 270 self.assertEqual(buf.value, 0) 271 272 TEST_RECV_PAYLOAD = net_test.UDP_PAYLOAD 273 self.receiveData(netid, version, TEST_RECV_PAYLOAD) 274 fcntl.ioctl(sock, SIOCINQ, buf) 275 self.assertEqual(buf.value, len(TEST_RECV_PAYLOAD)) 276 sock.close() 277 278 def writeQueueIdleTest(self, version): 279 netid = self.RandomNetid() 280 # Setup a connected socket, write queue is empty. 281 sock = self.createConnectedSocket(version, netid) 282 buf = ctypes.c_int() 283 fcntl.ioctl(sock, SIOCOUTQ, buf) 284 self.assertEqual(buf.value, 0) 285 # Change to repair mode with SEND_QUEUE, writing some data to the queue. 286 sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) 287 TEST_SEND_PAYLOAD = net_test.UDP_PAYLOAD 288 sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE) 289 self.sendData(netid, version, sock, TEST_SEND_PAYLOAD) 290 fcntl.ioctl(sock, SIOCOUTQ, buf) 291 self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD)) 292 sock.close() 293 294 # Setup a connected socket again. 295 netid = self.RandomNetid() 296 sock = self.createConnectedSocket(version, netid) 297 # Send out some data and don't receive ACK yet. 298 self.sendData(netid, version, sock, TEST_SEND_PAYLOAD) 299 fcntl.ioctl(sock, SIOCOUTQ, buf) 300 self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD)) 301 # Receive response ACK. 302 remoteaddr = self.GetRemoteAddress(version) 303 myaddr = self.MyAddress(version, netid) 304 desc_ack, ack = packets.ACK(version, remoteaddr, myaddr, self.last_sent) 305 self.ReceivePacketOn(netid, ack) 306 fcntl.ioctl(sock, SIOCOUTQ, buf) 307 self.assertEqual(buf.value, 0) 308 sock.close() 309 310 311 def fdSelect(self, sock, expected): 312 READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR | select.POLLNVAL 313 p = select.poll() 314 p.register(sock, READ_ONLY) 315 events = p.poll(500) 316 for fd,event in events: 317 if fd == sock.fileno(): 318 self.assertEqual(event, expected) 319 else: 320 raise AssertionError("unexpected poll fd") 321 322class SocketExceptionThread(threading.Thread): 323 324 def __init__(self, sock, operation): 325 self.exception = None 326 super(SocketExceptionThread, self).__init__() 327 self.daemon = True 328 self.sock = sock 329 self.operation = operation 330 331 def stop(self): 332 self._Thread__stop() 333 334 def run(self): 335 try: 336 self.operation(self.sock) 337 except (IOError, AssertionError) as e: 338 self.exception = e 339 340if __name__ == '__main__': 341 unittest.main() 342