1import asyncio 2import unittest 3 4from test.test_asyncio import functional as func_tests 5 6 7def tearDownModule(): 8 asyncio.set_event_loop_policy(None) 9 10 11class ReceiveStuffProto(asyncio.BufferedProtocol): 12 def __init__(self, cb, con_lost_fut): 13 self.cb = cb 14 self.con_lost_fut = con_lost_fut 15 16 def get_buffer(self, sizehint): 17 self.buffer = bytearray(100) 18 return self.buffer 19 20 def buffer_updated(self, nbytes): 21 self.cb(self.buffer[:nbytes]) 22 23 def connection_lost(self, exc): 24 if exc is None: 25 self.con_lost_fut.set_result(None) 26 else: 27 self.con_lost_fut.set_exception(exc) 28 29 30class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin): 31 32 def new_loop(self): 33 raise NotImplementedError 34 35 def test_buffered_proto_create_connection(self): 36 37 NOISE = b'12345678+' * 1024 38 39 async def client(addr): 40 data = b'' 41 42 def on_buf(buf): 43 nonlocal data 44 data += buf 45 if data == NOISE: 46 tr.write(b'1') 47 48 conn_lost_fut = self.loop.create_future() 49 50 tr, pr = await self.loop.create_connection( 51 lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr) 52 53 await conn_lost_fut 54 55 async def on_server_client(reader, writer): 56 writer.write(NOISE) 57 await reader.readexactly(1) 58 writer.close() 59 await writer.wait_closed() 60 61 srv = self.loop.run_until_complete( 62 asyncio.start_server( 63 on_server_client, '127.0.0.1', 0)) 64 65 addr = srv.sockets[0].getsockname() 66 self.loop.run_until_complete( 67 asyncio.wait_for(client(addr), 5)) 68 69 srv.close() 70 self.loop.run_until_complete(srv.wait_closed()) 71 72 73class BufferedProtocolSelectorTests(BaseTestBufferedProtocol, 74 unittest.TestCase): 75 76 def new_loop(self): 77 return asyncio.SelectorEventLoop() 78 79 80@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') 81class BufferedProtocolProactorTests(BaseTestBufferedProtocol, 82 unittest.TestCase): 83 84 def new_loop(self): 85 return asyncio.ProactorEventLoop() 86 87 88if __name__ == '__main__': 89 unittest.main() 90