1import asyncio 2import asyncio.events 3import contextlib 4import os 5import pprint 6import select 7import socket 8import tempfile 9import threading 10 11 12class FunctionalTestCaseMixin: 13 14 def new_loop(self): 15 return asyncio.new_event_loop() 16 17 def run_loop_briefly(self, *, delay=0.01): 18 self.loop.run_until_complete(asyncio.sleep(delay)) 19 20 def loop_exception_handler(self, loop, context): 21 self.__unhandled_exceptions.append(context) 22 self.loop.default_exception_handler(context) 23 24 def setUp(self): 25 self.loop = self.new_loop() 26 asyncio.set_event_loop(None) 27 28 self.loop.set_exception_handler(self.loop_exception_handler) 29 self.__unhandled_exceptions = [] 30 31 # Disable `_get_running_loop`. 32 self._old_get_running_loop = asyncio.events._get_running_loop 33 asyncio.events._get_running_loop = lambda: None 34 35 def tearDown(self): 36 try: 37 self.loop.close() 38 39 if self.__unhandled_exceptions: 40 print('Unexpected calls to loop.call_exception_handler():') 41 pprint.pprint(self.__unhandled_exceptions) 42 self.fail('unexpected calls to loop.call_exception_handler()') 43 44 finally: 45 asyncio.events._get_running_loop = self._old_get_running_loop 46 asyncio.set_event_loop(None) 47 self.loop = None 48 49 def tcp_server(self, server_prog, *, 50 family=socket.AF_INET, 51 addr=None, 52 timeout=5, 53 backlog=1, 54 max_clients=10): 55 56 if addr is None: 57 if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: 58 with tempfile.NamedTemporaryFile() as tmp: 59 addr = tmp.name 60 else: 61 addr = ('127.0.0.1', 0) 62 63 sock = socket.create_server(addr, family=family, backlog=backlog) 64 if timeout is None: 65 raise RuntimeError('timeout is required') 66 if timeout <= 0: 67 raise RuntimeError('only blocking sockets are supported') 68 sock.settimeout(timeout) 69 70 return TestThreadedServer( 71 self, sock, server_prog, timeout, max_clients) 72 73 def tcp_client(self, client_prog, 74 family=socket.AF_INET, 75 timeout=10): 76 77 sock = socket.socket(family, socket.SOCK_STREAM) 78 79 if timeout is None: 80 raise RuntimeError('timeout is required') 81 if timeout <= 0: 82 raise RuntimeError('only blocking sockets are supported') 83 sock.settimeout(timeout) 84 85 return TestThreadedClient( 86 self, sock, client_prog, timeout) 87 88 def unix_server(self, *args, **kwargs): 89 if not hasattr(socket, 'AF_UNIX'): 90 raise NotImplementedError 91 return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) 92 93 def unix_client(self, *args, **kwargs): 94 if not hasattr(socket, 'AF_UNIX'): 95 raise NotImplementedError 96 return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) 97 98 @contextlib.contextmanager 99 def unix_sock_name(self): 100 with tempfile.TemporaryDirectory() as td: 101 fn = os.path.join(td, 'sock') 102 try: 103 yield fn 104 finally: 105 try: 106 os.unlink(fn) 107 except OSError: 108 pass 109 110 def _abort_socket_test(self, ex): 111 try: 112 self.loop.stop() 113 finally: 114 self.fail(ex) 115 116 117############################################################################## 118# Socket Testing Utilities 119############################################################################## 120 121 122class TestSocketWrapper: 123 124 def __init__(self, sock): 125 self.__sock = sock 126 127 def recv_all(self, n): 128 buf = b'' 129 while len(buf) < n: 130 data = self.recv(n - len(buf)) 131 if data == b'': 132 raise ConnectionAbortedError 133 buf += data 134 return buf 135 136 def start_tls(self, ssl_context, *, 137 server_side=False, 138 server_hostname=None): 139 140 ssl_sock = ssl_context.wrap_socket( 141 self.__sock, server_side=server_side, 142 server_hostname=server_hostname, 143 do_handshake_on_connect=False) 144 145 try: 146 ssl_sock.do_handshake() 147 except: 148 ssl_sock.close() 149 raise 150 finally: 151 self.__sock.close() 152 153 self.__sock = ssl_sock 154 155 def __getattr__(self, name): 156 return getattr(self.__sock, name) 157 158 def __repr__(self): 159 return '<{} {!r}>'.format(type(self).__name__, self.__sock) 160 161 162class SocketThread(threading.Thread): 163 164 def stop(self): 165 self._active = False 166 self.join() 167 168 def __enter__(self): 169 self.start() 170 return self 171 172 def __exit__(self, *exc): 173 self.stop() 174 175 176class TestThreadedClient(SocketThread): 177 178 def __init__(self, test, sock, prog, timeout): 179 threading.Thread.__init__(self, None, None, 'test-client') 180 self.daemon = True 181 182 self._timeout = timeout 183 self._sock = sock 184 self._active = True 185 self._prog = prog 186 self._test = test 187 188 def run(self): 189 try: 190 self._prog(TestSocketWrapper(self._sock)) 191 except Exception as ex: 192 self._test._abort_socket_test(ex) 193 194 195class TestThreadedServer(SocketThread): 196 197 def __init__(self, test, sock, prog, timeout, max_clients): 198 threading.Thread.__init__(self, None, None, 'test-server') 199 self.daemon = True 200 201 self._clients = 0 202 self._finished_clients = 0 203 self._max_clients = max_clients 204 self._timeout = timeout 205 self._sock = sock 206 self._active = True 207 208 self._prog = prog 209 210 self._s1, self._s2 = socket.socketpair() 211 self._s1.setblocking(False) 212 213 self._test = test 214 215 def stop(self): 216 try: 217 if self._s2 and self._s2.fileno() != -1: 218 try: 219 self._s2.send(b'stop') 220 except OSError: 221 pass 222 finally: 223 super().stop() 224 225 def run(self): 226 try: 227 with self._sock: 228 self._sock.setblocking(0) 229 self._run() 230 finally: 231 self._s1.close() 232 self._s2.close() 233 234 def _run(self): 235 while self._active: 236 if self._clients >= self._max_clients: 237 return 238 239 r, w, x = select.select( 240 [self._sock, self._s1], [], [], self._timeout) 241 242 if self._s1 in r: 243 return 244 245 if self._sock in r: 246 try: 247 conn, addr = self._sock.accept() 248 except BlockingIOError: 249 continue 250 except socket.timeout: 251 if not self._active: 252 return 253 else: 254 raise 255 else: 256 self._clients += 1 257 conn.settimeout(self._timeout) 258 try: 259 with conn: 260 self._handle_client(conn) 261 except Exception as ex: 262 self._active = False 263 try: 264 raise 265 finally: 266 self._test._abort_socket_test(ex) 267 268 def _handle_client(self, sock): 269 self._prog(TestSocketWrapper(sock)) 270 271 @property 272 def addr(self): 273 return self._sock.getsockname() 274