• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for asyncio/sslproto.py."""
2
3import logging
4import socket
5from test import support
6import unittest
7import weakref
8from unittest import mock
9try:
10    import ssl
11except ImportError:
12    ssl = None
13
14import asyncio
15from asyncio import log
16from asyncio import protocols
17from asyncio import sslproto
18from test import support
19from test.test_asyncio import utils as test_utils
20from test.test_asyncio import functional as func_tests
21
22
23def tearDownModule():
24    asyncio.set_event_loop_policy(None)
25
26
27@unittest.skipIf(ssl is None, 'No ssl module')
28class SslProtoHandshakeTests(test_utils.TestCase):
29
30    def setUp(self):
31        super().setUp()
32        self.loop = asyncio.new_event_loop()
33        self.set_event_loop(self.loop)
34
35    def ssl_protocol(self, *, waiter=None, proto=None):
36        sslcontext = test_utils.dummy_ssl_context()
37        if proto is None:  # app protocol
38            proto = asyncio.Protocol()
39        ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
40                                         ssl_handshake_timeout=0.1)
41        self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
42        self.addCleanup(ssl_proto._app_transport.close)
43        return ssl_proto
44
45    def connection_made(self, ssl_proto, *, do_handshake=None):
46        transport = mock.Mock()
47        sslpipe = mock.Mock()
48        sslpipe.shutdown.return_value = b''
49        if do_handshake:
50            sslpipe.do_handshake.side_effect = do_handshake
51        else:
52            def mock_handshake(callback):
53                return []
54            sslpipe.do_handshake.side_effect = mock_handshake
55        with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
56            ssl_proto.connection_made(transport)
57        return transport
58
59    def test_handshake_timeout_zero(self):
60        sslcontext = test_utils.dummy_ssl_context()
61        app_proto = mock.Mock()
62        waiter = mock.Mock()
63        with self.assertRaisesRegex(ValueError, 'a positive number'):
64            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
65                                 ssl_handshake_timeout=0)
66
67    def test_handshake_timeout_negative(self):
68        sslcontext = test_utils.dummy_ssl_context()
69        app_proto = mock.Mock()
70        waiter = mock.Mock()
71        with self.assertRaisesRegex(ValueError, 'a positive number'):
72            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
73                                 ssl_handshake_timeout=-10)
74
75    def test_eof_received_waiter(self):
76        waiter = self.loop.create_future()
77        ssl_proto = self.ssl_protocol(waiter=waiter)
78        self.connection_made(ssl_proto)
79        ssl_proto.eof_received()
80        test_utils.run_briefly(self.loop)
81        self.assertIsInstance(waiter.exception(), ConnectionResetError)
82
83    def test_fatal_error_no_name_error(self):
84        # From issue #363.
85        # _fatal_error() generates a NameError if sslproto.py
86        # does not import base_events.
87        waiter = self.loop.create_future()
88        ssl_proto = self.ssl_protocol(waiter=waiter)
89        # Temporarily turn off error logging so as not to spoil test output.
90        log_level = log.logger.getEffectiveLevel()
91        log.logger.setLevel(logging.FATAL)
92        try:
93            ssl_proto._fatal_error(None)
94        finally:
95            # Restore error logging.
96            log.logger.setLevel(log_level)
97
98    def test_connection_lost(self):
99        # From issue #472.
100        # yield from waiter hang if lost_connection was called.
101        waiter = self.loop.create_future()
102        ssl_proto = self.ssl_protocol(waiter=waiter)
103        self.connection_made(ssl_proto)
104        ssl_proto.connection_lost(ConnectionAbortedError)
105        test_utils.run_briefly(self.loop)
106        self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
107
108    def test_close_during_handshake(self):
109        # bpo-29743 Closing transport during handshake process leaks socket
110        waiter = self.loop.create_future()
111        ssl_proto = self.ssl_protocol(waiter=waiter)
112
113        transport = self.connection_made(ssl_proto)
114        test_utils.run_briefly(self.loop)
115
116        ssl_proto._app_transport.close()
117        self.assertTrue(transport.abort.called)
118
119    def test_get_extra_info_on_closed_connection(self):
120        waiter = self.loop.create_future()
121        ssl_proto = self.ssl_protocol(waiter=waiter)
122        self.assertIsNone(ssl_proto._get_extra_info('socket'))
123        default = object()
124        self.assertIs(ssl_proto._get_extra_info('socket', default), default)
125        self.connection_made(ssl_proto)
126        self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
127        ssl_proto.connection_lost(None)
128        self.assertIsNone(ssl_proto._get_extra_info('socket'))
129
130    def test_set_new_app_protocol(self):
131        waiter = self.loop.create_future()
132        ssl_proto = self.ssl_protocol(waiter=waiter)
133        new_app_proto = asyncio.Protocol()
134        ssl_proto._app_transport.set_protocol(new_app_proto)
135        self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
136        self.assertIs(ssl_proto._app_protocol, new_app_proto)
137
138    def test_data_received_after_closing(self):
139        ssl_proto = self.ssl_protocol()
140        self.connection_made(ssl_proto)
141        transp = ssl_proto._app_transport
142
143        transp.close()
144
145        # should not raise
146        self.assertIsNone(ssl_proto.data_received(b'data'))
147
148    def test_write_after_closing(self):
149        ssl_proto = self.ssl_protocol()
150        self.connection_made(ssl_proto)
151        transp = ssl_proto._app_transport
152        transp.close()
153
154        # should not raise
155        self.assertIsNone(transp.write(b'data'))
156
157
158##############################################################################
159# Start TLS Tests
160##############################################################################
161
162
163class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
164
165    PAYLOAD_SIZE = 1024 * 100
166    TIMEOUT = support.LONG_TIMEOUT
167
168    def new_loop(self):
169        raise NotImplementedError
170
171    def test_buf_feed_data(self):
172
173        class Proto(asyncio.BufferedProtocol):
174
175            def __init__(self, bufsize, usemv):
176                self.buf = bytearray(bufsize)
177                self.mv = memoryview(self.buf)
178                self.data = b''
179                self.usemv = usemv
180
181            def get_buffer(self, sizehint):
182                if self.usemv:
183                    return self.mv
184                else:
185                    return self.buf
186
187            def buffer_updated(self, nsize):
188                if self.usemv:
189                    self.data += self.mv[:nsize]
190                else:
191                    self.data += self.buf[:nsize]
192
193        for usemv in [False, True]:
194            proto = Proto(1, usemv)
195            protocols._feed_data_to_buffered_proto(proto, b'12345')
196            self.assertEqual(proto.data, b'12345')
197
198            proto = Proto(2, usemv)
199            protocols._feed_data_to_buffered_proto(proto, b'12345')
200            self.assertEqual(proto.data, b'12345')
201
202            proto = Proto(2, usemv)
203            protocols._feed_data_to_buffered_proto(proto, b'1234')
204            self.assertEqual(proto.data, b'1234')
205
206            proto = Proto(4, usemv)
207            protocols._feed_data_to_buffered_proto(proto, b'1234')
208            self.assertEqual(proto.data, b'1234')
209
210            proto = Proto(100, usemv)
211            protocols._feed_data_to_buffered_proto(proto, b'12345')
212            self.assertEqual(proto.data, b'12345')
213
214            proto = Proto(0, usemv)
215            with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
216                protocols._feed_data_to_buffered_proto(proto, b'12345')
217
218    def test_start_tls_client_reg_proto_1(self):
219        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
220
221        server_context = test_utils.simple_server_sslcontext()
222        client_context = test_utils.simple_client_sslcontext()
223
224        def serve(sock):
225            sock.settimeout(self.TIMEOUT)
226
227            data = sock.recv_all(len(HELLO_MSG))
228            self.assertEqual(len(data), len(HELLO_MSG))
229
230            sock.start_tls(server_context, server_side=True)
231
232            sock.sendall(b'O')
233            data = sock.recv_all(len(HELLO_MSG))
234            self.assertEqual(len(data), len(HELLO_MSG))
235
236            sock.shutdown(socket.SHUT_RDWR)
237            sock.close()
238
239        class ClientProto(asyncio.Protocol):
240            def __init__(self, on_data, on_eof):
241                self.on_data = on_data
242                self.on_eof = on_eof
243                self.con_made_cnt = 0
244
245            def connection_made(proto, tr):
246                proto.con_made_cnt += 1
247                # Ensure connection_made gets called only once.
248                self.assertEqual(proto.con_made_cnt, 1)
249
250            def data_received(self, data):
251                self.on_data.set_result(data)
252
253            def eof_received(self):
254                self.on_eof.set_result(True)
255
256        async def client(addr):
257            await asyncio.sleep(0.5)
258
259            on_data = self.loop.create_future()
260            on_eof = self.loop.create_future()
261
262            tr, proto = await self.loop.create_connection(
263                lambda: ClientProto(on_data, on_eof), *addr)
264
265            tr.write(HELLO_MSG)
266            new_tr = await self.loop.start_tls(tr, proto, client_context)
267
268            self.assertEqual(await on_data, b'O')
269            new_tr.write(HELLO_MSG)
270            await on_eof
271
272            new_tr.close()
273
274        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
275            self.loop.run_until_complete(
276                asyncio.wait_for(client(srv.addr),
277                                 timeout=support.SHORT_TIMEOUT))
278
279        # No garbage is left if SSL is closed uncleanly
280        client_context = weakref.ref(client_context)
281        support.gc_collect()
282        self.assertIsNone(client_context())
283
284    def test_create_connection_memory_leak(self):
285        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
286
287        server_context = test_utils.simple_server_sslcontext()
288        client_context = test_utils.simple_client_sslcontext()
289
290        def serve(sock):
291            sock.settimeout(self.TIMEOUT)
292
293            sock.start_tls(server_context, server_side=True)
294
295            sock.sendall(b'O')
296            data = sock.recv_all(len(HELLO_MSG))
297            self.assertEqual(len(data), len(HELLO_MSG))
298
299            sock.shutdown(socket.SHUT_RDWR)
300            sock.close()
301
302        class ClientProto(asyncio.Protocol):
303            def __init__(self, on_data, on_eof):
304                self.on_data = on_data
305                self.on_eof = on_eof
306                self.con_made_cnt = 0
307
308            def connection_made(proto, tr):
309                # XXX: We assume user stores the transport in protocol
310                proto.tr = tr
311                proto.con_made_cnt += 1
312                # Ensure connection_made gets called only once.
313                self.assertEqual(proto.con_made_cnt, 1)
314
315            def data_received(self, data):
316                self.on_data.set_result(data)
317
318            def eof_received(self):
319                self.on_eof.set_result(True)
320
321        async def client(addr):
322            await asyncio.sleep(0.5)
323
324            on_data = self.loop.create_future()
325            on_eof = self.loop.create_future()
326
327            tr, proto = await self.loop.create_connection(
328                lambda: ClientProto(on_data, on_eof), *addr,
329                ssl=client_context)
330
331            self.assertEqual(await on_data, b'O')
332            tr.write(HELLO_MSG)
333            await on_eof
334
335            tr.close()
336
337        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
338            self.loop.run_until_complete(
339                asyncio.wait_for(client(srv.addr),
340                                 timeout=support.SHORT_TIMEOUT))
341
342        # No garbage is left for SSL client from loop.create_connection, even
343        # if user stores the SSLTransport in corresponding protocol instance
344        client_context = weakref.ref(client_context)
345        support.gc_collect()
346        self.assertIsNone(client_context())
347
348    def test_start_tls_client_buf_proto_1(self):
349        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
350
351        server_context = test_utils.simple_server_sslcontext()
352        client_context = test_utils.simple_client_sslcontext()
353        client_con_made_calls = 0
354
355        def serve(sock):
356            sock.settimeout(self.TIMEOUT)
357
358            data = sock.recv_all(len(HELLO_MSG))
359            self.assertEqual(len(data), len(HELLO_MSG))
360
361            sock.start_tls(server_context, server_side=True)
362
363            sock.sendall(b'O')
364            data = sock.recv_all(len(HELLO_MSG))
365            self.assertEqual(len(data), len(HELLO_MSG))
366
367            sock.sendall(b'2')
368            data = sock.recv_all(len(HELLO_MSG))
369            self.assertEqual(len(data), len(HELLO_MSG))
370
371            sock.shutdown(socket.SHUT_RDWR)
372            sock.close()
373
374        class ClientProtoFirst(asyncio.BufferedProtocol):
375            def __init__(self, on_data):
376                self.on_data = on_data
377                self.buf = bytearray(1)
378
379            def connection_made(self, tr):
380                nonlocal client_con_made_calls
381                client_con_made_calls += 1
382
383            def get_buffer(self, sizehint):
384                return self.buf
385
386            def buffer_updated(slf, nsize):
387                self.assertEqual(nsize, 1)
388                slf.on_data.set_result(bytes(slf.buf[:nsize]))
389
390        class ClientProtoSecond(asyncio.Protocol):
391            def __init__(self, on_data, on_eof):
392                self.on_data = on_data
393                self.on_eof = on_eof
394                self.con_made_cnt = 0
395
396            def connection_made(self, tr):
397                nonlocal client_con_made_calls
398                client_con_made_calls += 1
399
400            def data_received(self, data):
401                self.on_data.set_result(data)
402
403            def eof_received(self):
404                self.on_eof.set_result(True)
405
406        async def client(addr):
407            await asyncio.sleep(0.5)
408
409            on_data1 = self.loop.create_future()
410            on_data2 = self.loop.create_future()
411            on_eof = self.loop.create_future()
412
413            tr, proto = await self.loop.create_connection(
414                lambda: ClientProtoFirst(on_data1), *addr)
415
416            tr.write(HELLO_MSG)
417            new_tr = await self.loop.start_tls(tr, proto, client_context)
418
419            self.assertEqual(await on_data1, b'O')
420            new_tr.write(HELLO_MSG)
421
422            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
423            self.assertEqual(await on_data2, b'2')
424            new_tr.write(HELLO_MSG)
425            await on_eof
426
427            new_tr.close()
428
429            # connection_made() should be called only once -- when
430            # we establish connection for the first time. Start TLS
431            # doesn't call connection_made() on application protocols.
432            self.assertEqual(client_con_made_calls, 1)
433
434        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
435            self.loop.run_until_complete(
436                asyncio.wait_for(client(srv.addr),
437                                 timeout=self.TIMEOUT))
438
439    def test_start_tls_slow_client_cancel(self):
440        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
441
442        client_context = test_utils.simple_client_sslcontext()
443        server_waits_on_handshake = self.loop.create_future()
444
445        def serve(sock):
446            sock.settimeout(self.TIMEOUT)
447
448            data = sock.recv_all(len(HELLO_MSG))
449            self.assertEqual(len(data), len(HELLO_MSG))
450
451            try:
452                self.loop.call_soon_threadsafe(
453                    server_waits_on_handshake.set_result, None)
454                data = sock.recv_all(1024 * 1024)
455            except ConnectionAbortedError:
456                pass
457            finally:
458                sock.close()
459
460        class ClientProto(asyncio.Protocol):
461            def __init__(self, on_data, on_eof):
462                self.on_data = on_data
463                self.on_eof = on_eof
464                self.con_made_cnt = 0
465
466            def connection_made(proto, tr):
467                proto.con_made_cnt += 1
468                # Ensure connection_made gets called only once.
469                self.assertEqual(proto.con_made_cnt, 1)
470
471            def data_received(self, data):
472                self.on_data.set_result(data)
473
474            def eof_received(self):
475                self.on_eof.set_result(True)
476
477        async def client(addr):
478            await asyncio.sleep(0.5)
479
480            on_data = self.loop.create_future()
481            on_eof = self.loop.create_future()
482
483            tr, proto = await self.loop.create_connection(
484                lambda: ClientProto(on_data, on_eof), *addr)
485
486            tr.write(HELLO_MSG)
487
488            await server_waits_on_handshake
489
490            with self.assertRaises(asyncio.TimeoutError):
491                await asyncio.wait_for(
492                    self.loop.start_tls(tr, proto, client_context),
493                    0.5)
494
495        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
496            self.loop.run_until_complete(
497                asyncio.wait_for(client(srv.addr),
498                                 timeout=support.SHORT_TIMEOUT))
499
500    def test_start_tls_server_1(self):
501        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
502        ANSWER = b'answer'
503
504        server_context = test_utils.simple_server_sslcontext()
505        client_context = test_utils.simple_client_sslcontext()
506        answer = None
507
508        def client(sock, addr):
509            nonlocal answer
510            sock.settimeout(self.TIMEOUT)
511
512            sock.connect(addr)
513            data = sock.recv_all(len(HELLO_MSG))
514            self.assertEqual(len(data), len(HELLO_MSG))
515
516            sock.start_tls(client_context)
517            sock.sendall(HELLO_MSG)
518            answer = sock.recv_all(len(ANSWER))
519            sock.close()
520
521        class ServerProto(asyncio.Protocol):
522            def __init__(self, on_con, on_con_lost, on_got_hello):
523                self.on_con = on_con
524                self.on_con_lost = on_con_lost
525                self.on_got_hello = on_got_hello
526                self.data = b''
527                self.transport = None
528
529            def connection_made(self, tr):
530                self.transport = tr
531                self.on_con.set_result(tr)
532
533            def replace_transport(self, tr):
534                self.transport = tr
535
536            def data_received(self, data):
537                self.data += data
538                if len(self.data) >= len(HELLO_MSG):
539                    self.on_got_hello.set_result(None)
540
541            def connection_lost(self, exc):
542                self.transport = None
543                if exc is None:
544                    self.on_con_lost.set_result(None)
545                else:
546                    self.on_con_lost.set_exception(exc)
547
548        async def main(proto, on_con, on_con_lost, on_got_hello):
549            tr = await on_con
550            tr.write(HELLO_MSG)
551
552            self.assertEqual(proto.data, b'')
553
554            new_tr = await self.loop.start_tls(
555                tr, proto, server_context,
556                server_side=True,
557                ssl_handshake_timeout=self.TIMEOUT)
558            proto.replace_transport(new_tr)
559
560            await on_got_hello
561            new_tr.write(ANSWER)
562
563            await on_con_lost
564            self.assertEqual(proto.data, HELLO_MSG)
565            new_tr.close()
566
567        async def run_main():
568            on_con = self.loop.create_future()
569            on_con_lost = self.loop.create_future()
570            on_got_hello = self.loop.create_future()
571            proto = ServerProto(on_con, on_con_lost, on_got_hello)
572
573            server = await self.loop.create_server(
574                lambda: proto, '127.0.0.1', 0)
575            addr = server.sockets[0].getsockname()
576
577            with self.tcp_client(lambda sock: client(sock, addr),
578                                 timeout=self.TIMEOUT):
579                await asyncio.wait_for(
580                    main(proto, on_con, on_con_lost, on_got_hello),
581                    timeout=self.TIMEOUT)
582
583            server.close()
584            await server.wait_closed()
585            self.assertEqual(answer, ANSWER)
586
587        self.loop.run_until_complete(run_main())
588
589    def test_start_tls_wrong_args(self):
590        async def main():
591            with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
592                await self.loop.start_tls(None, None, None)
593
594            sslctx = test_utils.simple_server_sslcontext()
595            with self.assertRaisesRegex(TypeError, 'is not supported'):
596                await self.loop.start_tls(None, None, sslctx)
597
598        self.loop.run_until_complete(main())
599
600    def test_handshake_timeout(self):
601        # bpo-29970: Check that a connection is aborted if handshake is not
602        # completed in timeout period, instead of remaining open indefinitely
603        client_sslctx = test_utils.simple_client_sslcontext()
604
605        messages = []
606        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
607
608        server_side_aborted = False
609
610        def server(sock):
611            nonlocal server_side_aborted
612            try:
613                sock.recv_all(1024 * 1024)
614            except ConnectionAbortedError:
615                server_side_aborted = True
616            finally:
617                sock.close()
618
619        async def client(addr):
620            await asyncio.wait_for(
621                self.loop.create_connection(
622                    asyncio.Protocol,
623                    *addr,
624                    ssl=client_sslctx,
625                    server_hostname='',
626                    ssl_handshake_timeout=support.SHORT_TIMEOUT),
627                0.5)
628
629        with self.tcp_server(server,
630                             max_clients=1,
631                             backlog=1) as srv:
632
633            with self.assertRaises(asyncio.TimeoutError):
634                self.loop.run_until_complete(client(srv.addr))
635
636        self.assertTrue(server_side_aborted)
637
638        # Python issue #23197: cancelling a handshake must not raise an
639        # exception or log an error, even if the handshake failed
640        self.assertEqual(messages, [])
641
642        # The 10s handshake timeout should be cancelled to free related
643        # objects without really waiting for 10s
644        client_sslctx = weakref.ref(client_sslctx)
645        support.gc_collect()
646        self.assertIsNone(client_sslctx())
647
648    def test_create_connection_ssl_slow_handshake(self):
649        client_sslctx = test_utils.simple_client_sslcontext()
650
651        messages = []
652        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
653
654        def server(sock):
655            try:
656                sock.recv_all(1024 * 1024)
657            except ConnectionAbortedError:
658                pass
659            finally:
660                sock.close()
661
662        async def client(addr):
663            reader, writer = await asyncio.open_connection(
664                *addr,
665                ssl=client_sslctx,
666                server_hostname='',
667                ssl_handshake_timeout=1.0)
668
669        with self.tcp_server(server,
670                             max_clients=1,
671                             backlog=1) as srv:
672
673            with self.assertRaisesRegex(
674                    ConnectionAbortedError,
675                    r'SSL handshake.*is taking longer'):
676
677                self.loop.run_until_complete(client(srv.addr))
678
679        self.assertEqual(messages, [])
680
681    def test_create_connection_ssl_failed_certificate(self):
682        self.loop.set_exception_handler(lambda loop, ctx: None)
683
684        sslctx = test_utils.simple_server_sslcontext()
685        client_sslctx = test_utils.simple_client_sslcontext(
686            disable_verify=False)
687
688        def server(sock):
689            try:
690                sock.start_tls(
691                    sslctx,
692                    server_side=True)
693            except ssl.SSLError:
694                pass
695            except OSError:
696                pass
697            finally:
698                sock.close()
699
700        async def client(addr):
701            reader, writer = await asyncio.open_connection(
702                *addr,
703                ssl=client_sslctx,
704                server_hostname='',
705                ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
706
707        with self.tcp_server(server,
708                             max_clients=1,
709                             backlog=1) as srv:
710
711            with self.assertRaises(ssl.SSLCertVerificationError):
712                self.loop.run_until_complete(client(srv.addr))
713
714    def test_start_tls_client_corrupted_ssl(self):
715        self.loop.set_exception_handler(lambda loop, ctx: None)
716
717        sslctx = test_utils.simple_server_sslcontext()
718        client_sslctx = test_utils.simple_client_sslcontext()
719
720        def server(sock):
721            orig_sock = sock.dup()
722            try:
723                sock.start_tls(
724                    sslctx,
725                    server_side=True)
726                sock.sendall(b'A\n')
727                sock.recv_all(1)
728                orig_sock.send(b'please corrupt the SSL connection')
729            except ssl.SSLError:
730                pass
731            finally:
732                orig_sock.close()
733                sock.close()
734
735        async def client(addr):
736            reader, writer = await asyncio.open_connection(
737                *addr,
738                ssl=client_sslctx,
739                server_hostname='')
740
741            self.assertEqual(await reader.readline(), b'A\n')
742            writer.write(b'B')
743            with self.assertRaises(ssl.SSLError):
744                await reader.readline()
745
746            writer.close()
747            return 'OK'
748
749        with self.tcp_server(server,
750                             max_clients=1,
751                             backlog=1) as srv:
752
753            res = self.loop.run_until_complete(client(srv.addr))
754
755        self.assertEqual(res, 'OK')
756
757
758@unittest.skipIf(ssl is None, 'No ssl module')
759class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
760
761    def new_loop(self):
762        return asyncio.SelectorEventLoop()
763
764
765@unittest.skipIf(ssl is None, 'No ssl module')
766@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
767class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
768
769    def new_loop(self):
770        return asyncio.ProactorEventLoop()
771
772
773if __name__ == '__main__':
774    unittest.main()
775