1#!/usr/bin/python 2# 3# Copyright 2017 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import unittest 18 19from errno import * 20from socket import * 21from scapy import all as scapy 22 23import multinetwork_base 24import net_test 25import packets 26import tcp_metrics 27 28 29TCPOPT_FASTOPEN = 34 30TCP_FASTOPEN_CONNECT = 30 31 32 33class TcpFastOpenTest(multinetwork_base.MultiNetworkBaseTest): 34 35 @classmethod 36 def setUpClass(cls): 37 super(TcpFastOpenTest, cls).setUpClass() 38 cls.tcp_metrics = tcp_metrics.TcpMetrics() 39 40 def TFOClientSocket(self, version, netid): 41 s = net_test.TCPSocket(net_test.GetAddressFamily(version)) 42 net_test.DisableFinWait(s) 43 self.SelectInterface(s, netid, "mark") 44 s.setsockopt(IPPROTO_TCP, TCP_FASTOPEN_CONNECT, 1) 45 return s 46 47 def assertSocketNotConnected(self, sock): 48 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 49 50 def assertSocketConnected(self, sock): 51 sock.getpeername() # No errors? Socket is alive and connected. 52 53 def clearTcpMetrics(self, version, netid): 54 saddr = self.MyAddress(version, netid) 55 daddr = self.GetRemoteAddress(version) 56 self.tcp_metrics.DelMetrics(saddr, daddr) 57 with self.assertRaisesErrno(ESRCH): 58 print self.tcp_metrics.GetMetrics(saddr, daddr) 59 60 def assertNoTcpMetrics(self, version, netid): 61 saddr = self.MyAddress(version, netid) 62 daddr = self.GetRemoteAddress(version) 63 with self.assertRaisesErrno(ENOENT): 64 self.tcp_metrics.GetMetrics(saddr, daddr) 65 66 def CheckConnectOption(self, version): 67 ip_layer = {4: scapy.IP, 6: scapy.IPv6}[version] 68 netid = self.RandomNetid() 69 s = self.TFOClientSocket(version, netid) 70 71 self.clearTcpMetrics(version, netid) 72 73 # Connect the first time. 74 remoteaddr = self.GetRemoteAddress(version) 75 with self.assertRaisesErrno(EINPROGRESS): 76 s.connect((remoteaddr, 53)) 77 self.assertSocketNotConnected(s) 78 79 # Expect a SYN handshake with an empty TFO option. 80 myaddr = self.MyAddress(version, netid) 81 port = s.getsockname()[1] 82 self.assertNotEqual(0, port) 83 desc, syn = packets.SYN(53, version, myaddr, remoteaddr, port, seq=None) 84 syn.getlayer("TCP").options = [(TCPOPT_FASTOPEN, "")] 85 msg = "Fastopen connect: expected %s" % desc 86 syn = self.ExpectPacketOn(netid, msg, syn) 87 syn = ip_layer(str(syn)) 88 89 # Receive a SYN+ACK with a TFO cookie and expect the connection to proceed 90 # as normal. 91 desc, synack = packets.SYNACK(version, remoteaddr, myaddr, syn) 92 synack.getlayer("TCP").options = [ 93 (TCPOPT_FASTOPEN, "helloT"), ("NOP", None), ("NOP", None)] 94 self.ReceivePacketOn(netid, synack) 95 synack = ip_layer(str(synack)) 96 desc, ack = packets.ACK(version, myaddr, remoteaddr, synack) 97 msg = "First connect: got SYN+ACK, expected %s" % desc 98 self.ExpectPacketOn(netid, msg, ack) 99 self.assertSocketConnected(s) 100 s.close() 101 desc, rst = packets.RST(version, myaddr, remoteaddr, synack) 102 msg = "Closing client socket, expecting %s" % desc 103 self.ExpectPacketOn(netid, msg, rst) 104 105 # Connect to the same destination again. Expect the connect to succeed 106 # without sending a SYN packet. 107 s = self.TFOClientSocket(version, netid) 108 s.connect((remoteaddr, 53)) 109 self.assertSocketNotConnected(s) 110 self.ExpectNoPacketsOn(netid, "Second TFO connect, expected no packets") 111 112 # Issue a write and expect a SYN with data. 113 port = s.getsockname()[1] 114 s.send(net_test.UDP_PAYLOAD) 115 desc, syn = packets.SYN(53, version, myaddr, remoteaddr, port, seq=None) 116 t = syn.getlayer(scapy.TCP) 117 t.options = [ (TCPOPT_FASTOPEN, "helloT"), ("NOP", None), ("NOP", None)] 118 t.payload = scapy.Raw(net_test.UDP_PAYLOAD) 119 msg = "TFO write, expected %s" % desc 120 self.ExpectPacketOn(netid, msg, syn) 121 122 @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not yet backported") 123 def testConnectOptionIPv4(self): 124 self.CheckConnectOption(4) 125 126 @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not yet backported") 127 def testConnectOptionIPv6(self): 128 self.CheckConnectOption(6) 129 130 131if __name__ == "__main__": 132 unittest.main() 133