1import asyncio 2import asyncio.events 3import contextlib 4import os 5import pprint 6import select 7import socket 8import tempfile 9import threading 10from test import support 11 12 13class FunctionalTestCaseMixin: 14 15 def new_loop(self): 16 return asyncio.new_event_loop() 17 18 def run_loop_briefly(self, *, delay=0.01): 19 self.loop.run_until_complete(asyncio.sleep(delay)) 20 21 def loop_exception_handler(self, loop, context): 22 self.__unhandled_exceptions.append(context) 23 self.loop.default_exception_handler(context) 24 25 def setUp(self): 26 self.loop = self.new_loop() 27 asyncio.set_event_loop(None) 28 29 self.loop.set_exception_handler(self.loop_exception_handler) 30 self.__unhandled_exceptions = [] 31 32 def tearDown(self): 33 try: 34 self.loop.close() 35 36 if self.__unhandled_exceptions: 37 print('Unexpected calls to loop.call_exception_handler():') 38 pprint.pprint(self.__unhandled_exceptions) 39 self.fail('unexpected calls to loop.call_exception_handler()') 40 41 finally: 42 asyncio.set_event_loop(None) 43 self.loop = None 44 45 def tcp_server(self, server_prog, *, 46 family=socket.AF_INET, 47 addr=None, 48 timeout=support.LOOPBACK_TIMEOUT, 49 backlog=1, 50 max_clients=10): 51 52 if addr is None: 53 if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: 54 with tempfile.NamedTemporaryFile() as tmp: 55 addr = tmp.name 56 else: 57 addr = ('127.0.0.1', 0) 58 59 sock = socket.create_server(addr, family=family, backlog=backlog) 60 if timeout is None: 61 raise RuntimeError('timeout is required') 62 if timeout <= 0: 63 raise RuntimeError('only blocking sockets are supported') 64 sock.settimeout(timeout) 65 66 return TestThreadedServer( 67 self, sock, server_prog, timeout, max_clients) 68 69 def tcp_client(self, client_prog, 70 family=socket.AF_INET, 71 timeout=support.LOOPBACK_TIMEOUT): 72 73 sock = socket.socket(family, socket.SOCK_STREAM) 74 75 if timeout is None: 76 raise RuntimeError('timeout is required') 77 if timeout <= 0: 78 raise RuntimeError('only blocking sockets are supported') 79 sock.settimeout(timeout) 80 81 return TestThreadedClient( 82 self, sock, client_prog, timeout) 83 84 def unix_server(self, *args, **kwargs): 85 if not hasattr(socket, 'AF_UNIX'): 86 raise NotImplementedError 87 return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) 88 89 def unix_client(self, *args, **kwargs): 90 if not hasattr(socket, 'AF_UNIX'): 91 raise NotImplementedError 92 return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) 93 94 @contextlib.contextmanager 95 def unix_sock_name(self): 96 with tempfile.TemporaryDirectory() as td: 97 fn = os.path.join(td, 'sock') 98 try: 99 yield fn 100 finally: 101 try: 102 os.unlink(fn) 103 except OSError: 104 pass 105 106 def _abort_socket_test(self, ex): 107 try: 108 self.loop.stop() 109 finally: 110 self.fail(ex) 111 112 113############################################################################## 114# Socket Testing Utilities 115############################################################################## 116 117 118class TestSocketWrapper: 119 120 def __init__(self, sock): 121 self.__sock = sock 122 123 def recv_all(self, n): 124 buf = b'' 125 while len(buf) < n: 126 data = self.recv(n - len(buf)) 127 if data == b'': 128 raise ConnectionAbortedError 129 buf += data 130 return buf 131 132 def start_tls(self, ssl_context, *, 133 server_side=False, 134 server_hostname=None): 135 136 ssl_sock = ssl_context.wrap_socket( 137 self.__sock, server_side=server_side, 138 server_hostname=server_hostname, 139 do_handshake_on_connect=False) 140 141 try: 142 ssl_sock.do_handshake() 143 except: 144 ssl_sock.close() 145 raise 146 finally: 147 self.__sock.close() 148 149 self.__sock = ssl_sock 150 151 def __getattr__(self, name): 152 return getattr(self.__sock, name) 153 154 def __repr__(self): 155 return '<{} {!r}>'.format(type(self).__name__, self.__sock) 156 157 158class SocketThread(threading.Thread): 159 160 def stop(self): 161 self._active = False 162 self.join() 163 164 def __enter__(self): 165 self.start() 166 return self 167 168 def __exit__(self, *exc): 169 self.stop() 170 171 172class TestThreadedClient(SocketThread): 173 174 def __init__(self, test, sock, prog, timeout): 175 threading.Thread.__init__(self, None, None, 'test-client') 176 self.daemon = True 177 178 self._timeout = timeout 179 self._sock = sock 180 self._active = True 181 self._prog = prog 182 self._test = test 183 184 def run(self): 185 try: 186 self._prog(TestSocketWrapper(self._sock)) 187 except Exception as ex: 188 self._test._abort_socket_test(ex) 189 190 191class TestThreadedServer(SocketThread): 192 193 def __init__(self, test, sock, prog, timeout, max_clients): 194 threading.Thread.__init__(self, None, None, 'test-server') 195 self.daemon = True 196 197 self._clients = 0 198 self._finished_clients = 0 199 self._max_clients = max_clients 200 self._timeout = timeout 201 self._sock = sock 202 self._active = True 203 204 self._prog = prog 205 206 self._s1, self._s2 = socket.socketpair() 207 self._s1.setblocking(False) 208 209 self._test = test 210 211 def stop(self): 212 try: 213 if self._s2 and self._s2.fileno() != -1: 214 try: 215 self._s2.send(b'stop') 216 except OSError: 217 pass 218 finally: 219 super().stop() 220 221 def run(self): 222 try: 223 with self._sock: 224 self._sock.setblocking(False) 225 self._run() 226 finally: 227 self._s1.close() 228 self._s2.close() 229 230 def _run(self): 231 while self._active: 232 if self._clients >= self._max_clients: 233 return 234 235 r, w, x = select.select( 236 [self._sock, self._s1], [], [], self._timeout) 237 238 if self._s1 in r: 239 return 240 241 if self._sock in r: 242 try: 243 conn, addr = self._sock.accept() 244 except BlockingIOError: 245 continue 246 except TimeoutError: 247 if not self._active: 248 return 249 else: 250 raise 251 else: 252 self._clients += 1 253 conn.settimeout(self._timeout) 254 try: 255 with conn: 256 self._handle_client(conn) 257 except Exception as ex: 258 self._active = False 259 try: 260 raise 261 finally: 262 self._test._abort_socket_test(ex) 263 264 def _handle_client(self, sock): 265 self._prog(TestSocketWrapper(sock)) 266 267 @property 268 def addr(self): 269 return self._sock.getsockname() 270