• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import contextlib
18import errno
19import fcntl
20import resource
21import os
22from socket import *  # pylint: disable=wildcard-import
23import struct
24import threading
25import time
26import unittest
27
28import csocket
29import cstruct
30import net_test
31
32IPV4_LOOPBACK_ADDR = "127.0.0.1"
33IPV6_LOOPBACK_ADDR = "::1"
34LOOPBACK_DEV = "lo"
35LOOPBACK_IFINDEX = 1
36
37SIOCKILLADDR = 0x8939
38
39DEFAULT_TCP_PORT = 8001
40DEFAULT_BUFFER_SIZE = 20
41DEFAULT_TEST_MESSAGE = "TCP NUKE ADDR TEST"
42DEFAULT_TEST_RUNS = 100
43HASH_TEST_RUNS = 4000
44HASH_TEST_NOFILE = 16384
45
46
47Ifreq = cstruct.Struct("Ifreq", "=16s16s", "name data")
48In6Ifreq = cstruct.Struct("In6Ifreq", "=16sIi", "addr prefixlen ifindex")
49
50@contextlib.contextmanager
51def RunInBackground(thread):
52  """Starts a thread and waits until it joins.
53
54  Args:
55    thread: A not yet started threading.Thread object.
56  """
57  try:
58    thread.start()
59    yield thread
60  finally:
61    thread.join()
62
63
64def TcpAcceptAndReceive(listening_sock, buffer_size=DEFAULT_BUFFER_SIZE):
65  """Accepts a single connection and blocks receiving data from it.
66
67  Args:
68    listening_socket: A socket in LISTEN state.
69    buffer_size: Size of buffer where to read a message.
70  """
71  connection, _ = listening_sock.accept()
72  with contextlib.closing(connection):
73    _ = connection.recv(buffer_size)
74
75
76def ExchangeMessage(addr_family, ip_addr):
77  """Creates a listening socket, accepts a connection and sends data to it.
78
79  Args:
80    addr_family: The address family (e.g. AF_INET6).
81    ip_addr: The IP address (IPv4 or IPv6 depending on the addr_family).
82    tcp_port: The TCP port to listen on.
83  """
84  # Bind to a random port and connect to it.
85  test_addr = (ip_addr, 0)
86  with contextlib.closing(
87      socket(addr_family, SOCK_STREAM)) as listening_socket:
88    listening_socket.bind(test_addr)
89    test_addr = listening_socket.getsockname()
90    listening_socket.listen(1)
91    with RunInBackground(threading.Thread(target=TcpAcceptAndReceive,
92                                          args=(listening_socket,))):
93      with contextlib.closing(
94          socket(addr_family, SOCK_STREAM)) as client_socket:
95        client_socket.connect(test_addr)
96        client_socket.send(DEFAULT_TEST_MESSAGE)
97
98
99def KillAddrIoctl(addr):
100  """Calls the SIOCKILLADDR ioctl on the provided IP address.
101
102  Args:
103    addr The IP address to pass to the ioctl.
104
105  Raises:
106    ValueError: If addr is of an unsupported address family.
107  """
108  family, _, _, _, _ = getaddrinfo(addr, None, AF_UNSPEC, SOCK_DGRAM, 0,
109                                   AI_NUMERICHOST)[0]
110  if family == AF_INET6:
111    addr = inet_pton(AF_INET6, addr)
112    ifreq = In6Ifreq((addr, 128, LOOPBACK_IFINDEX)).Pack()
113  elif family == AF_INET:
114    addr = inet_pton(AF_INET, addr)
115    sockaddr = csocket.SockaddrIn((AF_INET, 0, addr)).Pack()
116    ifreq = Ifreq((LOOPBACK_DEV, sockaddr)).Pack()
117  else:
118    raise ValueError('Address family %r not supported.' % family)
119  datagram_socket = socket(family, SOCK_DGRAM)
120  fcntl.ioctl(datagram_socket.fileno(), SIOCKILLADDR, ifreq)
121  datagram_socket.close()
122
123
124class ExceptionalReadThread(threading.Thread):
125
126  def __init__(self, sock):
127    self.sock = sock
128    self.exception = None
129    super(ExceptionalReadThread, self).__init__()
130    self.daemon = True
131
132  def run(self):
133    try:
134      read = self.sock.recv(4096)
135    except Exception, e:
136      self.exception = e
137
138# For convenience.
139def CreateIPv4SocketPair():
140  return net_test.CreateSocketPair(AF_INET, SOCK_STREAM, IPV4_LOOPBACK_ADDR)
141
142def CreateIPv6SocketPair():
143  return net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, IPV6_LOOPBACK_ADDR)
144
145
146class TcpNukeAddrTest(net_test.NetworkTest):
147
148  def testTimewaitSockets(self):
149    """Tests that SIOCKILLADDR works as expected.
150
151    Relevant kernel commits:
152      https://www.codeaurora.org/cgit/quic/la/kernel/msm-3.18/commit/net/ipv4/tcp.c?h=aosp/android-3.10&id=1dcd3a1fa2fe78251cc91700eb1d384ab02e2dd6
153    """
154    for i in xrange(DEFAULT_TEST_RUNS):
155      ExchangeMessage(AF_INET6, IPV6_LOOPBACK_ADDR)
156      KillAddrIoctl(IPV6_LOOPBACK_ADDR)
157      ExchangeMessage(AF_INET, IPV4_LOOPBACK_ADDR)
158      KillAddrIoctl(IPV4_LOOPBACK_ADDR)
159      # Test passes if kernel does not crash.
160
161  def testClosesIPv6Sockets(self):
162    """Tests that SIOCKILLADDR closes IPv6 sockets and unblocks threads."""
163
164    threadpairs = []
165
166    for i in xrange(DEFAULT_TEST_RUNS):
167      clientsock, acceptedsock = CreateIPv6SocketPair()
168      clientthread = ExceptionalReadThread(clientsock)
169      clientthread.start()
170      serverthread = ExceptionalReadThread(acceptedsock)
171      serverthread.start()
172      threadpairs.append((clientthread, serverthread))
173
174    KillAddrIoctl(IPV6_LOOPBACK_ADDR)
175
176    def CheckThreadException(thread):
177      thread.join(100)
178      self.assertFalse(thread.is_alive())
179      self.assertIsNotNone(thread.exception)
180      self.assertTrue(isinstance(thread.exception, IOError))
181      self.assertEquals(errno.ETIMEDOUT, thread.exception.errno)
182      self.assertRaisesErrno(errno.ENOTCONN, thread.sock.getpeername)
183      self.assertRaisesErrno(errno.EISCONN, thread.sock.connect,
184                             (IPV6_LOOPBACK_ADDR, 53))
185      self.assertRaisesErrno(errno.EPIPE, thread.sock.send, "foo")
186
187    for clientthread, serverthread in threadpairs:
188      CheckThreadException(clientthread)
189      CheckThreadException(serverthread)
190
191  def assertSocketsClosed(self, socketpair):
192    for sock in socketpair:
193      self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername)
194
195  def assertSocketsNotClosed(self, socketpair):
196    for sock in socketpair:
197      self.assertTrue(sock.getpeername())
198
199  def testAddresses(self):
200    socketpair = CreateIPv4SocketPair()
201    KillAddrIoctl("::")
202    self.assertSocketsNotClosed(socketpair)
203    KillAddrIoctl("::1")
204    self.assertSocketsNotClosed(socketpair)
205    KillAddrIoctl("127.0.0.3")
206    self.assertSocketsNotClosed(socketpair)
207    KillAddrIoctl("0.0.0.0")
208    self.assertSocketsNotClosed(socketpair)
209    KillAddrIoctl("127.0.0.1")
210    self.assertSocketsClosed(socketpair)
211
212    socketpair = CreateIPv6SocketPair()
213    KillAddrIoctl("0.0.0.0")
214    self.assertSocketsNotClosed(socketpair)
215    KillAddrIoctl("127.0.0.1")
216    self.assertSocketsNotClosed(socketpair)
217    KillAddrIoctl("::2")
218    self.assertSocketsNotClosed(socketpair)
219    KillAddrIoctl("::")
220    self.assertSocketsNotClosed(socketpair)
221    KillAddrIoctl("::1")
222    self.assertSocketsClosed(socketpair)
223
224
225class TcpNukeAddrHashTest(net_test.NetworkTest):
226
227  def setUp(self):
228    self.nofile = resource.getrlimit(resource.RLIMIT_NOFILE)
229    resource.setrlimit(resource.RLIMIT_NOFILE, (HASH_TEST_NOFILE,
230                                                HASH_TEST_NOFILE))
231
232  def tearDown(self):
233    resource.setrlimit(resource.RLIMIT_NOFILE, self.nofile)
234
235  def testClosesAllSockets(self):
236    socketpairs = []
237    for i in xrange(HASH_TEST_RUNS):
238      socketpairs.append(CreateIPv4SocketPair())
239      socketpairs.append(CreateIPv6SocketPair())
240
241    KillAddrIoctl(IPV4_LOOPBACK_ADDR)
242    KillAddrIoctl(IPV6_LOOPBACK_ADDR)
243
244    for socketpair in socketpairs:
245      for sock in socketpair:
246        self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername)
247
248
249if __name__ == "__main__":
250  unittest.main()
251