• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import asyncio
2import asyncio.events
3import contextlib
4import os
5import pprint
6import select
7import socket
8import tempfile
9import threading
10
11
12class FunctionalTestCaseMixin:
13
14    def new_loop(self):
15        return asyncio.new_event_loop()
16
17    def run_loop_briefly(self, *, delay=0.01):
18        self.loop.run_until_complete(asyncio.sleep(delay))
19
20    def loop_exception_handler(self, loop, context):
21        self.__unhandled_exceptions.append(context)
22        self.loop.default_exception_handler(context)
23
24    def setUp(self):
25        self.loop = self.new_loop()
26        asyncio.set_event_loop(None)
27
28        self.loop.set_exception_handler(self.loop_exception_handler)
29        self.__unhandled_exceptions = []
30
31        # Disable `_get_running_loop`.
32        self._old_get_running_loop = asyncio.events._get_running_loop
33        asyncio.events._get_running_loop = lambda: None
34
35    def tearDown(self):
36        try:
37            self.loop.close()
38
39            if self.__unhandled_exceptions:
40                print('Unexpected calls to loop.call_exception_handler():')
41                pprint.pprint(self.__unhandled_exceptions)
42                self.fail('unexpected calls to loop.call_exception_handler()')
43
44        finally:
45            asyncio.events._get_running_loop = self._old_get_running_loop
46            asyncio.set_event_loop(None)
47            self.loop = None
48
49    def tcp_server(self, server_prog, *,
50                   family=socket.AF_INET,
51                   addr=None,
52                   timeout=5,
53                   backlog=1,
54                   max_clients=10):
55
56        if addr is None:
57            if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
58                with tempfile.NamedTemporaryFile() as tmp:
59                    addr = tmp.name
60            else:
61                addr = ('127.0.0.1', 0)
62
63        sock = socket.create_server(addr, family=family, backlog=backlog)
64        if timeout is None:
65            raise RuntimeError('timeout is required')
66        if timeout <= 0:
67            raise RuntimeError('only blocking sockets are supported')
68        sock.settimeout(timeout)
69
70        return TestThreadedServer(
71            self, sock, server_prog, timeout, max_clients)
72
73    def tcp_client(self, client_prog,
74                   family=socket.AF_INET,
75                   timeout=10):
76
77        sock = socket.socket(family, socket.SOCK_STREAM)
78
79        if timeout is None:
80            raise RuntimeError('timeout is required')
81        if timeout <= 0:
82            raise RuntimeError('only blocking sockets are supported')
83        sock.settimeout(timeout)
84
85        return TestThreadedClient(
86            self, sock, client_prog, timeout)
87
88    def unix_server(self, *args, **kwargs):
89        if not hasattr(socket, 'AF_UNIX'):
90            raise NotImplementedError
91        return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
92
93    def unix_client(self, *args, **kwargs):
94        if not hasattr(socket, 'AF_UNIX'):
95            raise NotImplementedError
96        return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
97
98    @contextlib.contextmanager
99    def unix_sock_name(self):
100        with tempfile.TemporaryDirectory() as td:
101            fn = os.path.join(td, 'sock')
102            try:
103                yield fn
104            finally:
105                try:
106                    os.unlink(fn)
107                except OSError:
108                    pass
109
110    def _abort_socket_test(self, ex):
111        try:
112            self.loop.stop()
113        finally:
114            self.fail(ex)
115
116
117##############################################################################
118# Socket Testing Utilities
119##############################################################################
120
121
122class TestSocketWrapper:
123
124    def __init__(self, sock):
125        self.__sock = sock
126
127    def recv_all(self, n):
128        buf = b''
129        while len(buf) < n:
130            data = self.recv(n - len(buf))
131            if data == b'':
132                raise ConnectionAbortedError
133            buf += data
134        return buf
135
136    def start_tls(self, ssl_context, *,
137                  server_side=False,
138                  server_hostname=None):
139
140        ssl_sock = ssl_context.wrap_socket(
141            self.__sock, server_side=server_side,
142            server_hostname=server_hostname,
143            do_handshake_on_connect=False)
144
145        try:
146            ssl_sock.do_handshake()
147        except:
148            ssl_sock.close()
149            raise
150        finally:
151            self.__sock.close()
152
153        self.__sock = ssl_sock
154
155    def __getattr__(self, name):
156        return getattr(self.__sock, name)
157
158    def __repr__(self):
159        return '<{} {!r}>'.format(type(self).__name__, self.__sock)
160
161
162class SocketThread(threading.Thread):
163
164    def stop(self):
165        self._active = False
166        self.join()
167
168    def __enter__(self):
169        self.start()
170        return self
171
172    def __exit__(self, *exc):
173        self.stop()
174
175
176class TestThreadedClient(SocketThread):
177
178    def __init__(self, test, sock, prog, timeout):
179        threading.Thread.__init__(self, None, None, 'test-client')
180        self.daemon = True
181
182        self._timeout = timeout
183        self._sock = sock
184        self._active = True
185        self._prog = prog
186        self._test = test
187
188    def run(self):
189        try:
190            self._prog(TestSocketWrapper(self._sock))
191        except Exception as ex:
192            self._test._abort_socket_test(ex)
193
194
195class TestThreadedServer(SocketThread):
196
197    def __init__(self, test, sock, prog, timeout, max_clients):
198        threading.Thread.__init__(self, None, None, 'test-server')
199        self.daemon = True
200
201        self._clients = 0
202        self._finished_clients = 0
203        self._max_clients = max_clients
204        self._timeout = timeout
205        self._sock = sock
206        self._active = True
207
208        self._prog = prog
209
210        self._s1, self._s2 = socket.socketpair()
211        self._s1.setblocking(False)
212
213        self._test = test
214
215    def stop(self):
216        try:
217            if self._s2 and self._s2.fileno() != -1:
218                try:
219                    self._s2.send(b'stop')
220                except OSError:
221                    pass
222        finally:
223            super().stop()
224
225    def run(self):
226        try:
227            with self._sock:
228                self._sock.setblocking(0)
229                self._run()
230        finally:
231            self._s1.close()
232            self._s2.close()
233
234    def _run(self):
235        while self._active:
236            if self._clients >= self._max_clients:
237                return
238
239            r, w, x = select.select(
240                [self._sock, self._s1], [], [], self._timeout)
241
242            if self._s1 in r:
243                return
244
245            if self._sock in r:
246                try:
247                    conn, addr = self._sock.accept()
248                except BlockingIOError:
249                    continue
250                except socket.timeout:
251                    if not self._active:
252                        return
253                    else:
254                        raise
255                else:
256                    self._clients += 1
257                    conn.settimeout(self._timeout)
258                    try:
259                        with conn:
260                            self._handle_client(conn)
261                    except Exception as ex:
262                        self._active = False
263                        try:
264                            raise
265                        finally:
266                            self._test._abort_socket_test(ex)
267
268    def _handle_client(self, sock):
269        self._prog(TestSocketWrapper(sock))
270
271    @property
272    def addr(self):
273        return self._sock.getsockname()
274