• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import asyncio
2import asyncio.events
3import contextlib
4import os
5import pprint
6import select
7import socket
8import tempfile
9import threading
10from test import support
11
12
13class FunctionalTestCaseMixin:
14
15    def new_loop(self):
16        return asyncio.new_event_loop()
17
18    def run_loop_briefly(self, *, delay=0.01):
19        self.loop.run_until_complete(asyncio.sleep(delay))
20
21    def loop_exception_handler(self, loop, context):
22        self.__unhandled_exceptions.append(context)
23        self.loop.default_exception_handler(context)
24
25    def setUp(self):
26        self.loop = self.new_loop()
27        asyncio.set_event_loop(None)
28
29        self.loop.set_exception_handler(self.loop_exception_handler)
30        self.__unhandled_exceptions = []
31
32    def tearDown(self):
33        try:
34            self.loop.close()
35
36            if self.__unhandled_exceptions:
37                print('Unexpected calls to loop.call_exception_handler():')
38                pprint.pprint(self.__unhandled_exceptions)
39                self.fail('unexpected calls to loop.call_exception_handler()')
40
41        finally:
42            asyncio.set_event_loop(None)
43            self.loop = None
44
45    def tcp_server(self, server_prog, *,
46                   family=socket.AF_INET,
47                   addr=None,
48                   timeout=support.LOOPBACK_TIMEOUT,
49                   backlog=1,
50                   max_clients=10):
51
52        if addr is None:
53            if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
54                with tempfile.NamedTemporaryFile() as tmp:
55                    addr = tmp.name
56            else:
57                addr = ('127.0.0.1', 0)
58
59        sock = socket.create_server(addr, family=family, backlog=backlog)
60        if timeout is None:
61            raise RuntimeError('timeout is required')
62        if timeout <= 0:
63            raise RuntimeError('only blocking sockets are supported')
64        sock.settimeout(timeout)
65
66        return TestThreadedServer(
67            self, sock, server_prog, timeout, max_clients)
68
69    def tcp_client(self, client_prog,
70                   family=socket.AF_INET,
71                   timeout=support.LOOPBACK_TIMEOUT):
72
73        sock = socket.socket(family, socket.SOCK_STREAM)
74
75        if timeout is None:
76            raise RuntimeError('timeout is required')
77        if timeout <= 0:
78            raise RuntimeError('only blocking sockets are supported')
79        sock.settimeout(timeout)
80
81        return TestThreadedClient(
82            self, sock, client_prog, timeout)
83
84    def unix_server(self, *args, **kwargs):
85        if not hasattr(socket, 'AF_UNIX'):
86            raise NotImplementedError
87        return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
88
89    def unix_client(self, *args, **kwargs):
90        if not hasattr(socket, 'AF_UNIX'):
91            raise NotImplementedError
92        return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
93
94    @contextlib.contextmanager
95    def unix_sock_name(self):
96        with tempfile.TemporaryDirectory() as td:
97            fn = os.path.join(td, 'sock')
98            try:
99                yield fn
100            finally:
101                try:
102                    os.unlink(fn)
103                except OSError:
104                    pass
105
106    def _abort_socket_test(self, ex):
107        try:
108            self.loop.stop()
109        finally:
110            self.fail(ex)
111
112
113##############################################################################
114# Socket Testing Utilities
115##############################################################################
116
117
118class TestSocketWrapper:
119
120    def __init__(self, sock):
121        self.__sock = sock
122
123    def recv_all(self, n):
124        buf = b''
125        while len(buf) < n:
126            data = self.recv(n - len(buf))
127            if data == b'':
128                raise ConnectionAbortedError
129            buf += data
130        return buf
131
132    def start_tls(self, ssl_context, *,
133                  server_side=False,
134                  server_hostname=None):
135
136        ssl_sock = ssl_context.wrap_socket(
137            self.__sock, server_side=server_side,
138            server_hostname=server_hostname,
139            do_handshake_on_connect=False)
140
141        try:
142            ssl_sock.do_handshake()
143        except:
144            ssl_sock.close()
145            raise
146        finally:
147            self.__sock.close()
148
149        self.__sock = ssl_sock
150
151    def __getattr__(self, name):
152        return getattr(self.__sock, name)
153
154    def __repr__(self):
155        return '<{} {!r}>'.format(type(self).__name__, self.__sock)
156
157
158class SocketThread(threading.Thread):
159
160    def stop(self):
161        self._active = False
162        self.join()
163
164    def __enter__(self):
165        self.start()
166        return self
167
168    def __exit__(self, *exc):
169        self.stop()
170
171
172class TestThreadedClient(SocketThread):
173
174    def __init__(self, test, sock, prog, timeout):
175        threading.Thread.__init__(self, None, None, 'test-client')
176        self.daemon = True
177
178        self._timeout = timeout
179        self._sock = sock
180        self._active = True
181        self._prog = prog
182        self._test = test
183
184    def run(self):
185        try:
186            self._prog(TestSocketWrapper(self._sock))
187        except Exception as ex:
188            self._test._abort_socket_test(ex)
189
190
191class TestThreadedServer(SocketThread):
192
193    def __init__(self, test, sock, prog, timeout, max_clients):
194        threading.Thread.__init__(self, None, None, 'test-server')
195        self.daemon = True
196
197        self._clients = 0
198        self._finished_clients = 0
199        self._max_clients = max_clients
200        self._timeout = timeout
201        self._sock = sock
202        self._active = True
203
204        self._prog = prog
205
206        self._s1, self._s2 = socket.socketpair()
207        self._s1.setblocking(False)
208
209        self._test = test
210
211    def stop(self):
212        try:
213            if self._s2 and self._s2.fileno() != -1:
214                try:
215                    self._s2.send(b'stop')
216                except OSError:
217                    pass
218        finally:
219            super().stop()
220
221    def run(self):
222        try:
223            with self._sock:
224                self._sock.setblocking(False)
225                self._run()
226        finally:
227            self._s1.close()
228            self._s2.close()
229
230    def _run(self):
231        while self._active:
232            if self._clients >= self._max_clients:
233                return
234
235            r, w, x = select.select(
236                [self._sock, self._s1], [], [], self._timeout)
237
238            if self._s1 in r:
239                return
240
241            if self._sock in r:
242                try:
243                    conn, addr = self._sock.accept()
244                except BlockingIOError:
245                    continue
246                except TimeoutError:
247                    if not self._active:
248                        return
249                    else:
250                        raise
251                else:
252                    self._clients += 1
253                    conn.settimeout(self._timeout)
254                    try:
255                        with conn:
256                            self._handle_client(conn)
257                    except Exception as ex:
258                        self._active = False
259                        try:
260                            raise
261                        finally:
262                            self._test._abort_socket_test(ex)
263
264    def _handle_client(self, sock):
265        self._prog(TestSocketWrapper(sock))
266
267    @property
268    def addr(self):
269        return self._sock.getsockname()
270