• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for events.py."""
2
3import collections.abc
4import concurrent.futures
5import functools
6import io
7import os
8import platform
9import re
10import signal
11import socket
12try:
13    import ssl
14except ImportError:
15    ssl = None
16import subprocess
17import sys
18import threading
19import time
20import errno
21import unittest
22from unittest import mock
23import weakref
24
25if sys.platform not in ('win32', 'vxworks'):
26    import tty
27
28import asyncio
29from asyncio import coroutines
30from asyncio import events
31from asyncio import proactor_events
32from asyncio import selector_events
33from test.test_asyncio import utils as test_utils
34from test import support
35from test.support import socket_helper
36from test.support import threading_helper
37from test.support import ALWAYS_EQ, LARGEST, SMALLEST
38
39
40def tearDownModule():
41    asyncio.set_event_loop_policy(None)
42
43
44def broken_unix_getsockname():
45    """Return True if the platform is Mac OS 10.4 or older."""
46    if sys.platform.startswith("aix"):
47        return True
48    elif sys.platform != 'darwin':
49        return False
50    version = platform.mac_ver()[0]
51    version = tuple(map(int, version.split('.')))
52    return version < (10, 5)
53
54
55def _test_get_event_loop_new_process__sub_proc():
56    async def doit():
57        return 'hello'
58
59    loop = asyncio.new_event_loop()
60    asyncio.set_event_loop(loop)
61    return loop.run_until_complete(doit())
62
63
64class CoroLike:
65    def send(self, v):
66        pass
67
68    def throw(self, *exc):
69        pass
70
71    def close(self):
72        pass
73
74    def __await__(self):
75        pass
76
77
78class MyBaseProto(asyncio.Protocol):
79    connected = None
80    done = None
81
82    def __init__(self, loop=None):
83        self.transport = None
84        self.state = 'INITIAL'
85        self.nbytes = 0
86        if loop is not None:
87            self.connected = loop.create_future()
88            self.done = loop.create_future()
89
90    def _assert_state(self, *expected):
91        if self.state not in expected:
92            raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
93
94    def connection_made(self, transport):
95        self.transport = transport
96        self._assert_state('INITIAL')
97        self.state = 'CONNECTED'
98        if self.connected:
99            self.connected.set_result(None)
100
101    def data_received(self, data):
102        self._assert_state('CONNECTED')
103        self.nbytes += len(data)
104
105    def eof_received(self):
106        self._assert_state('CONNECTED')
107        self.state = 'EOF'
108
109    def connection_lost(self, exc):
110        self._assert_state('CONNECTED', 'EOF')
111        self.state = 'CLOSED'
112        if self.done:
113            self.done.set_result(None)
114
115
116class MyProto(MyBaseProto):
117    def connection_made(self, transport):
118        super().connection_made(transport)
119        transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
120
121
122class MyDatagramProto(asyncio.DatagramProtocol):
123    done = None
124
125    def __init__(self, loop=None):
126        self.state = 'INITIAL'
127        self.nbytes = 0
128        if loop is not None:
129            self.done = loop.create_future()
130
131    def _assert_state(self, expected):
132        if self.state != expected:
133            raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
134
135    def connection_made(self, transport):
136        self.transport = transport
137        self._assert_state('INITIAL')
138        self.state = 'INITIALIZED'
139
140    def datagram_received(self, data, addr):
141        self._assert_state('INITIALIZED')
142        self.nbytes += len(data)
143
144    def error_received(self, exc):
145        self._assert_state('INITIALIZED')
146
147    def connection_lost(self, exc):
148        self._assert_state('INITIALIZED')
149        self.state = 'CLOSED'
150        if self.done:
151            self.done.set_result(None)
152
153
154class MyReadPipeProto(asyncio.Protocol):
155    done = None
156
157    def __init__(self, loop=None):
158        self.state = ['INITIAL']
159        self.nbytes = 0
160        self.transport = None
161        if loop is not None:
162            self.done = loop.create_future()
163
164    def _assert_state(self, expected):
165        if self.state != expected:
166            raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
167
168    def connection_made(self, transport):
169        self.transport = transport
170        self._assert_state(['INITIAL'])
171        self.state.append('CONNECTED')
172
173    def data_received(self, data):
174        self._assert_state(['INITIAL', 'CONNECTED'])
175        self.nbytes += len(data)
176
177    def eof_received(self):
178        self._assert_state(['INITIAL', 'CONNECTED'])
179        self.state.append('EOF')
180
181    def connection_lost(self, exc):
182        if 'EOF' not in self.state:
183            self.state.append('EOF')  # It is okay if EOF is missed.
184        self._assert_state(['INITIAL', 'CONNECTED', 'EOF'])
185        self.state.append('CLOSED')
186        if self.done:
187            self.done.set_result(None)
188
189
190class MyWritePipeProto(asyncio.BaseProtocol):
191    done = None
192
193    def __init__(self, loop=None):
194        self.state = 'INITIAL'
195        self.transport = None
196        if loop is not None:
197            self.done = loop.create_future()
198
199    def _assert_state(self, expected):
200        if self.state != expected:
201            raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
202
203    def connection_made(self, transport):
204        self.transport = transport
205        self._assert_state('INITIAL')
206        self.state = 'CONNECTED'
207
208    def connection_lost(self, exc):
209        self._assert_state('CONNECTED')
210        self.state = 'CLOSED'
211        if self.done:
212            self.done.set_result(None)
213
214
215class MySubprocessProtocol(asyncio.SubprocessProtocol):
216
217    def __init__(self, loop):
218        self.state = 'INITIAL'
219        self.transport = None
220        self.connected = loop.create_future()
221        self.completed = loop.create_future()
222        self.disconnects = {fd: loop.create_future() for fd in range(3)}
223        self.data = {1: b'', 2: b''}
224        self.returncode = None
225        self.got_data = {1: asyncio.Event(),
226                         2: asyncio.Event()}
227
228    def _assert_state(self, expected):
229        if self.state != expected:
230            raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
231
232    def connection_made(self, transport):
233        self.transport = transport
234        self._assert_state('INITIAL')
235        self.state = 'CONNECTED'
236        self.connected.set_result(None)
237
238    def connection_lost(self, exc):
239        self._assert_state('CONNECTED')
240        self.state = 'CLOSED'
241        self.completed.set_result(None)
242
243    def pipe_data_received(self, fd, data):
244        self._assert_state('CONNECTED')
245        self.data[fd] += data
246        self.got_data[fd].set()
247
248    def pipe_connection_lost(self, fd, exc):
249        self._assert_state('CONNECTED')
250        if exc:
251            self.disconnects[fd].set_exception(exc)
252        else:
253            self.disconnects[fd].set_result(exc)
254
255    def process_exited(self):
256        self._assert_state('CONNECTED')
257        self.returncode = self.transport.get_returncode()
258
259
260class EventLoopTestsMixin:
261
262    def setUp(self):
263        super().setUp()
264        self.loop = self.create_event_loop()
265        self.set_event_loop(self.loop)
266
267    def tearDown(self):
268        # just in case if we have transport close callbacks
269        if not self.loop.is_closed():
270            test_utils.run_briefly(self.loop)
271
272        self.doCleanups()
273        support.gc_collect()
274        super().tearDown()
275
276    def test_run_until_complete_nesting(self):
277        async def coro1():
278            await asyncio.sleep(0)
279
280        async def coro2():
281            self.assertTrue(self.loop.is_running())
282            self.loop.run_until_complete(coro1())
283
284        with self.assertWarnsRegex(
285            RuntimeWarning,
286            r"coroutine \S+ was never awaited"
287        ):
288            self.assertRaises(
289                RuntimeError, self.loop.run_until_complete, coro2())
290
291    # Note: because of the default Windows timing granularity of
292    # 15.6 msec, we use fairly long sleep times here (~100 msec).
293
294    def test_run_until_complete(self):
295        t0 = self.loop.time()
296        self.loop.run_until_complete(asyncio.sleep(0.1))
297        t1 = self.loop.time()
298        self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0)
299
300    def test_run_until_complete_stopped(self):
301
302        async def cb():
303            self.loop.stop()
304            await asyncio.sleep(0.1)
305        task = cb()
306        self.assertRaises(RuntimeError,
307                          self.loop.run_until_complete, task)
308
309    def test_call_later(self):
310        results = []
311
312        def callback(arg):
313            results.append(arg)
314            self.loop.stop()
315
316        self.loop.call_later(0.1, callback, 'hello world')
317        self.loop.run_forever()
318        self.assertEqual(results, ['hello world'])
319
320    def test_call_soon(self):
321        results = []
322
323        def callback(arg1, arg2):
324            results.append((arg1, arg2))
325            self.loop.stop()
326
327        self.loop.call_soon(callback, 'hello', 'world')
328        self.loop.run_forever()
329        self.assertEqual(results, [('hello', 'world')])
330
331    def test_call_soon_threadsafe(self):
332        results = []
333        lock = threading.Lock()
334
335        def callback(arg):
336            results.append(arg)
337            if len(results) >= 2:
338                self.loop.stop()
339
340        def run_in_thread():
341            self.loop.call_soon_threadsafe(callback, 'hello')
342            lock.release()
343
344        lock.acquire()
345        t = threading.Thread(target=run_in_thread)
346        t.start()
347
348        with lock:
349            self.loop.call_soon(callback, 'world')
350            self.loop.run_forever()
351        t.join()
352        self.assertEqual(results, ['hello', 'world'])
353
354    def test_call_soon_threadsafe_same_thread(self):
355        results = []
356
357        def callback(arg):
358            results.append(arg)
359            if len(results) >= 2:
360                self.loop.stop()
361
362        self.loop.call_soon_threadsafe(callback, 'hello')
363        self.loop.call_soon(callback, 'world')
364        self.loop.run_forever()
365        self.assertEqual(results, ['hello', 'world'])
366
367    def test_run_in_executor(self):
368        def run(arg):
369            return (arg, threading.get_ident())
370        f2 = self.loop.run_in_executor(None, run, 'yo')
371        res, thread_id = self.loop.run_until_complete(f2)
372        self.assertEqual(res, 'yo')
373        self.assertNotEqual(thread_id, threading.get_ident())
374
375    def test_run_in_executor_cancel(self):
376        called = False
377
378        def patched_call_soon(*args):
379            nonlocal called
380            called = True
381
382        def run():
383            time.sleep(0.05)
384
385        f2 = self.loop.run_in_executor(None, run)
386        f2.cancel()
387        self.loop.run_until_complete(
388                self.loop.shutdown_default_executor())
389        self.loop.close()
390        self.loop.call_soon = patched_call_soon
391        self.loop.call_soon_threadsafe = patched_call_soon
392        time.sleep(0.4)
393        self.assertFalse(called)
394
395    def test_reader_callback(self):
396        r, w = socket.socketpair()
397        r.setblocking(False)
398        bytes_read = bytearray()
399
400        def reader():
401            try:
402                data = r.recv(1024)
403            except BlockingIOError:
404                # Spurious readiness notifications are possible
405                # at least on Linux -- see man select.
406                return
407            if data:
408                bytes_read.extend(data)
409            else:
410                self.assertTrue(self.loop.remove_reader(r.fileno()))
411                r.close()
412
413        self.loop.add_reader(r.fileno(), reader)
414        self.loop.call_soon(w.send, b'abc')
415        test_utils.run_until(self.loop, lambda: len(bytes_read) >= 3)
416        self.loop.call_soon(w.send, b'def')
417        test_utils.run_until(self.loop, lambda: len(bytes_read) >= 6)
418        self.loop.call_soon(w.close)
419        self.loop.call_soon(self.loop.stop)
420        self.loop.run_forever()
421        self.assertEqual(bytes_read, b'abcdef')
422
423    def test_writer_callback(self):
424        r, w = socket.socketpair()
425        w.setblocking(False)
426
427        def writer(data):
428            w.send(data)
429            self.loop.stop()
430
431        data = b'x' * 1024
432        self.loop.add_writer(w.fileno(), writer, data)
433        self.loop.run_forever()
434
435        self.assertTrue(self.loop.remove_writer(w.fileno()))
436        self.assertFalse(self.loop.remove_writer(w.fileno()))
437
438        w.close()
439        read = r.recv(len(data) * 2)
440        r.close()
441        self.assertEqual(read, data)
442
443    @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL')
444    def test_add_signal_handler(self):
445        caught = 0
446
447        def my_handler():
448            nonlocal caught
449            caught += 1
450
451        # Check error behavior first.
452        self.assertRaises(
453            TypeError, self.loop.add_signal_handler, 'boom', my_handler)
454        self.assertRaises(
455            TypeError, self.loop.remove_signal_handler, 'boom')
456        self.assertRaises(
457            ValueError, self.loop.add_signal_handler, signal.NSIG+1,
458            my_handler)
459        self.assertRaises(
460            ValueError, self.loop.remove_signal_handler, signal.NSIG+1)
461        self.assertRaises(
462            ValueError, self.loop.add_signal_handler, 0, my_handler)
463        self.assertRaises(
464            ValueError, self.loop.remove_signal_handler, 0)
465        self.assertRaises(
466            ValueError, self.loop.add_signal_handler, -1, my_handler)
467        self.assertRaises(
468            ValueError, self.loop.remove_signal_handler, -1)
469        self.assertRaises(
470            RuntimeError, self.loop.add_signal_handler, signal.SIGKILL,
471            my_handler)
472        # Removing SIGKILL doesn't raise, since we don't call signal().
473        self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL))
474        # Now set a handler and handle it.
475        self.loop.add_signal_handler(signal.SIGINT, my_handler)
476
477        os.kill(os.getpid(), signal.SIGINT)
478        test_utils.run_until(self.loop, lambda: caught)
479
480        # Removing it should restore the default handler.
481        self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT))
482        self.assertEqual(signal.getsignal(signal.SIGINT),
483                         signal.default_int_handler)
484        # Removing again returns False.
485        self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT))
486
487    @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
488    @unittest.skipUnless(hasattr(signal, 'setitimer'),
489                         'need signal.setitimer()')
490    def test_signal_handling_while_selecting(self):
491        # Test with a signal actually arriving during a select() call.
492        caught = 0
493
494        def my_handler():
495            nonlocal caught
496            caught += 1
497            self.loop.stop()
498
499        self.loop.add_signal_handler(signal.SIGALRM, my_handler)
500
501        signal.setitimer(signal.ITIMER_REAL, 0.01, 0)  # Send SIGALRM once.
502        self.loop.call_later(60, self.loop.stop)
503        self.loop.run_forever()
504        self.assertEqual(caught, 1)
505
506    @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
507    @unittest.skipUnless(hasattr(signal, 'setitimer'),
508                         'need signal.setitimer()')
509    def test_signal_handling_args(self):
510        some_args = (42,)
511        caught = 0
512
513        def my_handler(*args):
514            nonlocal caught
515            caught += 1
516            self.assertEqual(args, some_args)
517            self.loop.stop()
518
519        self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args)
520
521        signal.setitimer(signal.ITIMER_REAL, 0.1, 0)  # Send SIGALRM once.
522        self.loop.call_later(60, self.loop.stop)
523        self.loop.run_forever()
524        self.assertEqual(caught, 1)
525
526    def _basetest_create_connection(self, connection_fut, check_sockname=True):
527        tr, pr = self.loop.run_until_complete(connection_fut)
528        self.assertIsInstance(tr, asyncio.Transport)
529        self.assertIsInstance(pr, asyncio.Protocol)
530        self.assertIs(pr.transport, tr)
531        if check_sockname:
532            self.assertIsNotNone(tr.get_extra_info('sockname'))
533        self.loop.run_until_complete(pr.done)
534        self.assertGreater(pr.nbytes, 0)
535        tr.close()
536
537    def test_create_connection(self):
538        with test_utils.run_test_server() as httpd:
539            conn_fut = self.loop.create_connection(
540                lambda: MyProto(loop=self.loop), *httpd.address)
541            self._basetest_create_connection(conn_fut)
542
543    @socket_helper.skip_unless_bind_unix_socket
544    def test_create_unix_connection(self):
545        # Issue #20682: On Mac OS X Tiger, getsockname() returns a
546        # zero-length address for UNIX socket.
547        check_sockname = not broken_unix_getsockname()
548
549        with test_utils.run_test_unix_server() as httpd:
550            conn_fut = self.loop.create_unix_connection(
551                lambda: MyProto(loop=self.loop), httpd.address)
552            self._basetest_create_connection(conn_fut, check_sockname)
553
554    def check_ssl_extra_info(self, client, check_sockname=True,
555                             peername=None, peercert={}):
556        if check_sockname:
557            self.assertIsNotNone(client.get_extra_info('sockname'))
558        if peername:
559            self.assertEqual(peername,
560                             client.get_extra_info('peername'))
561        else:
562            self.assertIsNotNone(client.get_extra_info('peername'))
563        self.assertEqual(peercert,
564                         client.get_extra_info('peercert'))
565
566        # test SSL cipher
567        cipher = client.get_extra_info('cipher')
568        self.assertIsInstance(cipher, tuple)
569        self.assertEqual(len(cipher), 3, cipher)
570        self.assertIsInstance(cipher[0], str)
571        self.assertIsInstance(cipher[1], str)
572        self.assertIsInstance(cipher[2], int)
573
574        # test SSL object
575        sslobj = client.get_extra_info('ssl_object')
576        self.assertIsNotNone(sslobj)
577        self.assertEqual(sslobj.compression(),
578                         client.get_extra_info('compression'))
579        self.assertEqual(sslobj.cipher(),
580                         client.get_extra_info('cipher'))
581        self.assertEqual(sslobj.getpeercert(),
582                         client.get_extra_info('peercert'))
583        self.assertEqual(sslobj.compression(),
584                         client.get_extra_info('compression'))
585
586    def _basetest_create_ssl_connection(self, connection_fut,
587                                        check_sockname=True,
588                                        peername=None):
589        tr, pr = self.loop.run_until_complete(connection_fut)
590        self.assertIsInstance(tr, asyncio.Transport)
591        self.assertIsInstance(pr, asyncio.Protocol)
592        self.assertTrue('ssl' in tr.__class__.__name__.lower())
593        self.check_ssl_extra_info(tr, check_sockname, peername)
594        self.loop.run_until_complete(pr.done)
595        self.assertGreater(pr.nbytes, 0)
596        tr.close()
597
598    def _test_create_ssl_connection(self, httpd, create_connection,
599                                    check_sockname=True, peername=None):
600        conn_fut = create_connection(ssl=test_utils.dummy_ssl_context())
601        self._basetest_create_ssl_connection(conn_fut, check_sockname,
602                                             peername)
603
604        # ssl.Purpose was introduced in Python 3.4
605        if hasattr(ssl, 'Purpose'):
606            def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *,
607                                          cafile=None, capath=None,
608                                          cadata=None):
609                """
610                A ssl.create_default_context() replacement that doesn't enable
611                cert validation.
612                """
613                self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH)
614                return test_utils.dummy_ssl_context()
615
616            # With ssl=True, ssl.create_default_context() should be called
617            with mock.patch('ssl.create_default_context',
618                            side_effect=_dummy_ssl_create_context) as m:
619                conn_fut = create_connection(ssl=True)
620                self._basetest_create_ssl_connection(conn_fut, check_sockname,
621                                                     peername)
622                self.assertEqual(m.call_count, 1)
623
624        # With the real ssl.create_default_context(), certificate
625        # validation will fail
626        with self.assertRaises(ssl.SSLError) as cm:
627            conn_fut = create_connection(ssl=True)
628            # Ignore the "SSL handshake failed" log in debug mode
629            with test_utils.disable_logger():
630                self._basetest_create_ssl_connection(conn_fut, check_sockname,
631                                                     peername)
632
633        self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED')
634
635    @unittest.skipIf(ssl is None, 'No ssl module')
636    def test_create_ssl_connection(self):
637        with test_utils.run_test_server(use_ssl=True) as httpd:
638            create_connection = functools.partial(
639                self.loop.create_connection,
640                lambda: MyProto(loop=self.loop),
641                *httpd.address)
642            self._test_create_ssl_connection(httpd, create_connection,
643                                             peername=httpd.address)
644
645    @socket_helper.skip_unless_bind_unix_socket
646    @unittest.skipIf(ssl is None, 'No ssl module')
647    def test_create_ssl_unix_connection(self):
648        # Issue #20682: On Mac OS X Tiger, getsockname() returns a
649        # zero-length address for UNIX socket.
650        check_sockname = not broken_unix_getsockname()
651
652        with test_utils.run_test_unix_server(use_ssl=True) as httpd:
653            create_connection = functools.partial(
654                self.loop.create_unix_connection,
655                lambda: MyProto(loop=self.loop), httpd.address,
656                server_hostname='127.0.0.1')
657
658            self._test_create_ssl_connection(httpd, create_connection,
659                                             check_sockname,
660                                             peername=httpd.address)
661
662    def test_create_connection_local_addr(self):
663        with test_utils.run_test_server() as httpd:
664            port = socket_helper.find_unused_port()
665            f = self.loop.create_connection(
666                lambda: MyProto(loop=self.loop),
667                *httpd.address, local_addr=(httpd.address[0], port))
668            tr, pr = self.loop.run_until_complete(f)
669            expected = pr.transport.get_extra_info('sockname')[1]
670            self.assertEqual(port, expected)
671            tr.close()
672
673    def test_create_connection_local_addr_in_use(self):
674        with test_utils.run_test_server() as httpd:
675            f = self.loop.create_connection(
676                lambda: MyProto(loop=self.loop),
677                *httpd.address, local_addr=httpd.address)
678            with self.assertRaises(OSError) as cm:
679                self.loop.run_until_complete(f)
680            self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
681            self.assertIn(str(httpd.address), cm.exception.strerror)
682
683    def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
684        loop = self.loop
685
686        class MyProto(MyBaseProto):
687
688            def connection_lost(self, exc):
689                super().connection_lost(exc)
690                loop.call_soon(loop.stop)
691
692            def data_received(self, data):
693                super().data_received(data)
694                self.transport.write(expected_response)
695
696        lsock = socket.create_server(('127.0.0.1', 0), backlog=1)
697        addr = lsock.getsockname()
698
699        message = b'test data'
700        response = None
701        expected_response = b'roger'
702
703        def client():
704            nonlocal response
705            try:
706                csock = socket.socket()
707                if client_ssl is not None:
708                    csock = client_ssl.wrap_socket(csock)
709                csock.connect(addr)
710                csock.sendall(message)
711                response = csock.recv(99)
712                csock.close()
713            except Exception as exc:
714                print(
715                    "Failure in client thread in test_connect_accepted_socket",
716                    exc)
717
718        thread = threading.Thread(target=client, daemon=True)
719        thread.start()
720
721        conn, _ = lsock.accept()
722        proto = MyProto(loop=loop)
723        proto.loop = loop
724        loop.run_until_complete(
725            loop.connect_accepted_socket(
726                (lambda: proto), conn, ssl=server_ssl))
727        loop.run_forever()
728        proto.transport.close()
729        lsock.close()
730
731        threading_helper.join_thread(thread)
732        self.assertFalse(thread.is_alive())
733        self.assertEqual(proto.state, 'CLOSED')
734        self.assertEqual(proto.nbytes, len(message))
735        self.assertEqual(response, expected_response)
736
737    @unittest.skipIf(ssl is None, 'No ssl module')
738    def test_ssl_connect_accepted_socket(self):
739        if (sys.platform == 'win32' and
740            sys.version_info < (3, 5) and
741            isinstance(self.loop, proactor_events.BaseProactorEventLoop)
742            ):
743            raise unittest.SkipTest(
744                'SSL not supported with proactor event loops before Python 3.5'
745                )
746
747        server_context = test_utils.simple_server_sslcontext()
748        client_context = test_utils.simple_client_sslcontext()
749
750        self.test_connect_accepted_socket(server_context, client_context)
751
752    def test_connect_accepted_socket_ssl_timeout_for_plain_socket(self):
753        sock = socket.socket()
754        self.addCleanup(sock.close)
755        coro = self.loop.connect_accepted_socket(
756            MyProto, sock, ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
757        with self.assertRaisesRegex(
758                ValueError,
759                'ssl_handshake_timeout is only meaningful with ssl'):
760            self.loop.run_until_complete(coro)
761
762    @mock.patch('asyncio.base_events.socket')
763    def create_server_multiple_hosts(self, family, hosts, mock_sock):
764        async def getaddrinfo(host, port, *args, **kw):
765            if family == socket.AF_INET:
766                return [(family, socket.SOCK_STREAM, 6, '', (host, port))]
767            else:
768                return [(family, socket.SOCK_STREAM, 6, '', (host, port, 0, 0))]
769
770        def getaddrinfo_task(*args, **kwds):
771            return self.loop.create_task(getaddrinfo(*args, **kwds))
772
773        unique_hosts = set(hosts)
774
775        if family == socket.AF_INET:
776            mock_sock.socket().getsockbyname.side_effect = [
777                (host, 80) for host in unique_hosts]
778        else:
779            mock_sock.socket().getsockbyname.side_effect = [
780                (host, 80, 0, 0) for host in unique_hosts]
781        self.loop.getaddrinfo = getaddrinfo_task
782        self.loop._start_serving = mock.Mock()
783        self.loop._stop_serving = mock.Mock()
784        f = self.loop.create_server(lambda: MyProto(self.loop), hosts, 80)
785        server = self.loop.run_until_complete(f)
786        self.addCleanup(server.close)
787        server_hosts = {sock.getsockbyname()[0] for sock in server.sockets}
788        self.assertEqual(server_hosts, unique_hosts)
789
790    def test_create_server_multiple_hosts_ipv4(self):
791        self.create_server_multiple_hosts(socket.AF_INET,
792                                          ['1.2.3.4', '5.6.7.8', '1.2.3.4'])
793
794    def test_create_server_multiple_hosts_ipv6(self):
795        self.create_server_multiple_hosts(socket.AF_INET6,
796                                          ['::1', '::2', '::1'])
797
798    def test_create_server(self):
799        proto = MyProto(self.loop)
800        f = self.loop.create_server(lambda: proto, '0.0.0.0', 0)
801        server = self.loop.run_until_complete(f)
802        self.assertEqual(len(server.sockets), 1)
803        sock = server.sockets[0]
804        host, port = sock.getsockname()
805        self.assertEqual(host, '0.0.0.0')
806        client = socket.socket()
807        client.connect(('127.0.0.1', port))
808        client.sendall(b'xxx')
809
810        self.loop.run_until_complete(proto.connected)
811        self.assertEqual('CONNECTED', proto.state)
812
813        test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
814        self.assertEqual(3, proto.nbytes)
815
816        # extra info is available
817        self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
818        self.assertEqual('127.0.0.1',
819                         proto.transport.get_extra_info('peername')[0])
820
821        # close connection
822        proto.transport.close()
823        self.loop.run_until_complete(proto.done)
824
825        self.assertEqual('CLOSED', proto.state)
826
827        # the client socket must be closed after to avoid ECONNRESET upon
828        # recv()/send() on the serving socket
829        client.close()
830
831        # close server
832        server.close()
833
834    @unittest.skipUnless(hasattr(socket, 'SO_REUSEPORT'), 'No SO_REUSEPORT')
835    def test_create_server_reuse_port(self):
836        proto = MyProto(self.loop)
837        f = self.loop.create_server(
838            lambda: proto, '0.0.0.0', 0)
839        server = self.loop.run_until_complete(f)
840        self.assertEqual(len(server.sockets), 1)
841        sock = server.sockets[0]
842        self.assertFalse(
843            sock.getsockopt(
844                socket.SOL_SOCKET, socket.SO_REUSEPORT))
845        server.close()
846
847        test_utils.run_briefly(self.loop)
848
849        proto = MyProto(self.loop)
850        f = self.loop.create_server(
851            lambda: proto, '0.0.0.0', 0, reuse_port=True)
852        server = self.loop.run_until_complete(f)
853        self.assertEqual(len(server.sockets), 1)
854        sock = server.sockets[0]
855        self.assertTrue(
856            sock.getsockopt(
857                socket.SOL_SOCKET, socket.SO_REUSEPORT))
858        server.close()
859
860    def _make_unix_server(self, factory, **kwargs):
861        path = test_utils.gen_unix_socket_path()
862        self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
863
864        f = self.loop.create_unix_server(factory, path, **kwargs)
865        server = self.loop.run_until_complete(f)
866
867        return server, path
868
869    @socket_helper.skip_unless_bind_unix_socket
870    def test_create_unix_server(self):
871        proto = MyProto(loop=self.loop)
872        server, path = self._make_unix_server(lambda: proto)
873        self.assertEqual(len(server.sockets), 1)
874
875        client = socket.socket(socket.AF_UNIX)
876        client.connect(path)
877        client.sendall(b'xxx')
878
879        self.loop.run_until_complete(proto.connected)
880        self.assertEqual('CONNECTED', proto.state)
881        test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
882        self.assertEqual(3, proto.nbytes)
883
884        # close connection
885        proto.transport.close()
886        self.loop.run_until_complete(proto.done)
887
888        self.assertEqual('CLOSED', proto.state)
889
890        # the client socket must be closed after to avoid ECONNRESET upon
891        # recv()/send() on the serving socket
892        client.close()
893
894        # close server
895        server.close()
896
897    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
898    def test_create_unix_server_path_socket_error(self):
899        proto = MyProto(loop=self.loop)
900        sock = socket.socket()
901        with sock:
902            f = self.loop.create_unix_server(lambda: proto, '/test', sock=sock)
903            with self.assertRaisesRegex(ValueError,
904                                        'path and sock can not be specified '
905                                        'at the same time'):
906                self.loop.run_until_complete(f)
907
908    def _create_ssl_context(self, certfile, keyfile=None):
909        sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
910        sslcontext.options |= ssl.OP_NO_SSLv2
911        sslcontext.load_cert_chain(certfile, keyfile)
912        return sslcontext
913
914    def _make_ssl_server(self, factory, certfile, keyfile=None):
915        sslcontext = self._create_ssl_context(certfile, keyfile)
916
917        f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext)
918        server = self.loop.run_until_complete(f)
919
920        sock = server.sockets[0]
921        host, port = sock.getsockname()
922        self.assertEqual(host, '127.0.0.1')
923        return server, host, port
924
925    def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
926        sslcontext = self._create_ssl_context(certfile, keyfile)
927        return self._make_unix_server(factory, ssl=sslcontext)
928
929    @unittest.skipIf(ssl is None, 'No ssl module')
930    def test_create_server_ssl(self):
931        proto = MyProto(loop=self.loop)
932        server, host, port = self._make_ssl_server(
933            lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY)
934
935        f_c = self.loop.create_connection(MyBaseProto, host, port,
936                                          ssl=test_utils.dummy_ssl_context())
937        client, pr = self.loop.run_until_complete(f_c)
938
939        client.write(b'xxx')
940        self.loop.run_until_complete(proto.connected)
941        self.assertEqual('CONNECTED', proto.state)
942
943        test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
944        self.assertEqual(3, proto.nbytes)
945
946        # extra info is available
947        self.check_ssl_extra_info(client, peername=(host, port))
948
949        # close connection
950        proto.transport.close()
951        self.loop.run_until_complete(proto.done)
952        self.assertEqual('CLOSED', proto.state)
953
954        # the client socket must be closed after to avoid ECONNRESET upon
955        # recv()/send() on the serving socket
956        client.close()
957
958        # stop serving
959        server.close()
960
961    @socket_helper.skip_unless_bind_unix_socket
962    @unittest.skipIf(ssl is None, 'No ssl module')
963    def test_create_unix_server_ssl(self):
964        proto = MyProto(loop=self.loop)
965        server, path = self._make_ssl_unix_server(
966            lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY)
967
968        f_c = self.loop.create_unix_connection(
969            MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
970            server_hostname='')
971
972        client, pr = self.loop.run_until_complete(f_c)
973
974        client.write(b'xxx')
975        self.loop.run_until_complete(proto.connected)
976        self.assertEqual('CONNECTED', proto.state)
977        test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
978        self.assertEqual(3, proto.nbytes)
979
980        # close connection
981        proto.transport.close()
982        self.loop.run_until_complete(proto.done)
983        self.assertEqual('CLOSED', proto.state)
984
985        # the client socket must be closed after to avoid ECONNRESET upon
986        # recv()/send() on the serving socket
987        client.close()
988
989        # stop serving
990        server.close()
991
992    @unittest.skipIf(ssl is None, 'No ssl module')
993    def test_create_server_ssl_verify_failed(self):
994        proto = MyProto(loop=self.loop)
995        server, host, port = self._make_ssl_server(
996            lambda: proto, test_utils.SIGNED_CERTFILE)
997
998        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
999        sslcontext_client.options |= ssl.OP_NO_SSLv2
1000        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1001        if hasattr(sslcontext_client, 'check_hostname'):
1002            sslcontext_client.check_hostname = True
1003
1004
1005        # no CA loaded
1006        f_c = self.loop.create_connection(MyProto, host, port,
1007                                          ssl=sslcontext_client)
1008        with mock.patch.object(self.loop, 'call_exception_handler'):
1009            with test_utils.disable_logger():
1010                with self.assertRaisesRegex(ssl.SSLError,
1011                                            '(?i)certificate.verify.failed'):
1012                    self.loop.run_until_complete(f_c)
1013
1014            # execute the loop to log the connection error
1015            test_utils.run_briefly(self.loop)
1016
1017        # close connection
1018        self.assertIsNone(proto.transport)
1019        server.close()
1020
1021    @socket_helper.skip_unless_bind_unix_socket
1022    @unittest.skipIf(ssl is None, 'No ssl module')
1023    def test_create_unix_server_ssl_verify_failed(self):
1024        proto = MyProto(loop=self.loop)
1025        server, path = self._make_ssl_unix_server(
1026            lambda: proto, test_utils.SIGNED_CERTFILE)
1027
1028        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1029        sslcontext_client.options |= ssl.OP_NO_SSLv2
1030        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1031        if hasattr(sslcontext_client, 'check_hostname'):
1032            sslcontext_client.check_hostname = True
1033
1034        # no CA loaded
1035        f_c = self.loop.create_unix_connection(MyProto, path,
1036                                               ssl=sslcontext_client,
1037                                               server_hostname='invalid')
1038        with mock.patch.object(self.loop, 'call_exception_handler'):
1039            with test_utils.disable_logger():
1040                with self.assertRaisesRegex(ssl.SSLError,
1041                                            '(?i)certificate.verify.failed'):
1042                    self.loop.run_until_complete(f_c)
1043
1044            # execute the loop to log the connection error
1045            test_utils.run_briefly(self.loop)
1046
1047        # close connection
1048        self.assertIsNone(proto.transport)
1049        server.close()
1050
1051    @unittest.skipIf(ssl is None, 'No ssl module')
1052    def test_create_server_ssl_match_failed(self):
1053        proto = MyProto(loop=self.loop)
1054        server, host, port = self._make_ssl_server(
1055            lambda: proto, test_utils.SIGNED_CERTFILE)
1056
1057        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1058        sslcontext_client.options |= ssl.OP_NO_SSLv2
1059        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1060        sslcontext_client.load_verify_locations(
1061            cafile=test_utils.SIGNING_CA)
1062        if hasattr(sslcontext_client, 'check_hostname'):
1063            sslcontext_client.check_hostname = True
1064
1065        # incorrect server_hostname
1066        f_c = self.loop.create_connection(MyProto, host, port,
1067                                          ssl=sslcontext_client)
1068        with mock.patch.object(self.loop, 'call_exception_handler'):
1069            with test_utils.disable_logger():
1070                with self.assertRaisesRegex(
1071                        ssl.CertificateError,
1072                        "IP address mismatch, certificate is not valid for "
1073                        "'127.0.0.1'"):
1074                    self.loop.run_until_complete(f_c)
1075
1076        # close connection
1077        # transport is None because TLS ALERT aborted the handshake
1078        self.assertIsNone(proto.transport)
1079        server.close()
1080
1081    @socket_helper.skip_unless_bind_unix_socket
1082    @unittest.skipIf(ssl is None, 'No ssl module')
1083    def test_create_unix_server_ssl_verified(self):
1084        proto = MyProto(loop=self.loop)
1085        server, path = self._make_ssl_unix_server(
1086            lambda: proto, test_utils.SIGNED_CERTFILE)
1087
1088        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1089        sslcontext_client.options |= ssl.OP_NO_SSLv2
1090        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1091        sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA)
1092        if hasattr(sslcontext_client, 'check_hostname'):
1093            sslcontext_client.check_hostname = True
1094
1095        # Connection succeeds with correct CA and server hostname.
1096        f_c = self.loop.create_unix_connection(MyProto, path,
1097                                               ssl=sslcontext_client,
1098                                               server_hostname='localhost')
1099        client, pr = self.loop.run_until_complete(f_c)
1100        self.loop.run_until_complete(proto.connected)
1101
1102        # close connection
1103        proto.transport.close()
1104        client.close()
1105        server.close()
1106        self.loop.run_until_complete(proto.done)
1107
1108    @unittest.skipIf(ssl is None, 'No ssl module')
1109    def test_create_server_ssl_verified(self):
1110        proto = MyProto(loop=self.loop)
1111        server, host, port = self._make_ssl_server(
1112            lambda: proto, test_utils.SIGNED_CERTFILE)
1113
1114        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1115        sslcontext_client.options |= ssl.OP_NO_SSLv2
1116        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1117        sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA)
1118        if hasattr(sslcontext_client, 'check_hostname'):
1119            sslcontext_client.check_hostname = True
1120
1121        # Connection succeeds with correct CA and server hostname.
1122        f_c = self.loop.create_connection(MyProto, host, port,
1123                                          ssl=sslcontext_client,
1124                                          server_hostname='localhost')
1125        client, pr = self.loop.run_until_complete(f_c)
1126        self.loop.run_until_complete(proto.connected)
1127
1128        # extra info is available
1129        self.check_ssl_extra_info(client, peername=(host, port),
1130                                  peercert=test_utils.PEERCERT)
1131
1132        # close connection
1133        proto.transport.close()
1134        client.close()
1135        server.close()
1136        self.loop.run_until_complete(proto.done)
1137
1138    def test_create_server_sock(self):
1139        proto = self.loop.create_future()
1140
1141        class TestMyProto(MyProto):
1142            def connection_made(self, transport):
1143                super().connection_made(transport)
1144                proto.set_result(self)
1145
1146        sock_ob = socket.create_server(('0.0.0.0', 0))
1147
1148        f = self.loop.create_server(TestMyProto, sock=sock_ob)
1149        server = self.loop.run_until_complete(f)
1150        sock = server.sockets[0]
1151        self.assertEqual(sock.fileno(), sock_ob.fileno())
1152
1153        host, port = sock.getsockname()
1154        self.assertEqual(host, '0.0.0.0')
1155        client = socket.socket()
1156        client.connect(('127.0.0.1', port))
1157        client.send(b'xxx')
1158        client.close()
1159        server.close()
1160
1161    def test_create_server_addr_in_use(self):
1162        sock_ob = socket.create_server(('0.0.0.0', 0))
1163
1164        f = self.loop.create_server(MyProto, sock=sock_ob)
1165        server = self.loop.run_until_complete(f)
1166        sock = server.sockets[0]
1167        host, port = sock.getsockname()
1168
1169        f = self.loop.create_server(MyProto, host=host, port=port)
1170        with self.assertRaises(OSError) as cm:
1171            self.loop.run_until_complete(f)
1172        self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
1173
1174        server.close()
1175
1176    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 not supported or enabled')
1177    def test_create_server_dual_stack(self):
1178        f_proto = self.loop.create_future()
1179
1180        class TestMyProto(MyProto):
1181            def connection_made(self, transport):
1182                super().connection_made(transport)
1183                f_proto.set_result(self)
1184
1185        try_count = 0
1186        while True:
1187            try:
1188                port = socket_helper.find_unused_port()
1189                f = self.loop.create_server(TestMyProto, host=None, port=port)
1190                server = self.loop.run_until_complete(f)
1191            except OSError as ex:
1192                if ex.errno == errno.EADDRINUSE:
1193                    try_count += 1
1194                    self.assertGreaterEqual(5, try_count)
1195                    continue
1196                else:
1197                    raise
1198            else:
1199                break
1200        client = socket.socket()
1201        client.connect(('127.0.0.1', port))
1202        client.send(b'xxx')
1203        proto = self.loop.run_until_complete(f_proto)
1204        proto.transport.close()
1205        client.close()
1206
1207        f_proto = self.loop.create_future()
1208        client = socket.socket(socket.AF_INET6)
1209        client.connect(('::1', port))
1210        client.send(b'xxx')
1211        proto = self.loop.run_until_complete(f_proto)
1212        proto.transport.close()
1213        client.close()
1214
1215        server.close()
1216
1217    def test_server_close(self):
1218        f = self.loop.create_server(MyProto, '0.0.0.0', 0)
1219        server = self.loop.run_until_complete(f)
1220        sock = server.sockets[0]
1221        host, port = sock.getsockname()
1222
1223        client = socket.socket()
1224        client.connect(('127.0.0.1', port))
1225        client.send(b'xxx')
1226        client.close()
1227
1228        server.close()
1229
1230        client = socket.socket()
1231        self.assertRaises(
1232            ConnectionRefusedError, client.connect, ('127.0.0.1', port))
1233        client.close()
1234
1235    def _test_create_datagram_endpoint(self, local_addr, family):
1236        class TestMyDatagramProto(MyDatagramProto):
1237            def __init__(inner_self):
1238                super().__init__(loop=self.loop)
1239
1240            def datagram_received(self, data, addr):
1241                super().datagram_received(data, addr)
1242                self.transport.sendto(b'resp:'+data, addr)
1243
1244        coro = self.loop.create_datagram_endpoint(
1245            TestMyDatagramProto, local_addr=local_addr, family=family)
1246        s_transport, server = self.loop.run_until_complete(coro)
1247        sockname = s_transport.get_extra_info('sockname')
1248        host, port = socket.getnameinfo(
1249            sockname, socket.NI_NUMERICHOST|socket.NI_NUMERICSERV)
1250
1251        self.assertIsInstance(s_transport, asyncio.Transport)
1252        self.assertIsInstance(server, TestMyDatagramProto)
1253        self.assertEqual('INITIALIZED', server.state)
1254        self.assertIs(server.transport, s_transport)
1255
1256        coro = self.loop.create_datagram_endpoint(
1257            lambda: MyDatagramProto(loop=self.loop),
1258            remote_addr=(host, port))
1259        transport, client = self.loop.run_until_complete(coro)
1260
1261        self.assertIsInstance(transport, asyncio.Transport)
1262        self.assertIsInstance(client, MyDatagramProto)
1263        self.assertEqual('INITIALIZED', client.state)
1264        self.assertIs(client.transport, transport)
1265
1266        transport.sendto(b'xxx')
1267        test_utils.run_until(self.loop, lambda: server.nbytes)
1268        self.assertEqual(3, server.nbytes)
1269        test_utils.run_until(self.loop, lambda: client.nbytes)
1270
1271        # received
1272        self.assertEqual(8, client.nbytes)
1273
1274        # extra info is available
1275        self.assertIsNotNone(transport.get_extra_info('sockname'))
1276
1277        # close connection
1278        transport.close()
1279        self.loop.run_until_complete(client.done)
1280        self.assertEqual('CLOSED', client.state)
1281        server.transport.close()
1282
1283    def test_create_datagram_endpoint(self):
1284        self._test_create_datagram_endpoint(('127.0.0.1', 0), socket.AF_INET)
1285
1286    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 not supported or enabled')
1287    def test_create_datagram_endpoint_ipv6(self):
1288        self._test_create_datagram_endpoint(('::1', 0), socket.AF_INET6)
1289
1290    def test_create_datagram_endpoint_sock(self):
1291        sock = None
1292        local_address = ('127.0.0.1', 0)
1293        infos = self.loop.run_until_complete(
1294            self.loop.getaddrinfo(
1295                *local_address, type=socket.SOCK_DGRAM))
1296        for family, type, proto, cname, address in infos:
1297            try:
1298                sock = socket.socket(family=family, type=type, proto=proto)
1299                sock.setblocking(False)
1300                sock.bind(address)
1301            except:
1302                pass
1303            else:
1304                break
1305        else:
1306            self.fail('Can not create socket.')
1307
1308        f = self.loop.create_datagram_endpoint(
1309            lambda: MyDatagramProto(loop=self.loop), sock=sock)
1310        tr, pr = self.loop.run_until_complete(f)
1311        self.assertIsInstance(tr, asyncio.Transport)
1312        self.assertIsInstance(pr, MyDatagramProto)
1313        tr.close()
1314        self.loop.run_until_complete(pr.done)
1315
1316    def test_internal_fds(self):
1317        loop = self.create_event_loop()
1318        if not isinstance(loop, selector_events.BaseSelectorEventLoop):
1319            loop.close()
1320            self.skipTest('loop is not a BaseSelectorEventLoop')
1321
1322        self.assertEqual(1, loop._internal_fds)
1323        loop.close()
1324        self.assertEqual(0, loop._internal_fds)
1325        self.assertIsNone(loop._csock)
1326        self.assertIsNone(loop._ssock)
1327
1328    @unittest.skipUnless(sys.platform != 'win32',
1329                         "Don't support pipes for Windows")
1330    def test_read_pipe(self):
1331        proto = MyReadPipeProto(loop=self.loop)
1332
1333        rpipe, wpipe = os.pipe()
1334        pipeobj = io.open(rpipe, 'rb', 1024)
1335
1336        async def connect():
1337            t, p = await self.loop.connect_read_pipe(
1338                lambda: proto, pipeobj)
1339            self.assertIs(p, proto)
1340            self.assertIs(t, proto.transport)
1341            self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
1342            self.assertEqual(0, proto.nbytes)
1343
1344        self.loop.run_until_complete(connect())
1345
1346        os.write(wpipe, b'1')
1347        test_utils.run_until(self.loop, lambda: proto.nbytes >= 1)
1348        self.assertEqual(1, proto.nbytes)
1349
1350        os.write(wpipe, b'2345')
1351        test_utils.run_until(self.loop, lambda: proto.nbytes >= 5)
1352        self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
1353        self.assertEqual(5, proto.nbytes)
1354
1355        os.close(wpipe)
1356        self.loop.run_until_complete(proto.done)
1357        self.assertEqual(
1358            ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
1359        # extra info is available
1360        self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
1361
1362    @unittest.skipUnless(sys.platform != 'win32',
1363                         "Don't support pipes for Windows")
1364    def test_unclosed_pipe_transport(self):
1365        # This test reproduces the issue #314 on GitHub
1366        loop = self.create_event_loop()
1367        read_proto = MyReadPipeProto(loop=loop)
1368        write_proto = MyWritePipeProto(loop=loop)
1369
1370        rpipe, wpipe = os.pipe()
1371        rpipeobj = io.open(rpipe, 'rb', 1024)
1372        wpipeobj = io.open(wpipe, 'w', 1024, encoding="utf-8")
1373
1374        async def connect():
1375            read_transport, _ = await loop.connect_read_pipe(
1376                lambda: read_proto, rpipeobj)
1377            write_transport, _ = await loop.connect_write_pipe(
1378                lambda: write_proto, wpipeobj)
1379            return read_transport, write_transport
1380
1381        # Run and close the loop without closing the transports
1382        read_transport, write_transport = loop.run_until_complete(connect())
1383        loop.close()
1384
1385        # These 'repr' calls used to raise an AttributeError
1386        # See Issue #314 on GitHub
1387        self.assertIn('open', repr(read_transport))
1388        self.assertIn('open', repr(write_transport))
1389
1390        # Clean up (avoid ResourceWarning)
1391        rpipeobj.close()
1392        wpipeobj.close()
1393        read_transport._pipe = None
1394        write_transport._pipe = None
1395
1396    @unittest.skipUnless(sys.platform != 'win32',
1397                         "Don't support pipes for Windows")
1398    @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()')
1399    def test_read_pty_output(self):
1400        proto = MyReadPipeProto(loop=self.loop)
1401
1402        master, slave = os.openpty()
1403        master_read_obj = io.open(master, 'rb', 0)
1404
1405        async def connect():
1406            t, p = await self.loop.connect_read_pipe(lambda: proto,
1407                                                     master_read_obj)
1408            self.assertIs(p, proto)
1409            self.assertIs(t, proto.transport)
1410            self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
1411            self.assertEqual(0, proto.nbytes)
1412
1413        self.loop.run_until_complete(connect())
1414
1415        os.write(slave, b'1')
1416        test_utils.run_until(self.loop, lambda: proto.nbytes)
1417        self.assertEqual(1, proto.nbytes)
1418
1419        os.write(slave, b'2345')
1420        test_utils.run_until(self.loop, lambda: proto.nbytes >= 5)
1421        self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
1422        self.assertEqual(5, proto.nbytes)
1423
1424        os.close(slave)
1425        proto.transport.close()
1426        self.loop.run_until_complete(proto.done)
1427        self.assertEqual(
1428            ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
1429        # extra info is available
1430        self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
1431
1432    @unittest.skipUnless(sys.platform != 'win32',
1433                         "Don't support pipes for Windows")
1434    def test_write_pipe(self):
1435        rpipe, wpipe = os.pipe()
1436        pipeobj = io.open(wpipe, 'wb', 1024)
1437
1438        proto = MyWritePipeProto(loop=self.loop)
1439        connect = self.loop.connect_write_pipe(lambda: proto, pipeobj)
1440        transport, p = self.loop.run_until_complete(connect)
1441        self.assertIs(p, proto)
1442        self.assertIs(transport, proto.transport)
1443        self.assertEqual('CONNECTED', proto.state)
1444
1445        transport.write(b'1')
1446
1447        data = bytearray()
1448        def reader(data):
1449            chunk = os.read(rpipe, 1024)
1450            data += chunk
1451            return len(data)
1452
1453        test_utils.run_until(self.loop, lambda: reader(data) >= 1)
1454        self.assertEqual(b'1', data)
1455
1456        transport.write(b'2345')
1457        test_utils.run_until(self.loop, lambda: reader(data) >= 5)
1458        self.assertEqual(b'12345', data)
1459        self.assertEqual('CONNECTED', proto.state)
1460
1461        os.close(rpipe)
1462
1463        # extra info is available
1464        self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
1465
1466        # close connection
1467        proto.transport.close()
1468        self.loop.run_until_complete(proto.done)
1469        self.assertEqual('CLOSED', proto.state)
1470
1471    @unittest.skipUnless(sys.platform != 'win32',
1472                         "Don't support pipes for Windows")
1473    def test_write_pipe_disconnect_on_close(self):
1474        rsock, wsock = socket.socketpair()
1475        rsock.setblocking(False)
1476        pipeobj = io.open(wsock.detach(), 'wb', 1024)
1477
1478        proto = MyWritePipeProto(loop=self.loop)
1479        connect = self.loop.connect_write_pipe(lambda: proto, pipeobj)
1480        transport, p = self.loop.run_until_complete(connect)
1481        self.assertIs(p, proto)
1482        self.assertIs(transport, proto.transport)
1483        self.assertEqual('CONNECTED', proto.state)
1484
1485        transport.write(b'1')
1486        data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024))
1487        self.assertEqual(b'1', data)
1488
1489        rsock.close()
1490
1491        self.loop.run_until_complete(proto.done)
1492        self.assertEqual('CLOSED', proto.state)
1493
1494    @unittest.skipUnless(sys.platform != 'win32',
1495                         "Don't support pipes for Windows")
1496    @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()')
1497    # select, poll and kqueue don't support character devices (PTY) on Mac OS X
1498    # older than 10.6 (Snow Leopard)
1499    @support.requires_mac_ver(10, 6)
1500    def test_write_pty(self):
1501        master, slave = os.openpty()
1502        slave_write_obj = io.open(slave, 'wb', 0)
1503
1504        proto = MyWritePipeProto(loop=self.loop)
1505        connect = self.loop.connect_write_pipe(lambda: proto, slave_write_obj)
1506        transport, p = self.loop.run_until_complete(connect)
1507        self.assertIs(p, proto)
1508        self.assertIs(transport, proto.transport)
1509        self.assertEqual('CONNECTED', proto.state)
1510
1511        transport.write(b'1')
1512
1513        data = bytearray()
1514        def reader(data):
1515            chunk = os.read(master, 1024)
1516            data += chunk
1517            return len(data)
1518
1519        test_utils.run_until(self.loop, lambda: reader(data) >= 1,
1520                             timeout=support.SHORT_TIMEOUT)
1521        self.assertEqual(b'1', data)
1522
1523        transport.write(b'2345')
1524        test_utils.run_until(self.loop, lambda: reader(data) >= 5,
1525                             timeout=support.SHORT_TIMEOUT)
1526        self.assertEqual(b'12345', data)
1527        self.assertEqual('CONNECTED', proto.state)
1528
1529        os.close(master)
1530
1531        # extra info is available
1532        self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
1533
1534        # close connection
1535        proto.transport.close()
1536        self.loop.run_until_complete(proto.done)
1537        self.assertEqual('CLOSED', proto.state)
1538
1539    @unittest.skipUnless(sys.platform != 'win32',
1540                         "Don't support pipes for Windows")
1541    @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()')
1542    # select, poll and kqueue don't support character devices (PTY) on Mac OS X
1543    # older than 10.6 (Snow Leopard)
1544    @support.requires_mac_ver(10, 6)
1545    def test_bidirectional_pty(self):
1546        master, read_slave = os.openpty()
1547        write_slave = os.dup(read_slave)
1548        tty.setraw(read_slave)
1549
1550        slave_read_obj = io.open(read_slave, 'rb', 0)
1551        read_proto = MyReadPipeProto(loop=self.loop)
1552        read_connect = self.loop.connect_read_pipe(lambda: read_proto,
1553                                                   slave_read_obj)
1554        read_transport, p = self.loop.run_until_complete(read_connect)
1555        self.assertIs(p, read_proto)
1556        self.assertIs(read_transport, read_proto.transport)
1557        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1558        self.assertEqual(0, read_proto.nbytes)
1559
1560
1561        slave_write_obj = io.open(write_slave, 'wb', 0)
1562        write_proto = MyWritePipeProto(loop=self.loop)
1563        write_connect = self.loop.connect_write_pipe(lambda: write_proto,
1564                                                     slave_write_obj)
1565        write_transport, p = self.loop.run_until_complete(write_connect)
1566        self.assertIs(p, write_proto)
1567        self.assertIs(write_transport, write_proto.transport)
1568        self.assertEqual('CONNECTED', write_proto.state)
1569
1570        data = bytearray()
1571        def reader(data):
1572            chunk = os.read(master, 1024)
1573            data += chunk
1574            return len(data)
1575
1576        write_transport.write(b'1')
1577        test_utils.run_until(self.loop, lambda: reader(data) >= 1,
1578                             timeout=support.SHORT_TIMEOUT)
1579        self.assertEqual(b'1', data)
1580        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1581        self.assertEqual('CONNECTED', write_proto.state)
1582
1583        os.write(master, b'a')
1584        test_utils.run_until(self.loop, lambda: read_proto.nbytes >= 1,
1585                             timeout=support.SHORT_TIMEOUT)
1586        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1587        self.assertEqual(1, read_proto.nbytes)
1588        self.assertEqual('CONNECTED', write_proto.state)
1589
1590        write_transport.write(b'2345')
1591        test_utils.run_until(self.loop, lambda: reader(data) >= 5,
1592                             timeout=support.SHORT_TIMEOUT)
1593        self.assertEqual(b'12345', data)
1594        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1595        self.assertEqual('CONNECTED', write_proto.state)
1596
1597        os.write(master, b'bcde')
1598        test_utils.run_until(self.loop, lambda: read_proto.nbytes >= 5,
1599                             timeout=support.SHORT_TIMEOUT)
1600        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1601        self.assertEqual(5, read_proto.nbytes)
1602        self.assertEqual('CONNECTED', write_proto.state)
1603
1604        os.close(master)
1605
1606        read_transport.close()
1607        self.loop.run_until_complete(read_proto.done)
1608        self.assertEqual(
1609            ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], read_proto.state)
1610
1611        write_transport.close()
1612        self.loop.run_until_complete(write_proto.done)
1613        self.assertEqual('CLOSED', write_proto.state)
1614
1615    def test_prompt_cancellation(self):
1616        r, w = socket.socketpair()
1617        r.setblocking(False)
1618        f = self.loop.create_task(self.loop.sock_recv(r, 1))
1619        ov = getattr(f, 'ov', None)
1620        if ov is not None:
1621            self.assertTrue(ov.pending)
1622
1623        async def main():
1624            try:
1625                self.loop.call_soon(f.cancel)
1626                await f
1627            except asyncio.CancelledError:
1628                res = 'cancelled'
1629            else:
1630                res = None
1631            finally:
1632                self.loop.stop()
1633            return res
1634
1635        start = time.monotonic()
1636        t = self.loop.create_task(main())
1637        self.loop.run_forever()
1638        elapsed = time.monotonic() - start
1639
1640        self.assertLess(elapsed, 0.1)
1641        self.assertEqual(t.result(), 'cancelled')
1642        self.assertRaises(asyncio.CancelledError, f.result)
1643        if ov is not None:
1644            self.assertFalse(ov.pending)
1645        self.loop._stop_serving(r)
1646
1647        r.close()
1648        w.close()
1649
1650    def test_timeout_rounding(self):
1651        def _run_once():
1652            self.loop._run_once_counter += 1
1653            orig_run_once()
1654
1655        orig_run_once = self.loop._run_once
1656        self.loop._run_once_counter = 0
1657        self.loop._run_once = _run_once
1658
1659        async def wait():
1660            loop = self.loop
1661            await asyncio.sleep(1e-2)
1662            await asyncio.sleep(1e-4)
1663            await asyncio.sleep(1e-6)
1664            await asyncio.sleep(1e-8)
1665            await asyncio.sleep(1e-10)
1666
1667        self.loop.run_until_complete(wait())
1668        # The ideal number of call is 12, but on some platforms, the selector
1669        # may sleep at little bit less than timeout depending on the resolution
1670        # of the clock used by the kernel. Tolerate a few useless calls on
1671        # these platforms.
1672        self.assertLessEqual(self.loop._run_once_counter, 20,
1673            {'clock_resolution': self.loop._clock_resolution,
1674             'selector': self.loop._selector.__class__.__name__})
1675
1676    def test_remove_fds_after_closing(self):
1677        loop = self.create_event_loop()
1678        callback = lambda: None
1679        r, w = socket.socketpair()
1680        self.addCleanup(r.close)
1681        self.addCleanup(w.close)
1682        loop.add_reader(r, callback)
1683        loop.add_writer(w, callback)
1684        loop.close()
1685        self.assertFalse(loop.remove_reader(r))
1686        self.assertFalse(loop.remove_writer(w))
1687
1688    def test_add_fds_after_closing(self):
1689        loop = self.create_event_loop()
1690        callback = lambda: None
1691        r, w = socket.socketpair()
1692        self.addCleanup(r.close)
1693        self.addCleanup(w.close)
1694        loop.close()
1695        with self.assertRaises(RuntimeError):
1696            loop.add_reader(r, callback)
1697        with self.assertRaises(RuntimeError):
1698            loop.add_writer(w, callback)
1699
1700    def test_close_running_event_loop(self):
1701        async def close_loop(loop):
1702            self.loop.close()
1703
1704        coro = close_loop(self.loop)
1705        with self.assertRaises(RuntimeError):
1706            self.loop.run_until_complete(coro)
1707
1708    def test_close(self):
1709        self.loop.close()
1710
1711        async def test():
1712            pass
1713
1714        func = lambda: False
1715        coro = test()
1716        self.addCleanup(coro.close)
1717
1718        # operation blocked when the loop is closed
1719        with self.assertRaises(RuntimeError):
1720            self.loop.run_forever()
1721        with self.assertRaises(RuntimeError):
1722            fut = self.loop.create_future()
1723            self.loop.run_until_complete(fut)
1724        with self.assertRaises(RuntimeError):
1725            self.loop.call_soon(func)
1726        with self.assertRaises(RuntimeError):
1727            self.loop.call_soon_threadsafe(func)
1728        with self.assertRaises(RuntimeError):
1729            self.loop.call_later(1.0, func)
1730        with self.assertRaises(RuntimeError):
1731            self.loop.call_at(self.loop.time() + .0, func)
1732        with self.assertRaises(RuntimeError):
1733            self.loop.create_task(coro)
1734        with self.assertRaises(RuntimeError):
1735            self.loop.add_signal_handler(signal.SIGTERM, func)
1736
1737        # run_in_executor test is tricky: the method is a coroutine,
1738        # but run_until_complete cannot be called on closed loop.
1739        # Thus iterate once explicitly.
1740        with self.assertRaises(RuntimeError):
1741            it = self.loop.run_in_executor(None, func).__await__()
1742            next(it)
1743
1744
1745class SubprocessTestsMixin:
1746
1747    def check_terminated(self, returncode):
1748        if sys.platform == 'win32':
1749            self.assertIsInstance(returncode, int)
1750            # expect 1 but sometimes get 0
1751        else:
1752            self.assertEqual(-signal.SIGTERM, returncode)
1753
1754    def check_killed(self, returncode):
1755        if sys.platform == 'win32':
1756            self.assertIsInstance(returncode, int)
1757            # expect 1 but sometimes get 0
1758        else:
1759            self.assertEqual(-signal.SIGKILL, returncode)
1760
1761    def test_subprocess_exec(self):
1762        prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1763
1764        connect = self.loop.subprocess_exec(
1765                        functools.partial(MySubprocessProtocol, self.loop),
1766                        sys.executable, prog)
1767
1768        transp, proto = self.loop.run_until_complete(connect)
1769        self.assertIsInstance(proto, MySubprocessProtocol)
1770        self.loop.run_until_complete(proto.connected)
1771        self.assertEqual('CONNECTED', proto.state)
1772
1773        stdin = transp.get_pipe_transport(0)
1774        stdin.write(b'Python The Winner')
1775        self.loop.run_until_complete(proto.got_data[1].wait())
1776        with test_utils.disable_logger():
1777            transp.close()
1778        self.loop.run_until_complete(proto.completed)
1779        self.check_killed(proto.returncode)
1780        self.assertEqual(b'Python The Winner', proto.data[1])
1781
1782    def test_subprocess_interactive(self):
1783        prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1784
1785        connect = self.loop.subprocess_exec(
1786                        functools.partial(MySubprocessProtocol, self.loop),
1787                        sys.executable, prog)
1788
1789        transp, proto = self.loop.run_until_complete(connect)
1790        self.assertIsInstance(proto, MySubprocessProtocol)
1791        self.loop.run_until_complete(proto.connected)
1792        self.assertEqual('CONNECTED', proto.state)
1793
1794        stdin = transp.get_pipe_transport(0)
1795        stdin.write(b'Python ')
1796        self.loop.run_until_complete(proto.got_data[1].wait())
1797        proto.got_data[1].clear()
1798        self.assertEqual(b'Python ', proto.data[1])
1799
1800        stdin.write(b'The Winner')
1801        self.loop.run_until_complete(proto.got_data[1].wait())
1802        self.assertEqual(b'Python The Winner', proto.data[1])
1803
1804        with test_utils.disable_logger():
1805            transp.close()
1806        self.loop.run_until_complete(proto.completed)
1807        self.check_killed(proto.returncode)
1808
1809    def test_subprocess_shell(self):
1810        connect = self.loop.subprocess_shell(
1811                        functools.partial(MySubprocessProtocol, self.loop),
1812                        'echo Python')
1813        transp, proto = self.loop.run_until_complete(connect)
1814        self.assertIsInstance(proto, MySubprocessProtocol)
1815        self.loop.run_until_complete(proto.connected)
1816
1817        transp.get_pipe_transport(0).close()
1818        self.loop.run_until_complete(proto.completed)
1819        self.assertEqual(0, proto.returncode)
1820        self.assertTrue(all(f.done() for f in proto.disconnects.values()))
1821        self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python')
1822        self.assertEqual(proto.data[2], b'')
1823        transp.close()
1824
1825    def test_subprocess_exitcode(self):
1826        connect = self.loop.subprocess_shell(
1827                        functools.partial(MySubprocessProtocol, self.loop),
1828                        'exit 7', stdin=None, stdout=None, stderr=None)
1829
1830        transp, proto = self.loop.run_until_complete(connect)
1831        self.assertIsInstance(proto, MySubprocessProtocol)
1832        self.loop.run_until_complete(proto.completed)
1833        self.assertEqual(7, proto.returncode)
1834        transp.close()
1835
1836    def test_subprocess_close_after_finish(self):
1837        connect = self.loop.subprocess_shell(
1838                        functools.partial(MySubprocessProtocol, self.loop),
1839                        'exit 7', stdin=None, stdout=None, stderr=None)
1840
1841        transp, proto = self.loop.run_until_complete(connect)
1842        self.assertIsInstance(proto, MySubprocessProtocol)
1843        self.assertIsNone(transp.get_pipe_transport(0))
1844        self.assertIsNone(transp.get_pipe_transport(1))
1845        self.assertIsNone(transp.get_pipe_transport(2))
1846        self.loop.run_until_complete(proto.completed)
1847        self.assertEqual(7, proto.returncode)
1848        self.assertIsNone(transp.close())
1849
1850    def test_subprocess_kill(self):
1851        prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1852
1853        connect = self.loop.subprocess_exec(
1854                        functools.partial(MySubprocessProtocol, self.loop),
1855                        sys.executable, prog)
1856
1857        transp, proto = self.loop.run_until_complete(connect)
1858        self.assertIsInstance(proto, MySubprocessProtocol)
1859        self.loop.run_until_complete(proto.connected)
1860
1861        transp.kill()
1862        self.loop.run_until_complete(proto.completed)
1863        self.check_killed(proto.returncode)
1864        transp.close()
1865
1866    def test_subprocess_terminate(self):
1867        prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1868
1869        connect = self.loop.subprocess_exec(
1870                        functools.partial(MySubprocessProtocol, self.loop),
1871                        sys.executable, prog)
1872
1873        transp, proto = self.loop.run_until_complete(connect)
1874        self.assertIsInstance(proto, MySubprocessProtocol)
1875        self.loop.run_until_complete(proto.connected)
1876
1877        transp.terminate()
1878        self.loop.run_until_complete(proto.completed)
1879        self.check_terminated(proto.returncode)
1880        transp.close()
1881
1882    @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP")
1883    def test_subprocess_send_signal(self):
1884        # bpo-31034: Make sure that we get the default signal handler (killing
1885        # the process). The parent process may have decided to ignore SIGHUP,
1886        # and signal handlers are inherited.
1887        old_handler = signal.signal(signal.SIGHUP, signal.SIG_DFL)
1888        try:
1889            prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1890
1891            connect = self.loop.subprocess_exec(
1892                            functools.partial(MySubprocessProtocol, self.loop),
1893                            sys.executable, prog)
1894
1895
1896            transp, proto = self.loop.run_until_complete(connect)
1897            self.assertIsInstance(proto, MySubprocessProtocol)
1898            self.loop.run_until_complete(proto.connected)
1899
1900            transp.send_signal(signal.SIGHUP)
1901            self.loop.run_until_complete(proto.completed)
1902            self.assertEqual(-signal.SIGHUP, proto.returncode)
1903            transp.close()
1904        finally:
1905            signal.signal(signal.SIGHUP, old_handler)
1906
1907    def test_subprocess_stderr(self):
1908        prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
1909
1910        connect = self.loop.subprocess_exec(
1911                        functools.partial(MySubprocessProtocol, self.loop),
1912                        sys.executable, prog)
1913
1914        transp, proto = self.loop.run_until_complete(connect)
1915        self.assertIsInstance(proto, MySubprocessProtocol)
1916        self.loop.run_until_complete(proto.connected)
1917
1918        stdin = transp.get_pipe_transport(0)
1919        stdin.write(b'test')
1920
1921        self.loop.run_until_complete(proto.completed)
1922
1923        transp.close()
1924        self.assertEqual(b'OUT:test', proto.data[1])
1925        self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2])
1926        self.assertEqual(0, proto.returncode)
1927
1928    def test_subprocess_stderr_redirect_to_stdout(self):
1929        prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
1930
1931        connect = self.loop.subprocess_exec(
1932                        functools.partial(MySubprocessProtocol, self.loop),
1933                        sys.executable, prog, stderr=subprocess.STDOUT)
1934
1935
1936        transp, proto = self.loop.run_until_complete(connect)
1937        self.assertIsInstance(proto, MySubprocessProtocol)
1938        self.loop.run_until_complete(proto.connected)
1939
1940        stdin = transp.get_pipe_transport(0)
1941        self.assertIsNotNone(transp.get_pipe_transport(1))
1942        self.assertIsNone(transp.get_pipe_transport(2))
1943
1944        stdin.write(b'test')
1945        self.loop.run_until_complete(proto.completed)
1946        self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'),
1947                        proto.data[1])
1948        self.assertEqual(b'', proto.data[2])
1949
1950        transp.close()
1951        self.assertEqual(0, proto.returncode)
1952
1953    def test_subprocess_close_client_stream(self):
1954        prog = os.path.join(os.path.dirname(__file__), 'echo3.py')
1955
1956        connect = self.loop.subprocess_exec(
1957                        functools.partial(MySubprocessProtocol, self.loop),
1958                        sys.executable, prog)
1959
1960        transp, proto = self.loop.run_until_complete(connect)
1961        self.assertIsInstance(proto, MySubprocessProtocol)
1962        self.loop.run_until_complete(proto.connected)
1963
1964        stdin = transp.get_pipe_transport(0)
1965        stdout = transp.get_pipe_transport(1)
1966        stdin.write(b'test')
1967        self.loop.run_until_complete(proto.got_data[1].wait())
1968        self.assertEqual(b'OUT:test', proto.data[1])
1969
1970        stdout.close()
1971        self.loop.run_until_complete(proto.disconnects[1])
1972        stdin.write(b'xxx')
1973        self.loop.run_until_complete(proto.got_data[2].wait())
1974        if sys.platform != 'win32':
1975            self.assertEqual(b'ERR:BrokenPipeError', proto.data[2])
1976        else:
1977            # After closing the read-end of a pipe, writing to the
1978            # write-end using os.write() fails with errno==EINVAL and
1979            # GetLastError()==ERROR_INVALID_NAME on Windows!?!  (Using
1980            # WriteFile() we get ERROR_BROKEN_PIPE as expected.)
1981            self.assertEqual(b'ERR:OSError', proto.data[2])
1982        with test_utils.disable_logger():
1983            transp.close()
1984        self.loop.run_until_complete(proto.completed)
1985        self.check_killed(proto.returncode)
1986
1987    def test_subprocess_wait_no_same_group(self):
1988        # start the new process in a new session
1989        connect = self.loop.subprocess_shell(
1990                        functools.partial(MySubprocessProtocol, self.loop),
1991                        'exit 7', stdin=None, stdout=None, stderr=None,
1992                        start_new_session=True)
1993        transp, proto = self.loop.run_until_complete(connect)
1994        self.assertIsInstance(proto, MySubprocessProtocol)
1995        self.loop.run_until_complete(proto.completed)
1996        self.assertEqual(7, proto.returncode)
1997        transp.close()
1998
1999    def test_subprocess_exec_invalid_args(self):
2000        async def connect(**kwds):
2001            await self.loop.subprocess_exec(
2002                asyncio.SubprocessProtocol,
2003                'pwd', **kwds)
2004
2005        with self.assertRaises(ValueError):
2006            self.loop.run_until_complete(connect(universal_newlines=True))
2007        with self.assertRaises(ValueError):
2008            self.loop.run_until_complete(connect(bufsize=4096))
2009        with self.assertRaises(ValueError):
2010            self.loop.run_until_complete(connect(shell=True))
2011
2012    def test_subprocess_shell_invalid_args(self):
2013
2014        async def connect(cmd=None, **kwds):
2015            if not cmd:
2016                cmd = 'pwd'
2017            await self.loop.subprocess_shell(
2018                asyncio.SubprocessProtocol,
2019                cmd, **kwds)
2020
2021        with self.assertRaises(ValueError):
2022            self.loop.run_until_complete(connect(['ls', '-l']))
2023        with self.assertRaises(ValueError):
2024            self.loop.run_until_complete(connect(universal_newlines=True))
2025        with self.assertRaises(ValueError):
2026            self.loop.run_until_complete(connect(bufsize=4096))
2027        with self.assertRaises(ValueError):
2028            self.loop.run_until_complete(connect(shell=False))
2029
2030
2031if sys.platform == 'win32':
2032
2033    class SelectEventLoopTests(EventLoopTestsMixin,
2034                               test_utils.TestCase):
2035
2036        def create_event_loop(self):
2037            return asyncio.SelectorEventLoop()
2038
2039    class ProactorEventLoopTests(EventLoopTestsMixin,
2040                                 SubprocessTestsMixin,
2041                                 test_utils.TestCase):
2042
2043        def create_event_loop(self):
2044            return asyncio.ProactorEventLoop()
2045
2046        def test_reader_callback(self):
2047            raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
2048
2049        def test_reader_callback_cancel(self):
2050            raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
2051
2052        def test_writer_callback(self):
2053            raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
2054
2055        def test_writer_callback_cancel(self):
2056            raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
2057
2058        def test_remove_fds_after_closing(self):
2059            raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
2060else:
2061    import selectors
2062
2063    class UnixEventLoopTestsMixin(EventLoopTestsMixin):
2064        def setUp(self):
2065            super().setUp()
2066            watcher = asyncio.SafeChildWatcher()
2067            watcher.attach_loop(self.loop)
2068            asyncio.set_child_watcher(watcher)
2069
2070        def tearDown(self):
2071            asyncio.set_child_watcher(None)
2072            super().tearDown()
2073
2074
2075    if hasattr(selectors, 'KqueueSelector'):
2076        class KqueueEventLoopTests(UnixEventLoopTestsMixin,
2077                                   SubprocessTestsMixin,
2078                                   test_utils.TestCase):
2079
2080            def create_event_loop(self):
2081                return asyncio.SelectorEventLoop(
2082                    selectors.KqueueSelector())
2083
2084            # kqueue doesn't support character devices (PTY) on Mac OS X older
2085            # than 10.9 (Maverick)
2086            @support.requires_mac_ver(10, 9)
2087            # Issue #20667: KqueueEventLoopTests.test_read_pty_output()
2088            # hangs on OpenBSD 5.5
2089            @unittest.skipIf(sys.platform.startswith('openbsd'),
2090                             'test hangs on OpenBSD')
2091            def test_read_pty_output(self):
2092                super().test_read_pty_output()
2093
2094            # kqueue doesn't support character devices (PTY) on Mac OS X older
2095            # than 10.9 (Maverick)
2096            @support.requires_mac_ver(10, 9)
2097            def test_write_pty(self):
2098                super().test_write_pty()
2099
2100    if hasattr(selectors, 'EpollSelector'):
2101        class EPollEventLoopTests(UnixEventLoopTestsMixin,
2102                                  SubprocessTestsMixin,
2103                                  test_utils.TestCase):
2104
2105            def create_event_loop(self):
2106                return asyncio.SelectorEventLoop(selectors.EpollSelector())
2107
2108    if hasattr(selectors, 'PollSelector'):
2109        class PollEventLoopTests(UnixEventLoopTestsMixin,
2110                                 SubprocessTestsMixin,
2111                                 test_utils.TestCase):
2112
2113            def create_event_loop(self):
2114                return asyncio.SelectorEventLoop(selectors.PollSelector())
2115
2116    # Should always exist.
2117    class SelectEventLoopTests(UnixEventLoopTestsMixin,
2118                               SubprocessTestsMixin,
2119                               test_utils.TestCase):
2120
2121        def create_event_loop(self):
2122            return asyncio.SelectorEventLoop(selectors.SelectSelector())
2123
2124
2125def noop(*args, **kwargs):
2126    pass
2127
2128
2129class HandleTests(test_utils.TestCase):
2130
2131    def setUp(self):
2132        super().setUp()
2133        self.loop = mock.Mock()
2134        self.loop.get_debug.return_value = True
2135
2136    def test_handle(self):
2137        def callback(*args):
2138            return args
2139
2140        args = ()
2141        h = asyncio.Handle(callback, args, self.loop)
2142        self.assertIs(h._callback, callback)
2143        self.assertIs(h._args, args)
2144        self.assertFalse(h.cancelled())
2145
2146        h.cancel()
2147        self.assertTrue(h.cancelled())
2148
2149    def test_callback_with_exception(self):
2150        def callback():
2151            raise ValueError()
2152
2153        self.loop = mock.Mock()
2154        self.loop.call_exception_handler = mock.Mock()
2155
2156        h = asyncio.Handle(callback, (), self.loop)
2157        h._run()
2158
2159        self.loop.call_exception_handler.assert_called_with({
2160            'message': test_utils.MockPattern('Exception in callback.*'),
2161            'exception': mock.ANY,
2162            'handle': h,
2163            'source_traceback': h._source_traceback,
2164        })
2165
2166    def test_handle_weakref(self):
2167        wd = weakref.WeakValueDictionary()
2168        h = asyncio.Handle(lambda: None, (), self.loop)
2169        wd['h'] = h  # Would fail without __weakref__ slot.
2170
2171    def test_handle_repr(self):
2172        self.loop.get_debug.return_value = False
2173
2174        # simple function
2175        h = asyncio.Handle(noop, (1, 2), self.loop)
2176        filename, lineno = test_utils.get_function_source(noop)
2177        self.assertEqual(repr(h),
2178                        '<Handle noop(1, 2) at %s:%s>'
2179                        % (filename, lineno))
2180
2181        # cancelled handle
2182        h.cancel()
2183        self.assertEqual(repr(h),
2184                        '<Handle cancelled>')
2185
2186        # decorated function
2187        with self.assertWarns(DeprecationWarning):
2188            cb = asyncio.coroutine(noop)
2189        h = asyncio.Handle(cb, (), self.loop)
2190        self.assertEqual(repr(h),
2191                        '<Handle noop() at %s:%s>'
2192                        % (filename, lineno))
2193
2194        # partial function
2195        cb = functools.partial(noop, 1, 2)
2196        h = asyncio.Handle(cb, (3,), self.loop)
2197        regex = (r'^<Handle noop\(1, 2\)\(3\) at %s:%s>$'
2198                 % (re.escape(filename), lineno))
2199        self.assertRegex(repr(h), regex)
2200
2201        # partial function with keyword args
2202        cb = functools.partial(noop, x=1)
2203        h = asyncio.Handle(cb, (2, 3), self.loop)
2204        regex = (r'^<Handle noop\(x=1\)\(2, 3\) at %s:%s>$'
2205                 % (re.escape(filename), lineno))
2206        self.assertRegex(repr(h), regex)
2207
2208        # partial method
2209        if sys.version_info >= (3, 4):
2210            method = HandleTests.test_handle_repr
2211            cb = functools.partialmethod(method)
2212            filename, lineno = test_utils.get_function_source(method)
2213            h = asyncio.Handle(cb, (), self.loop)
2214
2215            cb_regex = r'<function HandleTests.test_handle_repr .*>'
2216            cb_regex = (r'functools.partialmethod\(%s, , \)\(\)' % cb_regex)
2217            regex = (r'^<Handle %s at %s:%s>$'
2218                     % (cb_regex, re.escape(filename), lineno))
2219            self.assertRegex(repr(h), regex)
2220
2221    def test_handle_repr_debug(self):
2222        self.loop.get_debug.return_value = True
2223
2224        # simple function
2225        create_filename = __file__
2226        create_lineno = sys._getframe().f_lineno + 1
2227        h = asyncio.Handle(noop, (1, 2), self.loop)
2228        filename, lineno = test_utils.get_function_source(noop)
2229        self.assertEqual(repr(h),
2230                        '<Handle noop(1, 2) at %s:%s created at %s:%s>'
2231                        % (filename, lineno, create_filename, create_lineno))
2232
2233        # cancelled handle
2234        h.cancel()
2235        self.assertEqual(
2236            repr(h),
2237            '<Handle cancelled noop(1, 2) at %s:%s created at %s:%s>'
2238            % (filename, lineno, create_filename, create_lineno))
2239
2240        # double cancellation won't overwrite _repr
2241        h.cancel()
2242        self.assertEqual(
2243            repr(h),
2244            '<Handle cancelled noop(1, 2) at %s:%s created at %s:%s>'
2245            % (filename, lineno, create_filename, create_lineno))
2246
2247    def test_handle_source_traceback(self):
2248        loop = asyncio.get_event_loop_policy().new_event_loop()
2249        loop.set_debug(True)
2250        self.set_event_loop(loop)
2251
2252        def check_source_traceback(h):
2253            lineno = sys._getframe(1).f_lineno - 1
2254            self.assertIsInstance(h._source_traceback, list)
2255            self.assertEqual(h._source_traceback[-1][:3],
2256                             (__file__,
2257                              lineno,
2258                              'test_handle_source_traceback'))
2259
2260        # call_soon
2261        h = loop.call_soon(noop)
2262        check_source_traceback(h)
2263
2264        # call_soon_threadsafe
2265        h = loop.call_soon_threadsafe(noop)
2266        check_source_traceback(h)
2267
2268        # call_later
2269        h = loop.call_later(0, noop)
2270        check_source_traceback(h)
2271
2272        # call_at
2273        h = loop.call_later(0, noop)
2274        check_source_traceback(h)
2275
2276    @unittest.skipUnless(hasattr(collections.abc, 'Coroutine'),
2277                         'No collections.abc.Coroutine')
2278    def test_coroutine_like_object_debug_formatting(self):
2279        # Test that asyncio can format coroutines that are instances of
2280        # collections.abc.Coroutine, but lack cr_core or gi_code attributes
2281        # (such as ones compiled with Cython).
2282
2283        coro = CoroLike()
2284        coro.__name__ = 'AAA'
2285        self.assertTrue(asyncio.iscoroutine(coro))
2286        self.assertEqual(coroutines._format_coroutine(coro), 'AAA()')
2287
2288        coro.__qualname__ = 'BBB'
2289        self.assertEqual(coroutines._format_coroutine(coro), 'BBB()')
2290
2291        coro.cr_running = True
2292        self.assertEqual(coroutines._format_coroutine(coro), 'BBB() running')
2293
2294        coro.__name__ = coro.__qualname__ = None
2295        self.assertEqual(coroutines._format_coroutine(coro),
2296                         '<CoroLike without __name__>() running')
2297
2298        coro = CoroLike()
2299        coro.__qualname__ = 'CoroLike'
2300        # Some coroutines might not have '__name__', such as
2301        # built-in async_gen.asend().
2302        self.assertEqual(coroutines._format_coroutine(coro), 'CoroLike()')
2303
2304        coro = CoroLike()
2305        coro.__qualname__ = 'AAA'
2306        coro.cr_code = None
2307        self.assertEqual(coroutines._format_coroutine(coro), 'AAA()')
2308
2309
2310class TimerTests(unittest.TestCase):
2311
2312    def setUp(self):
2313        super().setUp()
2314        self.loop = mock.Mock()
2315
2316    def test_hash(self):
2317        when = time.monotonic()
2318        h = asyncio.TimerHandle(when, lambda: False, (),
2319                                mock.Mock())
2320        self.assertEqual(hash(h), hash(when))
2321
2322    def test_when(self):
2323        when = time.monotonic()
2324        h = asyncio.TimerHandle(when, lambda: False, (),
2325                                mock.Mock())
2326        self.assertEqual(when, h.when())
2327
2328    def test_timer(self):
2329        def callback(*args):
2330            return args
2331
2332        args = (1, 2, 3)
2333        when = time.monotonic()
2334        h = asyncio.TimerHandle(when, callback, args, mock.Mock())
2335        self.assertIs(h._callback, callback)
2336        self.assertIs(h._args, args)
2337        self.assertFalse(h.cancelled())
2338
2339        # cancel
2340        h.cancel()
2341        self.assertTrue(h.cancelled())
2342        self.assertIsNone(h._callback)
2343        self.assertIsNone(h._args)
2344
2345        # when cannot be None
2346        self.assertRaises(AssertionError,
2347                          asyncio.TimerHandle, None, callback, args,
2348                          self.loop)
2349
2350    def test_timer_repr(self):
2351        self.loop.get_debug.return_value = False
2352
2353        # simple function
2354        h = asyncio.TimerHandle(123, noop, (), self.loop)
2355        src = test_utils.get_function_source(noop)
2356        self.assertEqual(repr(h),
2357                        '<TimerHandle when=123 noop() at %s:%s>' % src)
2358
2359        # cancelled handle
2360        h.cancel()
2361        self.assertEqual(repr(h),
2362                        '<TimerHandle cancelled when=123>')
2363
2364    def test_timer_repr_debug(self):
2365        self.loop.get_debug.return_value = True
2366
2367        # simple function
2368        create_filename = __file__
2369        create_lineno = sys._getframe().f_lineno + 1
2370        h = asyncio.TimerHandle(123, noop, (), self.loop)
2371        filename, lineno = test_utils.get_function_source(noop)
2372        self.assertEqual(repr(h),
2373                        '<TimerHandle when=123 noop() '
2374                        'at %s:%s created at %s:%s>'
2375                        % (filename, lineno, create_filename, create_lineno))
2376
2377        # cancelled handle
2378        h.cancel()
2379        self.assertEqual(repr(h),
2380                        '<TimerHandle cancelled when=123 noop() '
2381                        'at %s:%s created at %s:%s>'
2382                        % (filename, lineno, create_filename, create_lineno))
2383
2384
2385    def test_timer_comparison(self):
2386        def callback(*args):
2387            return args
2388
2389        when = time.monotonic()
2390
2391        h1 = asyncio.TimerHandle(when, callback, (), self.loop)
2392        h2 = asyncio.TimerHandle(when, callback, (), self.loop)
2393        # TODO: Use assertLess etc.
2394        self.assertFalse(h1 < h2)
2395        self.assertFalse(h2 < h1)
2396        self.assertTrue(h1 <= h2)
2397        self.assertTrue(h2 <= h1)
2398        self.assertFalse(h1 > h2)
2399        self.assertFalse(h2 > h1)
2400        self.assertTrue(h1 >= h2)
2401        self.assertTrue(h2 >= h1)
2402        self.assertTrue(h1 == h2)
2403        self.assertFalse(h1 != h2)
2404
2405        h2.cancel()
2406        self.assertFalse(h1 == h2)
2407
2408        h1 = asyncio.TimerHandle(when, callback, (), self.loop)
2409        h2 = asyncio.TimerHandle(when + 10.0, callback, (), self.loop)
2410        self.assertTrue(h1 < h2)
2411        self.assertFalse(h2 < h1)
2412        self.assertTrue(h1 <= h2)
2413        self.assertFalse(h2 <= h1)
2414        self.assertFalse(h1 > h2)
2415        self.assertTrue(h2 > h1)
2416        self.assertFalse(h1 >= h2)
2417        self.assertTrue(h2 >= h1)
2418        self.assertFalse(h1 == h2)
2419        self.assertTrue(h1 != h2)
2420
2421        h3 = asyncio.Handle(callback, (), self.loop)
2422        self.assertIs(NotImplemented, h1.__eq__(h3))
2423        self.assertIs(NotImplemented, h1.__ne__(h3))
2424
2425        with self.assertRaises(TypeError):
2426            h1 < ()
2427        with self.assertRaises(TypeError):
2428            h1 > ()
2429        with self.assertRaises(TypeError):
2430            h1 <= ()
2431        with self.assertRaises(TypeError):
2432            h1 >= ()
2433        self.assertFalse(h1 == ())
2434        self.assertTrue(h1 != ())
2435
2436        self.assertTrue(h1 == ALWAYS_EQ)
2437        self.assertFalse(h1 != ALWAYS_EQ)
2438        self.assertTrue(h1 < LARGEST)
2439        self.assertFalse(h1 > LARGEST)
2440        self.assertTrue(h1 <= LARGEST)
2441        self.assertFalse(h1 >= LARGEST)
2442        self.assertFalse(h1 < SMALLEST)
2443        self.assertTrue(h1 > SMALLEST)
2444        self.assertFalse(h1 <= SMALLEST)
2445        self.assertTrue(h1 >= SMALLEST)
2446
2447
2448class AbstractEventLoopTests(unittest.TestCase):
2449
2450    def test_not_implemented(self):
2451        f = mock.Mock()
2452        loop = asyncio.AbstractEventLoop()
2453        self.assertRaises(
2454            NotImplementedError, loop.run_forever)
2455        self.assertRaises(
2456            NotImplementedError, loop.run_until_complete, None)
2457        self.assertRaises(
2458            NotImplementedError, loop.stop)
2459        self.assertRaises(
2460            NotImplementedError, loop.is_running)
2461        self.assertRaises(
2462            NotImplementedError, loop.is_closed)
2463        self.assertRaises(
2464            NotImplementedError, loop.close)
2465        self.assertRaises(
2466            NotImplementedError, loop.create_task, None)
2467        self.assertRaises(
2468            NotImplementedError, loop.call_later, None, None)
2469        self.assertRaises(
2470            NotImplementedError, loop.call_at, f, f)
2471        self.assertRaises(
2472            NotImplementedError, loop.call_soon, None)
2473        self.assertRaises(
2474            NotImplementedError, loop.time)
2475        self.assertRaises(
2476            NotImplementedError, loop.call_soon_threadsafe, None)
2477        self.assertRaises(
2478            NotImplementedError, loop.set_default_executor, f)
2479        self.assertRaises(
2480            NotImplementedError, loop.add_reader, 1, f)
2481        self.assertRaises(
2482            NotImplementedError, loop.remove_reader, 1)
2483        self.assertRaises(
2484            NotImplementedError, loop.add_writer, 1, f)
2485        self.assertRaises(
2486            NotImplementedError, loop.remove_writer, 1)
2487        self.assertRaises(
2488            NotImplementedError, loop.add_signal_handler, 1, f)
2489        self.assertRaises(
2490            NotImplementedError, loop.remove_signal_handler, 1)
2491        self.assertRaises(
2492            NotImplementedError, loop.remove_signal_handler, 1)
2493        self.assertRaises(
2494            NotImplementedError, loop.set_exception_handler, f)
2495        self.assertRaises(
2496            NotImplementedError, loop.default_exception_handler, f)
2497        self.assertRaises(
2498            NotImplementedError, loop.call_exception_handler, f)
2499        self.assertRaises(
2500            NotImplementedError, loop.get_debug)
2501        self.assertRaises(
2502            NotImplementedError, loop.set_debug, f)
2503
2504    def test_not_implemented_async(self):
2505
2506        async def inner():
2507            f = mock.Mock()
2508            loop = asyncio.AbstractEventLoop()
2509
2510            with self.assertRaises(NotImplementedError):
2511                await loop.run_in_executor(f, f)
2512            with self.assertRaises(NotImplementedError):
2513                await loop.getaddrinfo('localhost', 8080)
2514            with self.assertRaises(NotImplementedError):
2515                await loop.getnameinfo(('localhost', 8080))
2516            with self.assertRaises(NotImplementedError):
2517                await loop.create_connection(f)
2518            with self.assertRaises(NotImplementedError):
2519                await loop.create_server(f)
2520            with self.assertRaises(NotImplementedError):
2521                await loop.create_datagram_endpoint(f)
2522            with self.assertRaises(NotImplementedError):
2523                await loop.sock_recv(f, 10)
2524            with self.assertRaises(NotImplementedError):
2525                await loop.sock_recv_into(f, 10)
2526            with self.assertRaises(NotImplementedError):
2527                await loop.sock_sendall(f, 10)
2528            with self.assertRaises(NotImplementedError):
2529                await loop.sock_connect(f, f)
2530            with self.assertRaises(NotImplementedError):
2531                await loop.sock_accept(f)
2532            with self.assertRaises(NotImplementedError):
2533                await loop.sock_sendfile(f, f)
2534            with self.assertRaises(NotImplementedError):
2535                await loop.sendfile(f, f)
2536            with self.assertRaises(NotImplementedError):
2537                await loop.connect_read_pipe(f, mock.sentinel.pipe)
2538            with self.assertRaises(NotImplementedError):
2539                await loop.connect_write_pipe(f, mock.sentinel.pipe)
2540            with self.assertRaises(NotImplementedError):
2541                await loop.subprocess_shell(f, mock.sentinel)
2542            with self.assertRaises(NotImplementedError):
2543                await loop.subprocess_exec(f)
2544
2545        loop = asyncio.new_event_loop()
2546        loop.run_until_complete(inner())
2547        loop.close()
2548
2549
2550class PolicyTests(unittest.TestCase):
2551
2552    def test_event_loop_policy(self):
2553        policy = asyncio.AbstractEventLoopPolicy()
2554        self.assertRaises(NotImplementedError, policy.get_event_loop)
2555        self.assertRaises(NotImplementedError, policy.set_event_loop, object())
2556        self.assertRaises(NotImplementedError, policy.new_event_loop)
2557        self.assertRaises(NotImplementedError, policy.get_child_watcher)
2558        self.assertRaises(NotImplementedError, policy.set_child_watcher,
2559                          object())
2560
2561    def test_get_event_loop(self):
2562        policy = asyncio.DefaultEventLoopPolicy()
2563        self.assertIsNone(policy._local._loop)
2564
2565        loop = policy.get_event_loop()
2566        self.assertIsInstance(loop, asyncio.AbstractEventLoop)
2567
2568        self.assertIs(policy._local._loop, loop)
2569        self.assertIs(loop, policy.get_event_loop())
2570        loop.close()
2571
2572    def test_get_event_loop_calls_set_event_loop(self):
2573        policy = asyncio.DefaultEventLoopPolicy()
2574
2575        with mock.patch.object(
2576                policy, "set_event_loop",
2577                wraps=policy.set_event_loop) as m_set_event_loop:
2578
2579            loop = policy.get_event_loop()
2580
2581            # policy._local._loop must be set through .set_event_loop()
2582            # (the unix DefaultEventLoopPolicy needs this call to attach
2583            # the child watcher correctly)
2584            m_set_event_loop.assert_called_with(loop)
2585
2586        loop.close()
2587
2588    def test_get_event_loop_after_set_none(self):
2589        policy = asyncio.DefaultEventLoopPolicy()
2590        policy.set_event_loop(None)
2591        self.assertRaises(RuntimeError, policy.get_event_loop)
2592
2593    @mock.patch('asyncio.events.threading.current_thread')
2594    def test_get_event_loop_thread(self, m_current_thread):
2595
2596        def f():
2597            policy = asyncio.DefaultEventLoopPolicy()
2598            self.assertRaises(RuntimeError, policy.get_event_loop)
2599
2600        th = threading.Thread(target=f)
2601        th.start()
2602        th.join()
2603
2604    def test_new_event_loop(self):
2605        policy = asyncio.DefaultEventLoopPolicy()
2606
2607        loop = policy.new_event_loop()
2608        self.assertIsInstance(loop, asyncio.AbstractEventLoop)
2609        loop.close()
2610
2611    def test_set_event_loop(self):
2612        policy = asyncio.DefaultEventLoopPolicy()
2613        old_loop = policy.get_event_loop()
2614
2615        self.assertRaises(AssertionError, policy.set_event_loop, object())
2616
2617        loop = policy.new_event_loop()
2618        policy.set_event_loop(loop)
2619        self.assertIs(loop, policy.get_event_loop())
2620        self.assertIsNot(old_loop, policy.get_event_loop())
2621        loop.close()
2622        old_loop.close()
2623
2624    def test_get_event_loop_policy(self):
2625        policy = asyncio.get_event_loop_policy()
2626        self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy)
2627        self.assertIs(policy, asyncio.get_event_loop_policy())
2628
2629    def test_set_event_loop_policy(self):
2630        self.assertRaises(
2631            AssertionError, asyncio.set_event_loop_policy, object())
2632
2633        old_policy = asyncio.get_event_loop_policy()
2634
2635        policy = asyncio.DefaultEventLoopPolicy()
2636        asyncio.set_event_loop_policy(policy)
2637        self.assertIs(policy, asyncio.get_event_loop_policy())
2638        self.assertIsNot(policy, old_policy)
2639
2640
2641class GetEventLoopTestsMixin:
2642
2643    _get_running_loop_impl = None
2644    _set_running_loop_impl = None
2645    get_running_loop_impl = None
2646    get_event_loop_impl = None
2647
2648    def setUp(self):
2649        self._get_running_loop_saved = events._get_running_loop
2650        self._set_running_loop_saved = events._set_running_loop
2651        self.get_running_loop_saved = events.get_running_loop
2652        self.get_event_loop_saved = events.get_event_loop
2653
2654        events._get_running_loop = type(self)._get_running_loop_impl
2655        events._set_running_loop = type(self)._set_running_loop_impl
2656        events.get_running_loop = type(self).get_running_loop_impl
2657        events.get_event_loop = type(self).get_event_loop_impl
2658
2659        asyncio._get_running_loop = type(self)._get_running_loop_impl
2660        asyncio._set_running_loop = type(self)._set_running_loop_impl
2661        asyncio.get_running_loop = type(self).get_running_loop_impl
2662        asyncio.get_event_loop = type(self).get_event_loop_impl
2663
2664        super().setUp()
2665
2666        self.loop = asyncio.new_event_loop()
2667        asyncio.set_event_loop(self.loop)
2668
2669        if sys.platform != 'win32':
2670            watcher = asyncio.SafeChildWatcher()
2671            watcher.attach_loop(self.loop)
2672            asyncio.set_child_watcher(watcher)
2673
2674    def tearDown(self):
2675        try:
2676            if sys.platform != 'win32':
2677                asyncio.set_child_watcher(None)
2678
2679            super().tearDown()
2680        finally:
2681            self.loop.close()
2682            asyncio.set_event_loop(None)
2683
2684            events._get_running_loop = self._get_running_loop_saved
2685            events._set_running_loop = self._set_running_loop_saved
2686            events.get_running_loop = self.get_running_loop_saved
2687            events.get_event_loop = self.get_event_loop_saved
2688
2689            asyncio._get_running_loop = self._get_running_loop_saved
2690            asyncio._set_running_loop = self._set_running_loop_saved
2691            asyncio.get_running_loop = self.get_running_loop_saved
2692            asyncio.get_event_loop = self.get_event_loop_saved
2693
2694    if sys.platform != 'win32':
2695
2696        def test_get_event_loop_new_process(self):
2697            # bpo-32126: The multiprocessing module used by
2698            # ProcessPoolExecutor is not functional when the
2699            # multiprocessing.synchronize module cannot be imported.
2700            support.skip_if_broken_multiprocessing_synchronize()
2701
2702            async def main():
2703                pool = concurrent.futures.ProcessPoolExecutor()
2704                result = await self.loop.run_in_executor(
2705                    pool, _test_get_event_loop_new_process__sub_proc)
2706                pool.shutdown()
2707                return result
2708
2709            self.assertEqual(
2710                self.loop.run_until_complete(main()),
2711                'hello')
2712
2713    def test_get_event_loop_returns_running_loop(self):
2714        class TestError(Exception):
2715            pass
2716
2717        class Policy(asyncio.DefaultEventLoopPolicy):
2718            def get_event_loop(self):
2719                raise TestError
2720
2721        old_policy = asyncio.get_event_loop_policy()
2722        try:
2723            asyncio.set_event_loop_policy(Policy())
2724            loop = asyncio.new_event_loop()
2725
2726            with self.assertWarns(DeprecationWarning) as cm:
2727                with self.assertRaises(TestError):
2728                    asyncio.get_event_loop()
2729            self.assertEqual(cm.warnings[0].filename, __file__)
2730            asyncio.set_event_loop(None)
2731            with self.assertWarns(DeprecationWarning) as cm:
2732                with self.assertRaises(TestError):
2733                    asyncio.get_event_loop()
2734            self.assertEqual(cm.warnings[0].filename, __file__)
2735
2736            with self.assertRaisesRegex(RuntimeError, 'no running'):
2737                asyncio.get_running_loop()
2738            self.assertIs(asyncio._get_running_loop(), None)
2739
2740            async def func():
2741                self.assertIs(asyncio.get_event_loop(), loop)
2742                self.assertIs(asyncio.get_running_loop(), loop)
2743                self.assertIs(asyncio._get_running_loop(), loop)
2744
2745            loop.run_until_complete(func())
2746
2747            asyncio.set_event_loop(loop)
2748            with self.assertWarns(DeprecationWarning) as cm:
2749                with self.assertRaises(TestError):
2750                    asyncio.get_event_loop()
2751            self.assertEqual(cm.warnings[0].filename, __file__)
2752
2753            asyncio.set_event_loop(None)
2754            with self.assertWarns(DeprecationWarning) as cm:
2755                with self.assertRaises(TestError):
2756                    asyncio.get_event_loop()
2757            self.assertEqual(cm.warnings[0].filename, __file__)
2758
2759        finally:
2760            asyncio.set_event_loop_policy(old_policy)
2761            if loop is not None:
2762                loop.close()
2763
2764        with self.assertRaisesRegex(RuntimeError, 'no running'):
2765            asyncio.get_running_loop()
2766
2767        self.assertIs(asyncio._get_running_loop(), None)
2768
2769    def test_get_event_loop_returns_running_loop2(self):
2770        old_policy = asyncio.get_event_loop_policy()
2771        try:
2772            asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
2773            loop = asyncio.new_event_loop()
2774            self.addCleanup(loop.close)
2775
2776            with self.assertWarns(DeprecationWarning) as cm:
2777                loop2 = asyncio.get_event_loop()
2778            self.addCleanup(loop2.close)
2779            self.assertEqual(cm.warnings[0].filename, __file__)
2780            asyncio.set_event_loop(None)
2781            with self.assertWarns(DeprecationWarning) as cm:
2782                with self.assertRaisesRegex(RuntimeError, 'no current'):
2783                    asyncio.get_event_loop()
2784            self.assertEqual(cm.warnings[0].filename, __file__)
2785
2786            with self.assertRaisesRegex(RuntimeError, 'no running'):
2787                asyncio.get_running_loop()
2788            self.assertIs(asyncio._get_running_loop(), None)
2789
2790            async def func():
2791                self.assertIs(asyncio.get_event_loop(), loop)
2792                self.assertIs(asyncio.get_running_loop(), loop)
2793                self.assertIs(asyncio._get_running_loop(), loop)
2794
2795            loop.run_until_complete(func())
2796
2797            asyncio.set_event_loop(loop)
2798            with self.assertWarns(DeprecationWarning) as cm:
2799                self.assertIs(asyncio.get_event_loop(), loop)
2800            self.assertEqual(cm.warnings[0].filename, __file__)
2801
2802            asyncio.set_event_loop(None)
2803            with self.assertWarns(DeprecationWarning) as cm:
2804                with self.assertRaisesRegex(RuntimeError, 'no current'):
2805                    asyncio.get_event_loop()
2806            self.assertEqual(cm.warnings[0].filename, __file__)
2807
2808        finally:
2809            asyncio.set_event_loop_policy(old_policy)
2810            if loop is not None:
2811                loop.close()
2812
2813        with self.assertRaisesRegex(RuntimeError, 'no running'):
2814            asyncio.get_running_loop()
2815
2816        self.assertIs(asyncio._get_running_loop(), None)
2817
2818
2819class TestPyGetEventLoop(GetEventLoopTestsMixin, unittest.TestCase):
2820
2821    _get_running_loop_impl = events._py__get_running_loop
2822    _set_running_loop_impl = events._py__set_running_loop
2823    get_running_loop_impl = events._py_get_running_loop
2824    get_event_loop_impl = events._py_get_event_loop
2825
2826
2827try:
2828    import _asyncio  # NoQA
2829except ImportError:
2830    pass
2831else:
2832
2833    class TestCGetEventLoop(GetEventLoopTestsMixin, unittest.TestCase):
2834
2835        _get_running_loop_impl = events._c__get_running_loop
2836        _set_running_loop_impl = events._c__set_running_loop
2837        get_running_loop_impl = events._c_get_running_loop
2838        get_event_loop_impl = events._c_get_event_loop
2839
2840
2841class TestServer(unittest.TestCase):
2842
2843    def test_get_loop(self):
2844        loop = asyncio.new_event_loop()
2845        self.addCleanup(loop.close)
2846        proto = MyProto(loop)
2847        server = loop.run_until_complete(loop.create_server(lambda: proto, '0.0.0.0', 0))
2848        self.assertEqual(server.get_loop(), loop)
2849        server.close()
2850        loop.run_until_complete(server.wait_closed())
2851
2852
2853class TestAbstractServer(unittest.TestCase):
2854
2855    def test_close(self):
2856        with self.assertRaises(NotImplementedError):
2857            events.AbstractServer().close()
2858
2859    def test_wait_closed(self):
2860        loop = asyncio.new_event_loop()
2861        self.addCleanup(loop.close)
2862
2863        with self.assertRaises(NotImplementedError):
2864            loop.run_until_complete(events.AbstractServer().wait_closed())
2865
2866    def test_get_loop(self):
2867        with self.assertRaises(NotImplementedError):
2868            events.AbstractServer().get_loop()
2869
2870
2871if __name__ == '__main__':
2872    unittest.main()
2873