• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for unix_events.py."""
2
3import contextlib
4import errno
5import io
6import os
7import pathlib
8import signal
9import socket
10import stat
11import sys
12import tempfile
13import threading
14import unittest
15from unittest import mock
16from test.support import os_helper
17from test.support import socket_helper
18
19if sys.platform == 'win32':
20    raise unittest.SkipTest('UNIX only')
21
22
23import asyncio
24from asyncio import log
25from asyncio import unix_events
26from test.test_asyncio import utils as test_utils
27
28
29MOCK_ANY = mock.ANY
30
31
32def EXITCODE(exitcode):
33    return 32768 + exitcode
34
35
36def SIGNAL(signum):
37    if not 1 <= signum <= 68:
38        raise AssertionError(f'invalid signum {signum}')
39    return 32768 - signum
40
41
42def tearDownModule():
43    asyncio.set_event_loop_policy(None)
44
45
46def close_pipe_transport(transport):
47    # Don't call transport.close() because the event loop and the selector
48    # are mocked
49    if transport._pipe is None:
50        return
51    transport._pipe.close()
52    transport._pipe = None
53
54
55@unittest.skipUnless(signal, 'Signals are not supported')
56class SelectorEventLoopSignalTests(test_utils.TestCase):
57
58    def setUp(self):
59        super().setUp()
60        self.loop = asyncio.SelectorEventLoop()
61        self.set_event_loop(self.loop)
62
63    def test_check_signal(self):
64        self.assertRaises(
65            TypeError, self.loop._check_signal, '1')
66        self.assertRaises(
67            ValueError, self.loop._check_signal, signal.NSIG + 1)
68
69    def test_handle_signal_no_handler(self):
70        self.loop._handle_signal(signal.NSIG + 1)
71
72    def test_handle_signal_cancelled_handler(self):
73        h = asyncio.Handle(mock.Mock(), (),
74                           loop=mock.Mock())
75        h.cancel()
76        self.loop._signal_handlers[signal.NSIG + 1] = h
77        self.loop.remove_signal_handler = mock.Mock()
78        self.loop._handle_signal(signal.NSIG + 1)
79        self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1)
80
81    @mock.patch('asyncio.unix_events.signal')
82    def test_add_signal_handler_setup_error(self, m_signal):
83        m_signal.NSIG = signal.NSIG
84        m_signal.valid_signals = signal.valid_signals
85        m_signal.set_wakeup_fd.side_effect = ValueError
86
87        self.assertRaises(
88            RuntimeError,
89            self.loop.add_signal_handler,
90            signal.SIGINT, lambda: True)
91
92    @mock.patch('asyncio.unix_events.signal')
93    def test_add_signal_handler_coroutine_error(self, m_signal):
94        m_signal.NSIG = signal.NSIG
95
96        async def simple_coroutine():
97            pass
98
99        # callback must not be a coroutine function
100        coro_func = simple_coroutine
101        coro_obj = coro_func()
102        self.addCleanup(coro_obj.close)
103        for func in (coro_func, coro_obj):
104            self.assertRaisesRegex(
105                TypeError, 'coroutines cannot be used with add_signal_handler',
106                self.loop.add_signal_handler,
107                signal.SIGINT, func)
108
109    @mock.patch('asyncio.unix_events.signal')
110    def test_add_signal_handler(self, m_signal):
111        m_signal.NSIG = signal.NSIG
112        m_signal.valid_signals = signal.valid_signals
113
114        cb = lambda: True
115        self.loop.add_signal_handler(signal.SIGHUP, cb)
116        h = self.loop._signal_handlers.get(signal.SIGHUP)
117        self.assertIsInstance(h, asyncio.Handle)
118        self.assertEqual(h._callback, cb)
119
120    @mock.patch('asyncio.unix_events.signal')
121    def test_add_signal_handler_install_error(self, m_signal):
122        m_signal.NSIG = signal.NSIG
123        m_signal.valid_signals = signal.valid_signals
124
125        def set_wakeup_fd(fd):
126            if fd == -1:
127                raise ValueError()
128        m_signal.set_wakeup_fd = set_wakeup_fd
129
130        class Err(OSError):
131            errno = errno.EFAULT
132        m_signal.signal.side_effect = Err
133
134        self.assertRaises(
135            Err,
136            self.loop.add_signal_handler,
137            signal.SIGINT, lambda: True)
138
139    @mock.patch('asyncio.unix_events.signal')
140    @mock.patch('asyncio.base_events.logger')
141    def test_add_signal_handler_install_error2(self, m_logging, m_signal):
142        m_signal.NSIG = signal.NSIG
143        m_signal.valid_signals = signal.valid_signals
144
145        class Err(OSError):
146            errno = errno.EINVAL
147        m_signal.signal.side_effect = Err
148
149        self.loop._signal_handlers[signal.SIGHUP] = lambda: True
150        self.assertRaises(
151            RuntimeError,
152            self.loop.add_signal_handler,
153            signal.SIGINT, lambda: True)
154        self.assertFalse(m_logging.info.called)
155        self.assertEqual(1, m_signal.set_wakeup_fd.call_count)
156
157    @mock.patch('asyncio.unix_events.signal')
158    @mock.patch('asyncio.base_events.logger')
159    def test_add_signal_handler_install_error3(self, m_logging, m_signal):
160        class Err(OSError):
161            errno = errno.EINVAL
162        m_signal.signal.side_effect = Err
163        m_signal.NSIG = signal.NSIG
164        m_signal.valid_signals = signal.valid_signals
165
166        self.assertRaises(
167            RuntimeError,
168            self.loop.add_signal_handler,
169            signal.SIGINT, lambda: True)
170        self.assertFalse(m_logging.info.called)
171        self.assertEqual(2, m_signal.set_wakeup_fd.call_count)
172
173    @mock.patch('asyncio.unix_events.signal')
174    def test_remove_signal_handler(self, m_signal):
175        m_signal.NSIG = signal.NSIG
176        m_signal.valid_signals = signal.valid_signals
177
178        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
179
180        self.assertTrue(
181            self.loop.remove_signal_handler(signal.SIGHUP))
182        self.assertTrue(m_signal.set_wakeup_fd.called)
183        self.assertTrue(m_signal.signal.called)
184        self.assertEqual(
185            (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0])
186
187    @mock.patch('asyncio.unix_events.signal')
188    def test_remove_signal_handler_2(self, m_signal):
189        m_signal.NSIG = signal.NSIG
190        m_signal.SIGINT = signal.SIGINT
191        m_signal.valid_signals = signal.valid_signals
192
193        self.loop.add_signal_handler(signal.SIGINT, lambda: True)
194        self.loop._signal_handlers[signal.SIGHUP] = object()
195        m_signal.set_wakeup_fd.reset_mock()
196
197        self.assertTrue(
198            self.loop.remove_signal_handler(signal.SIGINT))
199        self.assertFalse(m_signal.set_wakeup_fd.called)
200        self.assertTrue(m_signal.signal.called)
201        self.assertEqual(
202            (signal.SIGINT, m_signal.default_int_handler),
203            m_signal.signal.call_args[0])
204
205    @mock.patch('asyncio.unix_events.signal')
206    @mock.patch('asyncio.base_events.logger')
207    def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal):
208        m_signal.NSIG = signal.NSIG
209        m_signal.valid_signals = signal.valid_signals
210        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
211
212        m_signal.set_wakeup_fd.side_effect = ValueError
213
214        self.loop.remove_signal_handler(signal.SIGHUP)
215        self.assertTrue(m_logging.info)
216
217    @mock.patch('asyncio.unix_events.signal')
218    def test_remove_signal_handler_error(self, m_signal):
219        m_signal.NSIG = signal.NSIG
220        m_signal.valid_signals = signal.valid_signals
221        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
222
223        m_signal.signal.side_effect = OSError
224
225        self.assertRaises(
226            OSError, self.loop.remove_signal_handler, signal.SIGHUP)
227
228    @mock.patch('asyncio.unix_events.signal')
229    def test_remove_signal_handler_error2(self, m_signal):
230        m_signal.NSIG = signal.NSIG
231        m_signal.valid_signals = signal.valid_signals
232        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
233
234        class Err(OSError):
235            errno = errno.EINVAL
236        m_signal.signal.side_effect = Err
237
238        self.assertRaises(
239            RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP)
240
241    @mock.patch('asyncio.unix_events.signal')
242    def test_close(self, m_signal):
243        m_signal.NSIG = signal.NSIG
244        m_signal.valid_signals = signal.valid_signals
245
246        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
247        self.loop.add_signal_handler(signal.SIGCHLD, lambda: True)
248
249        self.assertEqual(len(self.loop._signal_handlers), 2)
250
251        m_signal.set_wakeup_fd.reset_mock()
252
253        self.loop.close()
254
255        self.assertEqual(len(self.loop._signal_handlers), 0)
256        m_signal.set_wakeup_fd.assert_called_once_with(-1)
257
258    @mock.patch('asyncio.unix_events.sys')
259    @mock.patch('asyncio.unix_events.signal')
260    def test_close_on_finalizing(self, m_signal, m_sys):
261        m_signal.NSIG = signal.NSIG
262        m_signal.valid_signals = signal.valid_signals
263        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
264
265        self.assertEqual(len(self.loop._signal_handlers), 1)
266        m_sys.is_finalizing.return_value = True
267        m_signal.signal.reset_mock()
268
269        with self.assertWarnsRegex(ResourceWarning,
270                                   "skipping signal handlers removal"):
271            self.loop.close()
272
273        self.assertEqual(len(self.loop._signal_handlers), 0)
274        self.assertFalse(m_signal.signal.called)
275
276
277@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
278                     'UNIX Sockets are not supported')
279class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
280
281    def setUp(self):
282        super().setUp()
283        self.loop = asyncio.SelectorEventLoop()
284        self.set_event_loop(self.loop)
285
286    @socket_helper.skip_unless_bind_unix_socket
287    def test_create_unix_server_existing_path_sock(self):
288        with test_utils.unix_socket_path() as path:
289            sock = socket.socket(socket.AF_UNIX)
290            sock.bind(path)
291            sock.listen(1)
292            sock.close()
293
294            coro = self.loop.create_unix_server(lambda: None, path)
295            srv = self.loop.run_until_complete(coro)
296            srv.close()
297            self.loop.run_until_complete(srv.wait_closed())
298
299    @socket_helper.skip_unless_bind_unix_socket
300    def test_create_unix_server_pathlib(self):
301        with test_utils.unix_socket_path() as path:
302            path = pathlib.Path(path)
303            srv_coro = self.loop.create_unix_server(lambda: None, path)
304            srv = self.loop.run_until_complete(srv_coro)
305            srv.close()
306            self.loop.run_until_complete(srv.wait_closed())
307
308    def test_create_unix_connection_pathlib(self):
309        with test_utils.unix_socket_path() as path:
310            path = pathlib.Path(path)
311            coro = self.loop.create_unix_connection(lambda: None, path)
312            with self.assertRaises(FileNotFoundError):
313                # If pathlib.Path wasn't supported, the exception would be
314                # different.
315                self.loop.run_until_complete(coro)
316
317    def test_create_unix_server_existing_path_nonsock(self):
318        with tempfile.NamedTemporaryFile() as file:
319            coro = self.loop.create_unix_server(lambda: None, file.name)
320            with self.assertRaisesRegex(OSError,
321                                        'Address.*is already in use'):
322                self.loop.run_until_complete(coro)
323
324    def test_create_unix_server_ssl_bool(self):
325        coro = self.loop.create_unix_server(lambda: None, path='spam',
326                                            ssl=True)
327        with self.assertRaisesRegex(TypeError,
328                                    'ssl argument must be an SSLContext'):
329            self.loop.run_until_complete(coro)
330
331    def test_create_unix_server_nopath_nosock(self):
332        coro = self.loop.create_unix_server(lambda: None, path=None)
333        with self.assertRaisesRegex(ValueError,
334                                    'path was not specified, and no sock'):
335            self.loop.run_until_complete(coro)
336
337    def test_create_unix_server_path_inetsock(self):
338        sock = socket.socket()
339        with sock:
340            coro = self.loop.create_unix_server(lambda: None, path=None,
341                                                sock=sock)
342            with self.assertRaisesRegex(ValueError,
343                                        'A UNIX Domain Stream.*was expected'):
344                self.loop.run_until_complete(coro)
345
346    def test_create_unix_server_path_dgram(self):
347        sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
348        with sock:
349            coro = self.loop.create_unix_server(lambda: None, path=None,
350                                                sock=sock)
351            with self.assertRaisesRegex(ValueError,
352                                        'A UNIX Domain Stream.*was expected'):
353                self.loop.run_until_complete(coro)
354
355    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
356                         'no socket.SOCK_NONBLOCK (linux only)')
357    @socket_helper.skip_unless_bind_unix_socket
358    def test_create_unix_server_path_stream_bittype(self):
359        sock = socket.socket(
360            socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
361        with tempfile.NamedTemporaryFile() as file:
362            fn = file.name
363        try:
364            with sock:
365                sock.bind(fn)
366                coro = self.loop.create_unix_server(lambda: None, path=None,
367                                                    sock=sock)
368                srv = self.loop.run_until_complete(coro)
369                srv.close()
370                self.loop.run_until_complete(srv.wait_closed())
371        finally:
372            os.unlink(fn)
373
374    def test_create_unix_server_ssl_timeout_with_plain_sock(self):
375        coro = self.loop.create_unix_server(lambda: None, path='spam',
376                                            ssl_handshake_timeout=1)
377        with self.assertRaisesRegex(
378                ValueError,
379                'ssl_handshake_timeout is only meaningful with ssl'):
380            self.loop.run_until_complete(coro)
381
382    def test_create_unix_connection_path_inetsock(self):
383        sock = socket.socket()
384        with sock:
385            coro = self.loop.create_unix_connection(lambda: None,
386                                                    sock=sock)
387            with self.assertRaisesRegex(ValueError,
388                                        'A UNIX Domain Stream.*was expected'):
389                self.loop.run_until_complete(coro)
390
391    @mock.patch('asyncio.unix_events.socket')
392    def test_create_unix_server_bind_error(self, m_socket):
393        # Ensure that the socket is closed on any bind error
394        sock = mock.Mock()
395        m_socket.socket.return_value = sock
396
397        sock.bind.side_effect = OSError
398        coro = self.loop.create_unix_server(lambda: None, path="/test")
399        with self.assertRaises(OSError):
400            self.loop.run_until_complete(coro)
401        self.assertTrue(sock.close.called)
402
403        sock.bind.side_effect = MemoryError
404        coro = self.loop.create_unix_server(lambda: None, path="/test")
405        with self.assertRaises(MemoryError):
406            self.loop.run_until_complete(coro)
407        self.assertTrue(sock.close.called)
408
409    def test_create_unix_connection_path_sock(self):
410        coro = self.loop.create_unix_connection(
411            lambda: None, os.devnull, sock=object())
412        with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
413            self.loop.run_until_complete(coro)
414
415    def test_create_unix_connection_nopath_nosock(self):
416        coro = self.loop.create_unix_connection(
417            lambda: None, None)
418        with self.assertRaisesRegex(ValueError,
419                                    'no path and sock were specified'):
420            self.loop.run_until_complete(coro)
421
422    def test_create_unix_connection_nossl_serverhost(self):
423        coro = self.loop.create_unix_connection(
424            lambda: None, os.devnull, server_hostname='spam')
425        with self.assertRaisesRegex(ValueError,
426                                    'server_hostname is only meaningful'):
427            self.loop.run_until_complete(coro)
428
429    def test_create_unix_connection_ssl_noserverhost(self):
430        coro = self.loop.create_unix_connection(
431            lambda: None, os.devnull, ssl=True)
432
433        with self.assertRaisesRegex(
434            ValueError, 'you have to pass server_hostname when using ssl'):
435
436            self.loop.run_until_complete(coro)
437
438    def test_create_unix_connection_ssl_timeout_with_plain_sock(self):
439        coro = self.loop.create_unix_connection(lambda: None, path='spam',
440                                            ssl_handshake_timeout=1)
441        with self.assertRaisesRegex(
442                ValueError,
443                'ssl_handshake_timeout is only meaningful with ssl'):
444            self.loop.run_until_complete(coro)
445
446
447@unittest.skipUnless(hasattr(os, 'sendfile'),
448                     'sendfile is not supported')
449class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
450    DATA = b"12345abcde" * 16 * 1024  # 160 KiB
451
452    class MyProto(asyncio.Protocol):
453
454        def __init__(self, loop):
455            self.started = False
456            self.closed = False
457            self.data = bytearray()
458            self.fut = loop.create_future()
459            self.transport = None
460            self._ready = loop.create_future()
461
462        def connection_made(self, transport):
463            self.started = True
464            self.transport = transport
465            self._ready.set_result(None)
466
467        def data_received(self, data):
468            self.data.extend(data)
469
470        def connection_lost(self, exc):
471            self.closed = True
472            self.fut.set_result(None)
473
474        async def wait_closed(self):
475            await self.fut
476
477    @classmethod
478    def setUpClass(cls):
479        with open(os_helper.TESTFN, 'wb') as fp:
480            fp.write(cls.DATA)
481        super().setUpClass()
482
483    @classmethod
484    def tearDownClass(cls):
485        os_helper.unlink(os_helper.TESTFN)
486        super().tearDownClass()
487
488    def setUp(self):
489        self.loop = asyncio.new_event_loop()
490        self.set_event_loop(self.loop)
491        self.file = open(os_helper.TESTFN, 'rb')
492        self.addCleanup(self.file.close)
493        super().setUp()
494
495    def make_socket(self, cleanup=True):
496        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
497        sock.setblocking(False)
498        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
499        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
500        if cleanup:
501            self.addCleanup(sock.close)
502        return sock
503
504    def run_loop(self, coro):
505        return self.loop.run_until_complete(coro)
506
507    def prepare(self):
508        sock = self.make_socket()
509        proto = self.MyProto(self.loop)
510        port = socket_helper.find_unused_port()
511        srv_sock = self.make_socket(cleanup=False)
512        srv_sock.bind((socket_helper.HOST, port))
513        server = self.run_loop(self.loop.create_server(
514            lambda: proto, sock=srv_sock))
515        self.run_loop(self.loop.sock_connect(sock, (socket_helper.HOST, port)))
516        self.run_loop(proto._ready)
517
518        def cleanup():
519            proto.transport.close()
520            self.run_loop(proto.wait_closed())
521
522            server.close()
523            self.run_loop(server.wait_closed())
524
525        self.addCleanup(cleanup)
526
527        return sock, proto
528
529    def test_sock_sendfile_not_available(self):
530        sock, proto = self.prepare()
531        with mock.patch('asyncio.unix_events.os', spec=[]):
532            with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
533                                        "os[.]sendfile[(][)] is not available"):
534                self.run_loop(self.loop._sock_sendfile_native(sock, self.file,
535                                                              0, None))
536        self.assertEqual(self.file.tell(), 0)
537
538    def test_sock_sendfile_not_a_file(self):
539        sock, proto = self.prepare()
540        f = object()
541        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
542                                    "not a regular file"):
543            self.run_loop(self.loop._sock_sendfile_native(sock, f,
544                                                          0, None))
545        self.assertEqual(self.file.tell(), 0)
546
547    def test_sock_sendfile_iobuffer(self):
548        sock, proto = self.prepare()
549        f = io.BytesIO()
550        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
551                                    "not a regular file"):
552            self.run_loop(self.loop._sock_sendfile_native(sock, f,
553                                                          0, None))
554        self.assertEqual(self.file.tell(), 0)
555
556    def test_sock_sendfile_not_regular_file(self):
557        sock, proto = self.prepare()
558        f = mock.Mock()
559        f.fileno.return_value = -1
560        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
561                                    "not a regular file"):
562            self.run_loop(self.loop._sock_sendfile_native(sock, f,
563                                                          0, None))
564        self.assertEqual(self.file.tell(), 0)
565
566    def test_sock_sendfile_cancel1(self):
567        sock, proto = self.prepare()
568
569        fut = self.loop.create_future()
570        fileno = self.file.fileno()
571        self.loop._sock_sendfile_native_impl(fut, None, sock, fileno,
572                                             0, None, len(self.DATA), 0)
573        fut.cancel()
574        with contextlib.suppress(asyncio.CancelledError):
575            self.run_loop(fut)
576        with self.assertRaises(KeyError):
577            self.loop._selector.get_key(sock)
578
579    def test_sock_sendfile_cancel2(self):
580        sock, proto = self.prepare()
581
582        fut = self.loop.create_future()
583        fileno = self.file.fileno()
584        self.loop._sock_sendfile_native_impl(fut, None, sock, fileno,
585                                             0, None, len(self.DATA), 0)
586        fut.cancel()
587        self.loop._sock_sendfile_native_impl(fut, sock.fileno(), sock, fileno,
588                                             0, None, len(self.DATA), 0)
589        with self.assertRaises(KeyError):
590            self.loop._selector.get_key(sock)
591
592    def test_sock_sendfile_blocking_error(self):
593        sock, proto = self.prepare()
594
595        fileno = self.file.fileno()
596        fut = mock.Mock()
597        fut.cancelled.return_value = False
598        with mock.patch('os.sendfile', side_effect=BlockingIOError()):
599            self.loop._sock_sendfile_native_impl(fut, None, sock, fileno,
600                                                 0, None, len(self.DATA), 0)
601        key = self.loop._selector.get_key(sock)
602        self.assertIsNotNone(key)
603        fut.add_done_callback.assert_called_once_with(mock.ANY)
604
605    def test_sock_sendfile_os_error_first_call(self):
606        sock, proto = self.prepare()
607
608        fileno = self.file.fileno()
609        fut = self.loop.create_future()
610        with mock.patch('os.sendfile', side_effect=OSError()):
611            self.loop._sock_sendfile_native_impl(fut, None, sock, fileno,
612                                                 0, None, len(self.DATA), 0)
613        with self.assertRaises(KeyError):
614            self.loop._selector.get_key(sock)
615        exc = fut.exception()
616        self.assertIsInstance(exc, asyncio.SendfileNotAvailableError)
617        self.assertEqual(0, self.file.tell())
618
619    def test_sock_sendfile_os_error_next_call(self):
620        sock, proto = self.prepare()
621
622        fileno = self.file.fileno()
623        fut = self.loop.create_future()
624        err = OSError()
625        with mock.patch('os.sendfile', side_effect=err):
626            self.loop._sock_sendfile_native_impl(fut, sock.fileno(),
627                                                 sock, fileno,
628                                                 1000, None, len(self.DATA),
629                                                 1000)
630        with self.assertRaises(KeyError):
631            self.loop._selector.get_key(sock)
632        exc = fut.exception()
633        self.assertIs(exc, err)
634        self.assertEqual(1000, self.file.tell())
635
636    def test_sock_sendfile_exception(self):
637        sock, proto = self.prepare()
638
639        fileno = self.file.fileno()
640        fut = self.loop.create_future()
641        err = asyncio.SendfileNotAvailableError()
642        with mock.patch('os.sendfile', side_effect=err):
643            self.loop._sock_sendfile_native_impl(fut, sock.fileno(),
644                                                 sock, fileno,
645                                                 1000, None, len(self.DATA),
646                                                 1000)
647        with self.assertRaises(KeyError):
648            self.loop._selector.get_key(sock)
649        exc = fut.exception()
650        self.assertIs(exc, err)
651        self.assertEqual(1000, self.file.tell())
652
653
654class UnixReadPipeTransportTests(test_utils.TestCase):
655
656    def setUp(self):
657        super().setUp()
658        self.loop = self.new_test_loop()
659        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
660        self.pipe = mock.Mock(spec_set=io.RawIOBase)
661        self.pipe.fileno.return_value = 5
662
663        blocking_patcher = mock.patch('os.set_blocking')
664        blocking_patcher.start()
665        self.addCleanup(blocking_patcher.stop)
666
667        fstat_patcher = mock.patch('os.fstat')
668        m_fstat = fstat_patcher.start()
669        st = mock.Mock()
670        st.st_mode = stat.S_IFIFO
671        m_fstat.return_value = st
672        self.addCleanup(fstat_patcher.stop)
673
674    def read_pipe_transport(self, waiter=None):
675        transport = unix_events._UnixReadPipeTransport(self.loop, self.pipe,
676                                                       self.protocol,
677                                                       waiter=waiter)
678        self.addCleanup(close_pipe_transport, transport)
679        return transport
680
681    def test_ctor(self):
682        waiter = self.loop.create_future()
683        tr = self.read_pipe_transport(waiter=waiter)
684        self.loop.run_until_complete(waiter)
685
686        self.protocol.connection_made.assert_called_with(tr)
687        self.loop.assert_reader(5, tr._read_ready)
688        self.assertIsNone(waiter.result())
689
690    @mock.patch('os.read')
691    def test__read_ready(self, m_read):
692        tr = self.read_pipe_transport()
693        m_read.return_value = b'data'
694        tr._read_ready()
695
696        m_read.assert_called_with(5, tr.max_size)
697        self.protocol.data_received.assert_called_with(b'data')
698
699    @mock.patch('os.read')
700    def test__read_ready_eof(self, m_read):
701        tr = self.read_pipe_transport()
702        m_read.return_value = b''
703        tr._read_ready()
704
705        m_read.assert_called_with(5, tr.max_size)
706        self.assertFalse(self.loop.readers)
707        test_utils.run_briefly(self.loop)
708        self.protocol.eof_received.assert_called_with()
709        self.protocol.connection_lost.assert_called_with(None)
710
711    @mock.patch('os.read')
712    def test__read_ready_blocked(self, m_read):
713        tr = self.read_pipe_transport()
714        m_read.side_effect = BlockingIOError
715        tr._read_ready()
716
717        m_read.assert_called_with(5, tr.max_size)
718        test_utils.run_briefly(self.loop)
719        self.assertFalse(self.protocol.data_received.called)
720
721    @mock.patch('asyncio.log.logger.error')
722    @mock.patch('os.read')
723    def test__read_ready_error(self, m_read, m_logexc):
724        tr = self.read_pipe_transport()
725        err = OSError()
726        m_read.side_effect = err
727        tr._close = mock.Mock()
728        tr._read_ready()
729
730        m_read.assert_called_with(5, tr.max_size)
731        tr._close.assert_called_with(err)
732        m_logexc.assert_called_with(
733            test_utils.MockPattern(
734                'Fatal read error on pipe transport'
735                '\nprotocol:.*\ntransport:.*'),
736            exc_info=(OSError, MOCK_ANY, MOCK_ANY))
737
738    @mock.patch('os.read')
739    def test_pause_reading(self, m_read):
740        tr = self.read_pipe_transport()
741        m = mock.Mock()
742        self.loop.add_reader(5, m)
743        tr.pause_reading()
744        self.assertFalse(self.loop.readers)
745
746    @mock.patch('os.read')
747    def test_resume_reading(self, m_read):
748        tr = self.read_pipe_transport()
749        tr.pause_reading()
750        tr.resume_reading()
751        self.loop.assert_reader(5, tr._read_ready)
752
753    @mock.patch('os.read')
754    def test_close(self, m_read):
755        tr = self.read_pipe_transport()
756        tr._close = mock.Mock()
757        tr.close()
758        tr._close.assert_called_with(None)
759
760    @mock.patch('os.read')
761    def test_close_already_closing(self, m_read):
762        tr = self.read_pipe_transport()
763        tr._closing = True
764        tr._close = mock.Mock()
765        tr.close()
766        self.assertFalse(tr._close.called)
767
768    @mock.patch('os.read')
769    def test__close(self, m_read):
770        tr = self.read_pipe_transport()
771        err = object()
772        tr._close(err)
773        self.assertTrue(tr.is_closing())
774        self.assertFalse(self.loop.readers)
775        test_utils.run_briefly(self.loop)
776        self.protocol.connection_lost.assert_called_with(err)
777
778    def test__call_connection_lost(self):
779        tr = self.read_pipe_transport()
780        self.assertIsNotNone(tr._protocol)
781        self.assertIsNotNone(tr._loop)
782
783        err = None
784        tr._call_connection_lost(err)
785        self.protocol.connection_lost.assert_called_with(err)
786        self.pipe.close.assert_called_with()
787
788        self.assertIsNone(tr._protocol)
789        self.assertIsNone(tr._loop)
790
791    def test__call_connection_lost_with_err(self):
792        tr = self.read_pipe_transport()
793        self.assertIsNotNone(tr._protocol)
794        self.assertIsNotNone(tr._loop)
795
796        err = OSError()
797        tr._call_connection_lost(err)
798        self.protocol.connection_lost.assert_called_with(err)
799        self.pipe.close.assert_called_with()
800
801        self.assertIsNone(tr._protocol)
802        self.assertIsNone(tr._loop)
803
804    def test_pause_reading_on_closed_pipe(self):
805        tr = self.read_pipe_transport()
806        tr.close()
807        test_utils.run_briefly(self.loop)
808        self.assertIsNone(tr._loop)
809        tr.pause_reading()
810
811    def test_pause_reading_on_paused_pipe(self):
812        tr = self.read_pipe_transport()
813        tr.pause_reading()
814        # the second call should do nothing
815        tr.pause_reading()
816
817    def test_resume_reading_on_closed_pipe(self):
818        tr = self.read_pipe_transport()
819        tr.close()
820        test_utils.run_briefly(self.loop)
821        self.assertIsNone(tr._loop)
822        tr.resume_reading()
823
824    def test_resume_reading_on_paused_pipe(self):
825        tr = self.read_pipe_transport()
826        # the pipe is not paused
827        # resuming should do nothing
828        tr.resume_reading()
829
830
831class UnixWritePipeTransportTests(test_utils.TestCase):
832
833    def setUp(self):
834        super().setUp()
835        self.loop = self.new_test_loop()
836        self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
837        self.pipe = mock.Mock(spec_set=io.RawIOBase)
838        self.pipe.fileno.return_value = 5
839
840        blocking_patcher = mock.patch('os.set_blocking')
841        blocking_patcher.start()
842        self.addCleanup(blocking_patcher.stop)
843
844        fstat_patcher = mock.patch('os.fstat')
845        m_fstat = fstat_patcher.start()
846        st = mock.Mock()
847        st.st_mode = stat.S_IFSOCK
848        m_fstat.return_value = st
849        self.addCleanup(fstat_patcher.stop)
850
851    def write_pipe_transport(self, waiter=None):
852        transport = unix_events._UnixWritePipeTransport(self.loop, self.pipe,
853                                                        self.protocol,
854                                                        waiter=waiter)
855        self.addCleanup(close_pipe_transport, transport)
856        return transport
857
858    def test_ctor(self):
859        waiter = self.loop.create_future()
860        tr = self.write_pipe_transport(waiter=waiter)
861        self.loop.run_until_complete(waiter)
862
863        self.protocol.connection_made.assert_called_with(tr)
864        self.loop.assert_reader(5, tr._read_ready)
865        self.assertEqual(None, waiter.result())
866
867    def test_can_write_eof(self):
868        tr = self.write_pipe_transport()
869        self.assertTrue(tr.can_write_eof())
870
871    @mock.patch('os.write')
872    def test_write(self, m_write):
873        tr = self.write_pipe_transport()
874        m_write.return_value = 4
875        tr.write(b'data')
876        m_write.assert_called_with(5, b'data')
877        self.assertFalse(self.loop.writers)
878        self.assertEqual(bytearray(), tr._buffer)
879
880    @mock.patch('os.write')
881    def test_write_no_data(self, m_write):
882        tr = self.write_pipe_transport()
883        tr.write(b'')
884        self.assertFalse(m_write.called)
885        self.assertFalse(self.loop.writers)
886        self.assertEqual(bytearray(b''), tr._buffer)
887
888    @mock.patch('os.write')
889    def test_write_partial(self, m_write):
890        tr = self.write_pipe_transport()
891        m_write.return_value = 2
892        tr.write(b'data')
893        self.loop.assert_writer(5, tr._write_ready)
894        self.assertEqual(bytearray(b'ta'), tr._buffer)
895
896    @mock.patch('os.write')
897    def test_write_buffer(self, m_write):
898        tr = self.write_pipe_transport()
899        self.loop.add_writer(5, tr._write_ready)
900        tr._buffer = bytearray(b'previous')
901        tr.write(b'data')
902        self.assertFalse(m_write.called)
903        self.loop.assert_writer(5, tr._write_ready)
904        self.assertEqual(bytearray(b'previousdata'), tr._buffer)
905
906    @mock.patch('os.write')
907    def test_write_again(self, m_write):
908        tr = self.write_pipe_transport()
909        m_write.side_effect = BlockingIOError()
910        tr.write(b'data')
911        m_write.assert_called_with(5, bytearray(b'data'))
912        self.loop.assert_writer(5, tr._write_ready)
913        self.assertEqual(bytearray(b'data'), tr._buffer)
914
915    @mock.patch('asyncio.unix_events.logger')
916    @mock.patch('os.write')
917    def test_write_err(self, m_write, m_log):
918        tr = self.write_pipe_transport()
919        err = OSError()
920        m_write.side_effect = err
921        tr._fatal_error = mock.Mock()
922        tr.write(b'data')
923        m_write.assert_called_with(5, b'data')
924        self.assertFalse(self.loop.writers)
925        self.assertEqual(bytearray(), tr._buffer)
926        tr._fatal_error.assert_called_with(
927                            err,
928                            'Fatal write error on pipe transport')
929        self.assertEqual(1, tr._conn_lost)
930
931        tr.write(b'data')
932        self.assertEqual(2, tr._conn_lost)
933        tr.write(b'data')
934        tr.write(b'data')
935        tr.write(b'data')
936        tr.write(b'data')
937        # This is a bit overspecified. :-(
938        m_log.warning.assert_called_with(
939            'pipe closed by peer or os.write(pipe, data) raised exception.')
940        tr.close()
941
942    @mock.patch('os.write')
943    def test_write_close(self, m_write):
944        tr = self.write_pipe_transport()
945        tr._read_ready()  # pipe was closed by peer
946
947        tr.write(b'data')
948        self.assertEqual(tr._conn_lost, 1)
949        tr.write(b'data')
950        self.assertEqual(tr._conn_lost, 2)
951
952    def test__read_ready(self):
953        tr = self.write_pipe_transport()
954        tr._read_ready()
955        self.assertFalse(self.loop.readers)
956        self.assertFalse(self.loop.writers)
957        self.assertTrue(tr.is_closing())
958        test_utils.run_briefly(self.loop)
959        self.protocol.connection_lost.assert_called_with(None)
960
961    @mock.patch('os.write')
962    def test__write_ready(self, m_write):
963        tr = self.write_pipe_transport()
964        self.loop.add_writer(5, tr._write_ready)
965        tr._buffer = bytearray(b'data')
966        m_write.return_value = 4
967        tr._write_ready()
968        self.assertFalse(self.loop.writers)
969        self.assertEqual(bytearray(), tr._buffer)
970
971    @mock.patch('os.write')
972    def test__write_ready_partial(self, m_write):
973        tr = self.write_pipe_transport()
974        self.loop.add_writer(5, tr._write_ready)
975        tr._buffer = bytearray(b'data')
976        m_write.return_value = 3
977        tr._write_ready()
978        self.loop.assert_writer(5, tr._write_ready)
979        self.assertEqual(bytearray(b'a'), tr._buffer)
980
981    @mock.patch('os.write')
982    def test__write_ready_again(self, m_write):
983        tr = self.write_pipe_transport()
984        self.loop.add_writer(5, tr._write_ready)
985        tr._buffer = bytearray(b'data')
986        m_write.side_effect = BlockingIOError()
987        tr._write_ready()
988        m_write.assert_called_with(5, bytearray(b'data'))
989        self.loop.assert_writer(5, tr._write_ready)
990        self.assertEqual(bytearray(b'data'), tr._buffer)
991
992    @mock.patch('os.write')
993    def test__write_ready_empty(self, m_write):
994        tr = self.write_pipe_transport()
995        self.loop.add_writer(5, tr._write_ready)
996        tr._buffer = bytearray(b'data')
997        m_write.return_value = 0
998        tr._write_ready()
999        m_write.assert_called_with(5, bytearray(b'data'))
1000        self.loop.assert_writer(5, tr._write_ready)
1001        self.assertEqual(bytearray(b'data'), tr._buffer)
1002
1003    @mock.patch('asyncio.log.logger.error')
1004    @mock.patch('os.write')
1005    def test__write_ready_err(self, m_write, m_logexc):
1006        tr = self.write_pipe_transport()
1007        self.loop.add_writer(5, tr._write_ready)
1008        tr._buffer = bytearray(b'data')
1009        m_write.side_effect = err = OSError()
1010        tr._write_ready()
1011        self.assertFalse(self.loop.writers)
1012        self.assertFalse(self.loop.readers)
1013        self.assertEqual(bytearray(), tr._buffer)
1014        self.assertTrue(tr.is_closing())
1015        m_logexc.assert_not_called()
1016        self.assertEqual(1, tr._conn_lost)
1017        test_utils.run_briefly(self.loop)
1018        self.protocol.connection_lost.assert_called_with(err)
1019
1020    @mock.patch('os.write')
1021    def test__write_ready_closing(self, m_write):
1022        tr = self.write_pipe_transport()
1023        self.loop.add_writer(5, tr._write_ready)
1024        tr._closing = True
1025        tr._buffer = bytearray(b'data')
1026        m_write.return_value = 4
1027        tr._write_ready()
1028        self.assertFalse(self.loop.writers)
1029        self.assertFalse(self.loop.readers)
1030        self.assertEqual(bytearray(), tr._buffer)
1031        self.protocol.connection_lost.assert_called_with(None)
1032        self.pipe.close.assert_called_with()
1033
1034    @mock.patch('os.write')
1035    def test_abort(self, m_write):
1036        tr = self.write_pipe_transport()
1037        self.loop.add_writer(5, tr._write_ready)
1038        self.loop.add_reader(5, tr._read_ready)
1039        tr._buffer = [b'da', b'ta']
1040        tr.abort()
1041        self.assertFalse(m_write.called)
1042        self.assertFalse(self.loop.readers)
1043        self.assertFalse(self.loop.writers)
1044        self.assertEqual([], tr._buffer)
1045        self.assertTrue(tr.is_closing())
1046        test_utils.run_briefly(self.loop)
1047        self.protocol.connection_lost.assert_called_with(None)
1048
1049    def test__call_connection_lost(self):
1050        tr = self.write_pipe_transport()
1051        self.assertIsNotNone(tr._protocol)
1052        self.assertIsNotNone(tr._loop)
1053
1054        err = None
1055        tr._call_connection_lost(err)
1056        self.protocol.connection_lost.assert_called_with(err)
1057        self.pipe.close.assert_called_with()
1058
1059        self.assertIsNone(tr._protocol)
1060        self.assertIsNone(tr._loop)
1061
1062    def test__call_connection_lost_with_err(self):
1063        tr = self.write_pipe_transport()
1064        self.assertIsNotNone(tr._protocol)
1065        self.assertIsNotNone(tr._loop)
1066
1067        err = OSError()
1068        tr._call_connection_lost(err)
1069        self.protocol.connection_lost.assert_called_with(err)
1070        self.pipe.close.assert_called_with()
1071
1072        self.assertIsNone(tr._protocol)
1073        self.assertIsNone(tr._loop)
1074
1075    def test_close(self):
1076        tr = self.write_pipe_transport()
1077        tr.write_eof = mock.Mock()
1078        tr.close()
1079        tr.write_eof.assert_called_with()
1080
1081        # closing the transport twice must not fail
1082        tr.close()
1083
1084    def test_close_closing(self):
1085        tr = self.write_pipe_transport()
1086        tr.write_eof = mock.Mock()
1087        tr._closing = True
1088        tr.close()
1089        self.assertFalse(tr.write_eof.called)
1090
1091    def test_write_eof(self):
1092        tr = self.write_pipe_transport()
1093        tr.write_eof()
1094        self.assertTrue(tr.is_closing())
1095        self.assertFalse(self.loop.readers)
1096        test_utils.run_briefly(self.loop)
1097        self.protocol.connection_lost.assert_called_with(None)
1098
1099    def test_write_eof_pending(self):
1100        tr = self.write_pipe_transport()
1101        tr._buffer = [b'data']
1102        tr.write_eof()
1103        self.assertTrue(tr.is_closing())
1104        self.assertFalse(self.protocol.connection_lost.called)
1105
1106
1107class AbstractChildWatcherTests(unittest.TestCase):
1108
1109    def test_not_implemented(self):
1110        f = mock.Mock()
1111        watcher = asyncio.AbstractChildWatcher()
1112        self.assertRaises(
1113            NotImplementedError, watcher.add_child_handler, f, f)
1114        self.assertRaises(
1115            NotImplementedError, watcher.remove_child_handler, f)
1116        self.assertRaises(
1117            NotImplementedError, watcher.attach_loop, f)
1118        self.assertRaises(
1119            NotImplementedError, watcher.close)
1120        self.assertRaises(
1121            NotImplementedError, watcher.is_active)
1122        self.assertRaises(
1123            NotImplementedError, watcher.__enter__)
1124        self.assertRaises(
1125            NotImplementedError, watcher.__exit__, f, f, f)
1126
1127
1128class BaseChildWatcherTests(unittest.TestCase):
1129
1130    def test_not_implemented(self):
1131        f = mock.Mock()
1132        watcher = unix_events.BaseChildWatcher()
1133        self.assertRaises(
1134            NotImplementedError, watcher._do_waitpid, f)
1135
1136
1137class ChildWatcherTestsMixin:
1138
1139    ignore_warnings = mock.patch.object(log.logger, "warning")
1140
1141    def setUp(self):
1142        super().setUp()
1143        self.loop = self.new_test_loop()
1144        self.running = False
1145        self.zombies = {}
1146
1147        with mock.patch.object(
1148                self.loop, "add_signal_handler") as self.m_add_signal_handler:
1149            self.watcher = self.create_watcher()
1150            self.watcher.attach_loop(self.loop)
1151
1152    def waitpid(self, pid, flags):
1153        if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1:
1154            self.assertGreater(pid, 0)
1155        try:
1156            if pid < 0:
1157                return self.zombies.popitem()
1158            else:
1159                return pid, self.zombies.pop(pid)
1160        except KeyError:
1161            pass
1162        if self.running:
1163            return 0, 0
1164        else:
1165            raise ChildProcessError()
1166
1167    def add_zombie(self, pid, status):
1168        self.zombies[pid] = status
1169
1170    def waitstatus_to_exitcode(self, status):
1171        if status > 32768:
1172            return status - 32768
1173        elif 32700 < status < 32768:
1174            return status - 32768
1175        else:
1176            return status
1177
1178    def test_create_watcher(self):
1179        self.m_add_signal_handler.assert_called_once_with(
1180            signal.SIGCHLD, self.watcher._sig_chld)
1181
1182    def waitpid_mocks(func):
1183        def wrapped_func(self):
1184            def patch(target, wrapper):
1185                return mock.patch(target, wraps=wrapper,
1186                                  new_callable=mock.Mock)
1187
1188            with patch('asyncio.unix_events.waitstatus_to_exitcode', self.waitstatus_to_exitcode), \
1189                 patch('os.waitpid', self.waitpid) as m_waitpid:
1190                func(self, m_waitpid)
1191        return wrapped_func
1192
1193    @waitpid_mocks
1194    def test_sigchld(self, m_waitpid):
1195        # register a child
1196        callback = mock.Mock()
1197
1198        with self.watcher:
1199            self.running = True
1200            self.watcher.add_child_handler(42, callback, 9, 10, 14)
1201
1202        self.assertFalse(callback.called)
1203
1204        # child is running
1205        self.watcher._sig_chld()
1206
1207        self.assertFalse(callback.called)
1208
1209        # child terminates (returncode 12)
1210        self.running = False
1211        self.add_zombie(42, EXITCODE(12))
1212        self.watcher._sig_chld()
1213
1214        callback.assert_called_once_with(42, 12, 9, 10, 14)
1215
1216        callback.reset_mock()
1217
1218        # ensure that the child is effectively reaped
1219        self.add_zombie(42, EXITCODE(13))
1220        with self.ignore_warnings:
1221            self.watcher._sig_chld()
1222
1223        self.assertFalse(callback.called)
1224
1225        # sigchld called again
1226        self.zombies.clear()
1227        self.watcher._sig_chld()
1228
1229        self.assertFalse(callback.called)
1230
1231    @waitpid_mocks
1232    def test_sigchld_two_children(self, m_waitpid):
1233        callback1 = mock.Mock()
1234        callback2 = mock.Mock()
1235
1236        # register child 1
1237        with self.watcher:
1238            self.running = True
1239            self.watcher.add_child_handler(43, callback1, 7, 8)
1240
1241        self.assertFalse(callback1.called)
1242        self.assertFalse(callback2.called)
1243
1244        # register child 2
1245        with self.watcher:
1246            self.watcher.add_child_handler(44, callback2, 147, 18)
1247
1248        self.assertFalse(callback1.called)
1249        self.assertFalse(callback2.called)
1250
1251        # children are running
1252        self.watcher._sig_chld()
1253
1254        self.assertFalse(callback1.called)
1255        self.assertFalse(callback2.called)
1256
1257        # child 1 terminates (signal 3)
1258        self.add_zombie(43, SIGNAL(3))
1259        self.watcher._sig_chld()
1260
1261        callback1.assert_called_once_with(43, -3, 7, 8)
1262        self.assertFalse(callback2.called)
1263
1264        callback1.reset_mock()
1265
1266        # child 2 still running
1267        self.watcher._sig_chld()
1268
1269        self.assertFalse(callback1.called)
1270        self.assertFalse(callback2.called)
1271
1272        # child 2 terminates (code 108)
1273        self.add_zombie(44, EXITCODE(108))
1274        self.running = False
1275        self.watcher._sig_chld()
1276
1277        callback2.assert_called_once_with(44, 108, 147, 18)
1278        self.assertFalse(callback1.called)
1279
1280        callback2.reset_mock()
1281
1282        # ensure that the children are effectively reaped
1283        self.add_zombie(43, EXITCODE(14))
1284        self.add_zombie(44, EXITCODE(15))
1285        with self.ignore_warnings:
1286            self.watcher._sig_chld()
1287
1288        self.assertFalse(callback1.called)
1289        self.assertFalse(callback2.called)
1290
1291        # sigchld called again
1292        self.zombies.clear()
1293        self.watcher._sig_chld()
1294
1295        self.assertFalse(callback1.called)
1296        self.assertFalse(callback2.called)
1297
1298    @waitpid_mocks
1299    def test_sigchld_two_children_terminating_together(self, m_waitpid):
1300        callback1 = mock.Mock()
1301        callback2 = mock.Mock()
1302
1303        # register child 1
1304        with self.watcher:
1305            self.running = True
1306            self.watcher.add_child_handler(45, callback1, 17, 8)
1307
1308        self.assertFalse(callback1.called)
1309        self.assertFalse(callback2.called)
1310
1311        # register child 2
1312        with self.watcher:
1313            self.watcher.add_child_handler(46, callback2, 1147, 18)
1314
1315        self.assertFalse(callback1.called)
1316        self.assertFalse(callback2.called)
1317
1318        # children are running
1319        self.watcher._sig_chld()
1320
1321        self.assertFalse(callback1.called)
1322        self.assertFalse(callback2.called)
1323
1324        # child 1 terminates (code 78)
1325        # child 2 terminates (signal 5)
1326        self.add_zombie(45, EXITCODE(78))
1327        self.add_zombie(46, SIGNAL(5))
1328        self.running = False
1329        self.watcher._sig_chld()
1330
1331        callback1.assert_called_once_with(45, 78, 17, 8)
1332        callback2.assert_called_once_with(46, -5, 1147, 18)
1333
1334        callback1.reset_mock()
1335        callback2.reset_mock()
1336
1337        # ensure that the children are effectively reaped
1338        self.add_zombie(45, EXITCODE(14))
1339        self.add_zombie(46, EXITCODE(15))
1340        with self.ignore_warnings:
1341            self.watcher._sig_chld()
1342
1343        self.assertFalse(callback1.called)
1344        self.assertFalse(callback2.called)
1345
1346    @waitpid_mocks
1347    def test_sigchld_race_condition(self, m_waitpid):
1348        # register a child
1349        callback = mock.Mock()
1350
1351        with self.watcher:
1352            # child terminates before being registered
1353            self.add_zombie(50, EXITCODE(4))
1354            self.watcher._sig_chld()
1355
1356            self.watcher.add_child_handler(50, callback, 1, 12)
1357
1358        callback.assert_called_once_with(50, 4, 1, 12)
1359        callback.reset_mock()
1360
1361        # ensure that the child is effectively reaped
1362        self.add_zombie(50, SIGNAL(1))
1363        with self.ignore_warnings:
1364            self.watcher._sig_chld()
1365
1366        self.assertFalse(callback.called)
1367
1368    @waitpid_mocks
1369    def test_sigchld_replace_handler(self, m_waitpid):
1370        callback1 = mock.Mock()
1371        callback2 = mock.Mock()
1372
1373        # register a child
1374        with self.watcher:
1375            self.running = True
1376            self.watcher.add_child_handler(51, callback1, 19)
1377
1378        self.assertFalse(callback1.called)
1379        self.assertFalse(callback2.called)
1380
1381        # register the same child again
1382        with self.watcher:
1383            self.watcher.add_child_handler(51, callback2, 21)
1384
1385        self.assertFalse(callback1.called)
1386        self.assertFalse(callback2.called)
1387
1388        # child terminates (signal 8)
1389        self.running = False
1390        self.add_zombie(51, SIGNAL(8))
1391        self.watcher._sig_chld()
1392
1393        callback2.assert_called_once_with(51, -8, 21)
1394        self.assertFalse(callback1.called)
1395
1396        callback2.reset_mock()
1397
1398        # ensure that the child is effectively reaped
1399        self.add_zombie(51, EXITCODE(13))
1400        with self.ignore_warnings:
1401            self.watcher._sig_chld()
1402
1403        self.assertFalse(callback1.called)
1404        self.assertFalse(callback2.called)
1405
1406    @waitpid_mocks
1407    def test_sigchld_remove_handler(self, m_waitpid):
1408        callback = mock.Mock()
1409
1410        # register a child
1411        with self.watcher:
1412            self.running = True
1413            self.watcher.add_child_handler(52, callback, 1984)
1414
1415        self.assertFalse(callback.called)
1416
1417        # unregister the child
1418        self.watcher.remove_child_handler(52)
1419
1420        self.assertFalse(callback.called)
1421
1422        # child terminates (code 99)
1423        self.running = False
1424        self.add_zombie(52, EXITCODE(99))
1425        with self.ignore_warnings:
1426            self.watcher._sig_chld()
1427
1428        self.assertFalse(callback.called)
1429
1430    @waitpid_mocks
1431    def test_sigchld_unknown_status(self, m_waitpid):
1432        callback = mock.Mock()
1433
1434        # register a child
1435        with self.watcher:
1436            self.running = True
1437            self.watcher.add_child_handler(53, callback, -19)
1438
1439        self.assertFalse(callback.called)
1440
1441        # terminate with unknown status
1442        self.zombies[53] = 1178
1443        self.running = False
1444        self.watcher._sig_chld()
1445
1446        callback.assert_called_once_with(53, 1178, -19)
1447
1448        callback.reset_mock()
1449
1450        # ensure that the child is effectively reaped
1451        self.add_zombie(53, EXITCODE(101))
1452        with self.ignore_warnings:
1453            self.watcher._sig_chld()
1454
1455        self.assertFalse(callback.called)
1456
1457    @waitpid_mocks
1458    def test_remove_child_handler(self, m_waitpid):
1459        callback1 = mock.Mock()
1460        callback2 = mock.Mock()
1461        callback3 = mock.Mock()
1462
1463        # register children
1464        with self.watcher:
1465            self.running = True
1466            self.watcher.add_child_handler(54, callback1, 1)
1467            self.watcher.add_child_handler(55, callback2, 2)
1468            self.watcher.add_child_handler(56, callback3, 3)
1469
1470        # remove child handler 1
1471        self.assertTrue(self.watcher.remove_child_handler(54))
1472
1473        # remove child handler 2 multiple times
1474        self.assertTrue(self.watcher.remove_child_handler(55))
1475        self.assertFalse(self.watcher.remove_child_handler(55))
1476        self.assertFalse(self.watcher.remove_child_handler(55))
1477
1478        # all children terminate
1479        self.add_zombie(54, EXITCODE(0))
1480        self.add_zombie(55, EXITCODE(1))
1481        self.add_zombie(56, EXITCODE(2))
1482        self.running = False
1483        with self.ignore_warnings:
1484            self.watcher._sig_chld()
1485
1486        self.assertFalse(callback1.called)
1487        self.assertFalse(callback2.called)
1488        callback3.assert_called_once_with(56, 2, 3)
1489
1490    @waitpid_mocks
1491    def test_sigchld_unhandled_exception(self, m_waitpid):
1492        callback = mock.Mock()
1493
1494        # register a child
1495        with self.watcher:
1496            self.running = True
1497            self.watcher.add_child_handler(57, callback)
1498
1499        # raise an exception
1500        m_waitpid.side_effect = ValueError
1501
1502        with mock.patch.object(log.logger,
1503                               'error') as m_error:
1504
1505            self.assertEqual(self.watcher._sig_chld(), None)
1506            self.assertTrue(m_error.called)
1507
1508    @waitpid_mocks
1509    def test_sigchld_child_reaped_elsewhere(self, m_waitpid):
1510        # register a child
1511        callback = mock.Mock()
1512
1513        with self.watcher:
1514            self.running = True
1515            self.watcher.add_child_handler(58, callback)
1516
1517        self.assertFalse(callback.called)
1518
1519        # child terminates
1520        self.running = False
1521        self.add_zombie(58, EXITCODE(4))
1522
1523        # waitpid is called elsewhere
1524        os.waitpid(58, os.WNOHANG)
1525
1526        m_waitpid.reset_mock()
1527
1528        # sigchld
1529        with self.ignore_warnings:
1530            self.watcher._sig_chld()
1531
1532        if isinstance(self.watcher, asyncio.FastChildWatcher):
1533            # here the FastChildWatche enters a deadlock
1534            # (there is no way to prevent it)
1535            self.assertFalse(callback.called)
1536        else:
1537            callback.assert_called_once_with(58, 255)
1538
1539    @waitpid_mocks
1540    def test_sigchld_unknown_pid_during_registration(self, m_waitpid):
1541        # register two children
1542        callback1 = mock.Mock()
1543        callback2 = mock.Mock()
1544
1545        with self.ignore_warnings, self.watcher:
1546            self.running = True
1547            # child 1 terminates
1548            self.add_zombie(591, EXITCODE(7))
1549            # an unknown child terminates
1550            self.add_zombie(593, EXITCODE(17))
1551
1552            self.watcher._sig_chld()
1553
1554            self.watcher.add_child_handler(591, callback1)
1555            self.watcher.add_child_handler(592, callback2)
1556
1557        callback1.assert_called_once_with(591, 7)
1558        self.assertFalse(callback2.called)
1559
1560    @waitpid_mocks
1561    def test_set_loop(self, m_waitpid):
1562        # register a child
1563        callback = mock.Mock()
1564
1565        with self.watcher:
1566            self.running = True
1567            self.watcher.add_child_handler(60, callback)
1568
1569        # attach a new loop
1570        old_loop = self.loop
1571        self.loop = self.new_test_loop()
1572        patch = mock.patch.object
1573
1574        with patch(old_loop, "remove_signal_handler") as m_old_remove, \
1575             patch(self.loop, "add_signal_handler") as m_new_add:
1576
1577            self.watcher.attach_loop(self.loop)
1578
1579            m_old_remove.assert_called_once_with(
1580                signal.SIGCHLD)
1581            m_new_add.assert_called_once_with(
1582                signal.SIGCHLD, self.watcher._sig_chld)
1583
1584        # child terminates
1585        self.running = False
1586        self.add_zombie(60, EXITCODE(9))
1587        self.watcher._sig_chld()
1588
1589        callback.assert_called_once_with(60, 9)
1590
1591    @waitpid_mocks
1592    def test_set_loop_race_condition(self, m_waitpid):
1593        # register 3 children
1594        callback1 = mock.Mock()
1595        callback2 = mock.Mock()
1596        callback3 = mock.Mock()
1597
1598        with self.watcher:
1599            self.running = True
1600            self.watcher.add_child_handler(61, callback1)
1601            self.watcher.add_child_handler(62, callback2)
1602            self.watcher.add_child_handler(622, callback3)
1603
1604        # detach the loop
1605        old_loop = self.loop
1606        self.loop = None
1607
1608        with mock.patch.object(
1609                old_loop, "remove_signal_handler") as m_remove_signal_handler:
1610
1611            with self.assertWarnsRegex(
1612                    RuntimeWarning, 'A loop is being detached'):
1613                self.watcher.attach_loop(None)
1614
1615            m_remove_signal_handler.assert_called_once_with(
1616                signal.SIGCHLD)
1617
1618        # child 1 & 2 terminate
1619        self.add_zombie(61, EXITCODE(11))
1620        self.add_zombie(62, SIGNAL(5))
1621
1622        # SIGCHLD was not caught
1623        self.assertFalse(callback1.called)
1624        self.assertFalse(callback2.called)
1625        self.assertFalse(callback3.called)
1626
1627        # attach a new loop
1628        self.loop = self.new_test_loop()
1629
1630        with mock.patch.object(
1631                self.loop, "add_signal_handler") as m_add_signal_handler:
1632
1633            self.watcher.attach_loop(self.loop)
1634
1635            m_add_signal_handler.assert_called_once_with(
1636                signal.SIGCHLD, self.watcher._sig_chld)
1637            callback1.assert_called_once_with(61, 11)  # race condition!
1638            callback2.assert_called_once_with(62, -5)  # race condition!
1639            self.assertFalse(callback3.called)
1640
1641        callback1.reset_mock()
1642        callback2.reset_mock()
1643
1644        # child 3 terminates
1645        self.running = False
1646        self.add_zombie(622, EXITCODE(19))
1647        self.watcher._sig_chld()
1648
1649        self.assertFalse(callback1.called)
1650        self.assertFalse(callback2.called)
1651        callback3.assert_called_once_with(622, 19)
1652
1653    @waitpid_mocks
1654    def test_close(self, m_waitpid):
1655        # register two children
1656        callback1 = mock.Mock()
1657
1658        with self.watcher:
1659            self.running = True
1660            # child 1 terminates
1661            self.add_zombie(63, EXITCODE(9))
1662            # other child terminates
1663            self.add_zombie(65, EXITCODE(18))
1664            self.watcher._sig_chld()
1665
1666            self.watcher.add_child_handler(63, callback1)
1667            self.watcher.add_child_handler(64, callback1)
1668
1669            self.assertEqual(len(self.watcher._callbacks), 1)
1670            if isinstance(self.watcher, asyncio.FastChildWatcher):
1671                self.assertEqual(len(self.watcher._zombies), 1)
1672
1673            with mock.patch.object(
1674                    self.loop,
1675                    "remove_signal_handler") as m_remove_signal_handler:
1676
1677                self.watcher.close()
1678
1679                m_remove_signal_handler.assert_called_once_with(
1680                    signal.SIGCHLD)
1681                self.assertFalse(self.watcher._callbacks)
1682                if isinstance(self.watcher, asyncio.FastChildWatcher):
1683                    self.assertFalse(self.watcher._zombies)
1684
1685
1686class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
1687    def create_watcher(self):
1688        return asyncio.SafeChildWatcher()
1689
1690
1691class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
1692    def create_watcher(self):
1693        return asyncio.FastChildWatcher()
1694
1695
1696class PolicyTests(unittest.TestCase):
1697
1698    def create_policy(self):
1699        return asyncio.DefaultEventLoopPolicy()
1700
1701    def test_get_default_child_watcher(self):
1702        policy = self.create_policy()
1703        self.assertIsNone(policy._watcher)
1704
1705        watcher = policy.get_child_watcher()
1706        self.assertIsInstance(watcher, asyncio.ThreadedChildWatcher)
1707
1708        self.assertIs(policy._watcher, watcher)
1709
1710        self.assertIs(watcher, policy.get_child_watcher())
1711
1712    def test_get_child_watcher_after_set(self):
1713        policy = self.create_policy()
1714        watcher = asyncio.FastChildWatcher()
1715
1716        policy.set_child_watcher(watcher)
1717        self.assertIs(policy._watcher, watcher)
1718        self.assertIs(watcher, policy.get_child_watcher())
1719
1720    def test_get_child_watcher_thread(self):
1721
1722        def f():
1723            policy.set_event_loop(policy.new_event_loop())
1724
1725            self.assertIsInstance(policy.get_event_loop(),
1726                                  asyncio.AbstractEventLoop)
1727            watcher = policy.get_child_watcher()
1728
1729            self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
1730            self.assertIsNone(watcher._loop)
1731
1732            policy.get_event_loop().close()
1733
1734        policy = self.create_policy()
1735        policy.set_child_watcher(asyncio.SafeChildWatcher())
1736
1737        th = threading.Thread(target=f)
1738        th.start()
1739        th.join()
1740
1741    def test_child_watcher_replace_mainloop_existing(self):
1742        policy = self.create_policy()
1743        loop = policy.get_event_loop()
1744
1745        # Explicitly setup SafeChildWatcher,
1746        # default ThreadedChildWatcher has no _loop property
1747        watcher = asyncio.SafeChildWatcher()
1748        policy.set_child_watcher(watcher)
1749        watcher.attach_loop(loop)
1750
1751        self.assertIs(watcher._loop, loop)
1752
1753        new_loop = policy.new_event_loop()
1754        policy.set_event_loop(new_loop)
1755
1756        self.assertIs(watcher._loop, new_loop)
1757
1758        policy.set_event_loop(None)
1759
1760        self.assertIs(watcher._loop, None)
1761
1762        loop.close()
1763        new_loop.close()
1764
1765
1766class TestFunctional(unittest.TestCase):
1767
1768    def setUp(self):
1769        self.loop = asyncio.new_event_loop()
1770        asyncio.set_event_loop(self.loop)
1771
1772    def tearDown(self):
1773        self.loop.close()
1774        asyncio.set_event_loop(None)
1775
1776    def test_add_reader_invalid_argument(self):
1777        def assert_raises():
1778            return self.assertRaisesRegex(ValueError, r'Invalid file object')
1779
1780        cb = lambda: None
1781
1782        with assert_raises():
1783            self.loop.add_reader(object(), cb)
1784        with assert_raises():
1785            self.loop.add_writer(object(), cb)
1786
1787        with assert_raises():
1788            self.loop.remove_reader(object())
1789        with assert_raises():
1790            self.loop.remove_writer(object())
1791
1792    def test_add_reader_or_writer_transport_fd(self):
1793        def assert_raises():
1794            return self.assertRaisesRegex(
1795                RuntimeError,
1796                r'File descriptor .* is used by transport')
1797
1798        async def runner():
1799            tr, pr = await self.loop.create_connection(
1800                lambda: asyncio.Protocol(), sock=rsock)
1801
1802            try:
1803                cb = lambda: None
1804
1805                with assert_raises():
1806                    self.loop.add_reader(rsock, cb)
1807                with assert_raises():
1808                    self.loop.add_reader(rsock.fileno(), cb)
1809
1810                with assert_raises():
1811                    self.loop.remove_reader(rsock)
1812                with assert_raises():
1813                    self.loop.remove_reader(rsock.fileno())
1814
1815                with assert_raises():
1816                    self.loop.add_writer(rsock, cb)
1817                with assert_raises():
1818                    self.loop.add_writer(rsock.fileno(), cb)
1819
1820                with assert_raises():
1821                    self.loop.remove_writer(rsock)
1822                with assert_raises():
1823                    self.loop.remove_writer(rsock.fileno())
1824
1825            finally:
1826                tr.close()
1827
1828        rsock, wsock = socket.socketpair()
1829        try:
1830            self.loop.run_until_complete(runner())
1831        finally:
1832            rsock.close()
1833            wsock.close()
1834
1835
1836if __name__ == '__main__':
1837    unittest.main()
1838